Skip to content

Commit 6470423

Browse files
authored
Edge extraction efficiency (#1140)
* Add node splitting when a large number of nodes are extrcted * update * add tests * update * update * update * update * update * update * update
1 parent 79712af commit 6470423

File tree

5 files changed

+553
-41
lines changed

5 files changed

+553
-41
lines changed

examples/podcast/podcast_runner.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,7 @@ class IsPresidentOf(BaseModel):
7777

7878
async def main(use_bulk: bool = False):
7979
setup_logging()
80-
client = Graphiti(
81-
neo4j_uri,
82-
neo4j_user,
83-
neo4j_password,
84-
)
80+
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
8581
await clear_data(client.driver)
8682
await client.build_indices_and_constraints()
8783
messages = parse_podcast_messages()

graphiti_core/utils/content_chunking.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616

1717
import json
1818
import logging
19+
import random
1920
import re
21+
from itertools import combinations
22+
from math import comb
23+
from typing import TypeVar
2024

2125
from graphiti_core.helpers import (
2226
CHUNK_DENSITY_THRESHOLD,
@@ -700,3 +704,123 @@ def _chunk_by_lines(
700704
chunks.append('\n'.join(current_lines))
701705

702706
return chunks if chunks else [content]
707+
708+
709+
T = TypeVar('T')
710+
711+
MAX_COMBINATIONS_TO_EVALUATE = 1000
712+
713+
714+
def _random_combination(n: int, k: int) -> tuple[int, ...]:
715+
"""Generate a random combination of k items from range(n)."""
716+
return tuple(sorted(random.sample(range(n), k)))
717+
718+
719+
def generate_covering_chunks(items: list[T], k: int) -> list[tuple[list[T], list[int]]]:
720+
"""Generate chunks of items that cover all pairs using a greedy approach.
721+
722+
Based on the Handshake Flights Problem / Covering Design problem.
723+
Each chunk of K items covers C(K,2) = K(K-1)/2 pairs. We greedily select
724+
chunks to maximize coverage of uncovered pairs, minimizing the total number
725+
of chunks needed to ensure every pair of items appears in at least one chunk.
726+
727+
For large inputs where C(n,k) > MAX_COMBINATIONS_TO_EVALUATE, random sampling
728+
is used instead of exhaustive search to maintain performance.
729+
730+
Lower bound (Schönheim): F >= ceil(N/K * ceil((N-1)/(K-1)))
731+
732+
Args:
733+
items: List of items to partition into covering chunks
734+
k: Maximum number of items per chunk
735+
736+
Returns:
737+
List of tuples (chunk_items, global_indices) where global_indices maps
738+
each position in chunk_items to its index in the original items list.
739+
"""
740+
n = len(items)
741+
if n <= k:
742+
return [(items, list(range(n)))]
743+
744+
# Track uncovered pairs using frozensets of indices
745+
uncovered_pairs: set[frozenset[int]] = {
746+
frozenset([i, j]) for i in range(n) for j in range(i + 1, n)
747+
}
748+
749+
chunks: list[tuple[list[T], list[int]]] = []
750+
751+
# Determine if we need to sample or can enumerate all combinations
752+
total_combinations = comb(n, k)
753+
use_sampling = total_combinations > MAX_COMBINATIONS_TO_EVALUATE
754+
755+
while uncovered_pairs:
756+
# Greedy selection: find the chunk that covers the most uncovered pairs
757+
best_chunk_indices: tuple[int, ...] | None = None
758+
best_covered_count = 0
759+
760+
if use_sampling:
761+
# Sample random combinations when there are too many to enumerate
762+
seen_combinations: set[tuple[int, ...]] = set()
763+
# Limit total attempts (including duplicates) to prevent infinite loops
764+
max_total_attempts = MAX_COMBINATIONS_TO_EVALUATE * 3
765+
total_attempts = 0
766+
samples_evaluated = 0
767+
while samples_evaluated < MAX_COMBINATIONS_TO_EVALUATE:
768+
total_attempts += 1
769+
if total_attempts > max_total_attempts:
770+
# Too many total attempts, break to avoid infinite loop
771+
break
772+
chunk_indices = _random_combination(n, k)
773+
if chunk_indices in seen_combinations:
774+
continue
775+
seen_combinations.add(chunk_indices)
776+
samples_evaluated += 1
777+
778+
# Count how many uncovered pairs this chunk covers
779+
covered_count = sum(
780+
1
781+
for i, idx_i in enumerate(chunk_indices)
782+
for idx_j in chunk_indices[i + 1 :]
783+
if frozenset([idx_i, idx_j]) in uncovered_pairs
784+
)
785+
786+
if covered_count > best_covered_count:
787+
best_covered_count = covered_count
788+
best_chunk_indices = chunk_indices
789+
else:
790+
# Enumerate all combinations when feasible
791+
for chunk_indices in combinations(range(n), k):
792+
# Count how many uncovered pairs this chunk covers
793+
covered_count = sum(
794+
1
795+
for i, idx_i in enumerate(chunk_indices)
796+
for idx_j in chunk_indices[i + 1 :]
797+
if frozenset([idx_i, idx_j]) in uncovered_pairs
798+
)
799+
800+
if covered_count > best_covered_count:
801+
best_covered_count = covered_count
802+
best_chunk_indices = chunk_indices
803+
804+
if best_chunk_indices is None or best_covered_count == 0:
805+
# Greedy search couldn't find a chunk covering uncovered pairs.
806+
# This can happen with random sampling. Fall back to creating
807+
# small chunks that directly cover remaining pairs.
808+
break
809+
810+
# Mark pairs in this chunk as covered
811+
for i, idx_i in enumerate(best_chunk_indices):
812+
for idx_j in best_chunk_indices[i + 1 :]:
813+
uncovered_pairs.discard(frozenset([idx_i, idx_j]))
814+
815+
chunk_items = [items[idx] for idx in best_chunk_indices]
816+
chunks.append((chunk_items, list(best_chunk_indices)))
817+
818+
# Handle any remaining uncovered pairs that the greedy algorithm missed.
819+
# This can happen when random sampling fails to find covering chunks.
820+
# Create minimal chunks (size 2) to guarantee all pairs are covered.
821+
for pair in uncovered_pairs:
822+
pair_indices = sorted(pair)
823+
chunk_items = [items[idx] for idx in pair_indices]
824+
chunks.append((chunk_items, pair_indices))
825+
826+
return chunks

graphiti_core/utils/maintenance/edge_operations.py

Lines changed: 108 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,18 @@
3535
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
3636
from graphiti_core.prompts import prompt_library
3737
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
38+
from graphiti_core.prompts.extract_edges import Edge as ExtractedEdge
3839
from graphiti_core.prompts.extract_edges import ExtractedEdges
3940
from graphiti_core.search.search import search
4041
from graphiti_core.search.search_config import SearchResults
4142
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
4243
from graphiti_core.search.search_filters import SearchFilters
44+
from graphiti_core.utils.content_chunking import generate_covering_chunks
4345
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
4446
from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
4547

4648
DEFAULT_EDGE_NAME = 'RELATES_TO'
49+
MAX_NODES = 15
4750

4851
logger = logging.getLogger(__name__)
4952

@@ -120,27 +123,110 @@ async def extract_edges(
120123
else []
121124
)
122125

123-
# Prepare context for LLM
124-
context = {
125-
'episode_content': episode.content,
126-
'nodes': [
127-
{'id': idx, 'name': node.name, 'entity_types': node.labels}
128-
for idx, node in enumerate(nodes)
129-
],
130-
'previous_episodes': [ep.content for ep in previous_episodes],
131-
'reference_time': episode.valid_at,
132-
'edge_types': edge_types_context,
133-
'custom_extraction_instructions': custom_extraction_instructions or '',
134-
}
126+
# Generate covering chunks to ensure all node pairs are processed.
127+
# Uses a greedy approach based on the Handshake Flights Problem.
128+
covering_chunks = generate_covering_chunks(nodes, MAX_NODES)
129+
130+
# Pre-assign pairs to chunks to avoid duplicate edge extraction.
131+
# Each pair is assigned to the first chunk that contains it.
132+
processed_pairs: set[frozenset[int]] = set()
133+
chunk_assigned_pairs: list[set[frozenset[int]]] = []
134+
135+
for _, global_indices in covering_chunks:
136+
assigned_pairs: set[frozenset[int]] = set()
137+
for i, idx_i in enumerate(global_indices):
138+
for idx_j in global_indices[i + 1 :]:
139+
pair = frozenset([idx_i, idx_j])
140+
if pair not in processed_pairs:
141+
processed_pairs.add(pair)
142+
assigned_pairs.add(pair)
143+
chunk_assigned_pairs.append(assigned_pairs)
144+
145+
async def extract_edges_for_chunk(
146+
chunk: list[EntityNode],
147+
global_indices: list[int],
148+
assigned_pairs: set[frozenset[int]],
149+
) -> list[ExtractedEdge]:
150+
# Skip chunks with no assigned pairs (all pairs already processed)
151+
if not assigned_pairs:
152+
return []
153+
154+
# Prepare context for LLM
155+
context = {
156+
'episode_content': episode.content,
157+
'nodes': [
158+
{'id': idx, 'name': node.name, 'entity_types': node.labels}
159+
for idx, node in enumerate(chunk)
160+
],
161+
'previous_episodes': [ep.content for ep in previous_episodes],
162+
'reference_time': episode.valid_at,
163+
'edge_types': edge_types_context,
164+
'custom_extraction_instructions': custom_extraction_instructions or '',
165+
}
135166

136-
llm_response = await llm_client.generate_response(
137-
prompt_library.extract_edges.edge(context),
138-
response_model=ExtractedEdges,
139-
max_tokens=extract_edges_max_tokens,
140-
group_id=group_id,
141-
prompt_name='extract_edges.edge',
167+
llm_response = await llm_client.generate_response(
168+
prompt_library.extract_edges.edge(context),
169+
response_model=ExtractedEdges,
170+
max_tokens=extract_edges_max_tokens,
171+
group_id=group_id,
172+
prompt_name='extract_edges.edge',
173+
)
174+
chunk_edges_data = ExtractedEdges(**llm_response).edges
175+
176+
# Map chunk-local indices to global indices in the original nodes list
177+
# Note: global_indices are guaranteed valid by generate_covering_chunks,
178+
# but LLM-returned local indices need validation
179+
valid_edges: list[ExtractedEdge] = []
180+
chunk_size = len(global_indices)
181+
182+
for edge_data in chunk_edges_data:
183+
source_local_idx = edge_data.source_entity_id
184+
target_local_idx = edge_data.target_entity_id
185+
186+
# Validate LLM-returned indices are within chunk bounds
187+
if not (0 <= source_local_idx < chunk_size):
188+
logger.warning(
189+
f'Source index {source_local_idx} out of bounds for chunk of size '
190+
f'{chunk_size} in edge {edge_data.relation_type}'
191+
)
192+
continue
193+
194+
if not (0 <= target_local_idx < chunk_size):
195+
logger.warning(
196+
f'Target index {target_local_idx} out of bounds for chunk of size '
197+
f'{chunk_size} in edge {edge_data.relation_type}'
198+
)
199+
continue
200+
201+
# Map to global indices (guaranteed valid by generate_covering_chunks)
202+
mapped_source = global_indices[source_local_idx]
203+
mapped_target = global_indices[target_local_idx]
204+
edge_data.source_entity_id = mapped_source
205+
edge_data.target_entity_id = mapped_target
206+
207+
# Only include edges for pairs assigned to this chunk
208+
edge_pair = frozenset([mapped_source, mapped_target])
209+
if edge_pair in assigned_pairs:
210+
valid_edges.append(edge_data)
211+
212+
return valid_edges
213+
214+
# Extract edges from all chunks in parallel
215+
chunk_results: list[list[ExtractedEdge]] = list(
216+
await semaphore_gather(
217+
*[
218+
extract_edges_for_chunk(chunk, global_indices, assigned_pairs)
219+
for (chunk, global_indices), assigned_pairs in zip(
220+
covering_chunks, chunk_assigned_pairs, strict=True
221+
)
222+
]
223+
)
142224
)
143-
edges_data = ExtractedEdges(**llm_response).edges
225+
226+
# Combine results from all chunks
227+
edges_data: list[ExtractedEdge] = []
228+
for chunk_edges in chunk_results:
229+
edges_data.extend(chunk_edges)
144230

145231
end = time()
146232
logger.debug(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
@@ -161,22 +247,9 @@ async def extract_edges(
161247
if not edge_data.fact.strip():
162248
continue
163249

164-
source_node_idx = edge_data.source_entity_id
165-
target_node_idx = edge_data.target_entity_id
166-
167-
if len(nodes) == 0:
168-
logger.warning('No entities provided for edge extraction')
169-
continue
170-
171-
if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
172-
logger.warning(
173-
f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
174-
f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
175-
f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
176-
)
177-
continue
178-
source_node_uuid = nodes[source_node_idx].uuid
179-
target_node_uuid = nodes[target_node_idx].uuid
250+
# Indices already validated in extract_edges_for_chunk
251+
source_node_uuid = nodes[edge_data.source_entity_id].uuid
252+
target_node_uuid = nodes[edge_data.target_entity_id].uuid
180253

181254
if valid_at:
182255
try:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "graphiti-core"
33
description = "A temporal graph building library"
4-
version = "0.25.4"
4+
version = "0.25.5"
55
authors = [
66
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
77
{ name = "Preston Rasmussen", email = "preston@getzep.com" },

0 commit comments

Comments
 (0)