-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_entity_extractor.py
More file actions
262 lines (207 loc) · 11.3 KB
/
test_entity_extractor.py
File metadata and controls
262 lines (207 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""Tests for the Entity Extractor (HCE Phase 3, Component 2)."""
from __future__ import annotations
import pytest
from hce_core import EntityGraph, NodeType, EdgeType
from entity_extractor import (
Entity,
EntityExtractor,
_default_ner,
_normalize_entity_id,
)
# ── Helpers ─────────────────────────────────────────────────────────────────
def _entity_texts(entities: list[Entity]) -> list[str]:
"""Return a list of lowered entity texts for easy assertions."""
return [e.text.lower() for e in entities]
def _entity_types(entities: list[Entity]) -> dict[str, NodeType]:
"""Map lowered entity text -> NodeType."""
return {e.text.lower(): e.entity_type for e in entities}
# ── _normalize_entity_id ────────────────────────────────────────────────────
class TestNormalizeEntityId:
def test_basic(self):
assert _normalize_entity_id("John Smith") == "john_smith"
def test_strips_whitespace(self):
assert _normalize_entity_id(" Hello World ") == "hello_world"
def test_collapses_internal_spaces(self):
assert _normalize_entity_id("foo bar") == "foo_bar"
# ── PERSON extraction ──────────────────────────────────────────────────────
class TestPersonExtraction:
def test_basic_person(self):
entities = _default_ner("I met John Smith at the conference")
persons = [e for e in entities if e.entity_type == NodeType.PERSON]
texts = [e.text for e in persons]
assert "John Smith" in texts
def test_title_prefix(self):
entities = _default_ner("Please ask Dr. Alice Brown about it")
persons = [e for e in entities if e.entity_type == NodeType.PERSON]
texts = [e.text for e in persons]
assert any("Alice Brown" in t for t in texts)
def test_three_word_name(self):
entities = _default_ner("I know Mary Jane Watson personally")
persons = [e for e in entities if e.entity_type == NodeType.PERSON]
texts = [e.text for e in persons]
assert any("Mary Jane Watson" in t for t in texts)
def test_sentence_start_filter(self):
"""Common sentence-start phrases should not become PERSON entities."""
entities = _default_ner("The quick brown fox jumps over the lazy dog.")
persons = [e for e in entities if e.entity_type == NodeType.PERSON]
assert len(persons) == 0
# ── CONCEPT extraction ─────────────────────────────────────────────────────
class TestConceptExtraction:
def test_quoted_concept(self):
entities = _default_ner("We discussed the concept of 'machine learning'")
concepts = [e for e in entities if e.entity_type == NodeType.CONCEPT]
texts = [e.text.lower() for e in concepts]
assert "machine learning" in texts
def test_double_quoted(self):
entities = _default_ner('The idea of "neural networks" is fascinating')
concepts = [e for e in entities if e.entity_type == NodeType.CONCEPT]
texts = [e.text.lower() for e in concepts]
assert "neural networks" in texts
def test_signal_preceded_concept(self):
entities = _default_ner("She talked about Python for hours")
concepts = [e for e in entities if e.entity_type == NodeType.CONCEPT]
texts = [e.text for e in concepts]
assert "Python" in texts
def test_repeated_capitalized_word(self):
entities = _default_ner(
"Rust is a language. Rust has a borrow checker."
)
concepts = [e for e in entities if e.entity_type == NodeType.CONCEPT]
texts = [e.text for e in concepts]
assert "Rust" in texts
def test_long_quoted_string_ignored(self):
"""Quoted strings >= 50 chars should not become concepts."""
long_quote = "a" * 55
entities = _default_ner(f'He said "{long_quote}" in the meeting')
concepts = [e for e in entities if e.entity_type == NodeType.CONCEPT]
quoted_long = [e for e in concepts if e.text == long_quote]
assert len(quoted_long) == 0
# ── EVENT extraction ───────────────────────────────────────────────────────
class TestEventExtraction:
def test_meeting_on_day(self):
entities = _default_ner("meeting on Monday at the office")
events = [e for e in entities if e.entity_type == NodeType.EVENT]
assert len(events) >= 1
texts = [e.text.lower() for e in events]
assert any("monday" in t for t in texts)
def test_gerund_to_place(self):
entities = _default_ner("We are traveling to Japan next week")
events = [e for e in entities if e.entity_type == NodeType.EVENT]
texts = [e.text.lower() for e in events]
assert any("japan" in t for t in texts)
def test_meeting_at_place(self):
entities = _default_ner("meeting at Google was productive")
events = [e for e in entities if e.entity_type == NodeType.EVENT]
texts = [e.text.lower() for e in events]
assert any("google" in t for t in texts)
# ── EntityExtractor.extract (deduplication) ────────────────────────────────
class TestExtractDeduplication:
def test_dedup_same_entity(self):
"""Same entity mentioned twice should appear once in extract output."""
extractor = EntityExtractor()
text = "John Smith met Alice Brown. Later John Smith called Alice Brown."
entities = extractor.extract(text)
ids = [_normalize_entity_id(e.text) for e in entities]
assert ids.count("john_smith") == 1
assert ids.count("alice_brown") == 1
# ── EntityExtractor.update_graph ───────────────────────────────────────────
class TestUpdateGraph:
def test_adds_nodes(self):
extractor = EntityExtractor()
graph = EntityGraph()
text = "I met John Smith at the conference"
node_ids = extractor.update_graph(text, graph)
assert len(node_ids) >= 1
assert graph.has_node("john_smith")
def test_multiple_entities_create_relates_to_edges(self):
"""All co-occurring entities should be linked by RELATES_TO."""
extractor = EntityExtractor()
graph = EntityGraph()
text = "Dr. Alice Brown discussed 'quantum computing' with Bob Carter"
node_ids = extractor.update_graph(text, graph)
# Should have at least 2 nodes
assert len(node_ids) >= 2
# Check RELATES_TO edges exist between at least one pair
assert graph.edge_count >= 1
def test_dedup_increments_mentions(self):
"""Re-extracting the same entity should increment its mentions count."""
extractor = EntityExtractor()
graph = EntityGraph()
text1 = "John Smith is here"
text2 = "John Smith left the building"
extractor.update_graph(text1, graph)
extractor.update_graph(text2, graph)
node = graph.get_node("john_smith")
assert node is not None
assert node["metadata"]["mentions"] == 2
def test_interaction_id_creates_links(self):
"""When interaction_id exists in graph, entities link to it."""
extractor = EntityExtractor()
graph = EntityGraph()
# Pre-create an interaction node
graph.add_node("turn_42", NodeType.EVENT, label="Turn 42")
text = "John Smith talked about 'deep learning'"
node_ids = extractor.update_graph(text, graph, interaction_id="turn_42")
# Each entity should have an edge to turn_42
for nid in node_ids:
neighbors = graph.get_neighbors(nid, edge_type=EdgeType.RELATES_TO, direction="out")
neighbor_ids = [n for n, _ in neighbors]
assert "turn_42" in neighbor_ids, (
f"Node {nid} should link to interaction turn_42"
)
def test_interaction_id_missing_no_error(self):
"""If interaction_id is not in graph, no error and no link."""
extractor = EntityExtractor()
graph = EntityGraph()
text = "John Smith is here"
# Should not raise
node_ids = extractor.update_graph(text, graph, interaction_id="nonexistent")
assert len(node_ids) >= 1
# ── Custom NER function ────────────────────────────────────────────────────
class TestCustomNer:
def test_custom_ner_is_used(self):
"""A user-provided NER function replaces the default."""
def my_ner(text: str) -> list[Entity]:
return [Entity("custom_entity", NodeType.CONCEPT, 0, len(text))]
extractor = EntityExtractor(ner_func=my_ner)
entities = extractor.extract("any text at all")
assert len(entities) == 1
assert entities[0].text == "custom_entity"
assert entities[0].entity_type == NodeType.CONCEPT
# ── Empty text ──────────────────────────────────────────────────────────────
class TestEmptyText:
def test_empty_string(self):
extractor = EntityExtractor()
entities = extractor.extract("")
assert entities == []
def test_no_entities_in_lowercase(self):
extractor = EntityExtractor()
entities = extractor.extract("nothing special here at all")
# Should not crash; may return empty or very few items
assert isinstance(entities, list)
# ── Co-occurrence edges ────────────────────────────────────────────────────
class TestCoOccurrence:
def test_alice_bob_python(self):
"""'Alice told Bob about Python' should yield edges between all pairs."""
extractor = EntityExtractor()
graph = EntityGraph()
text = "Alice Carter told Bob Davis about Python"
node_ids = extractor.update_graph(text, graph)
# We expect at least Alice Carter, Bob Davis, and Python
assert len(node_ids) >= 3
# Verify edges exist between all extracted pairs
for i, a in enumerate(node_ids):
for b in node_ids[i + 1:]:
if a == b:
continue
# Check at least one direction
neighbors_out = graph.get_neighbors(
a, edge_type=EdgeType.RELATES_TO, direction="out"
)
neighbors_in = graph.get_neighbors(
a, edge_type=EdgeType.RELATES_TO, direction="in"
)
connected_ids = {n for n, _ in neighbors_out} | {n for n, _ in neighbors_in}
assert b in connected_ids, (
f"Expected RELATES_TO edge between {a} and {b}"
)