Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3e121a4
Fix #2191: sanitize embedding inputs before API call
Aroool Feb 28, 2026
351f4e0
Address CodeRabbit feedback for #2191
Aroool Feb 28, 2026
fbe94d4
Merge branch 'dev' into fix-2191-embedding-sanitize
Aroool Mar 15, 2026
cc9452a
chore: update versions
dexters1 Mar 27, 2026
4b38042
chore: format code
dexters1 Mar 27, 2026
f15d41f
feat: add embedding sanitization
dexters1 Mar 27, 2026
1e42ea0
test: update unit test
dexters1 Mar 27, 2026
182ba1b
Merge branch 'release-candidate-v0.5.6' into aroool-sanitize-embedding
dexters1 Mar 28, 2026
6cee1e7
refactor: sanitize text input for all embeding models
dexters1 Mar 28, 2026
f38c8fd
Merge branch 'aroool-sanitize-embedding' of github.com:topoteretes/co…
dexters1 Mar 28, 2026
67675b8
chore: add todo comment
dexters1 Mar 28, 2026
b2269f6
chore: add comments to Ollama embedding
dexters1 Mar 28, 2026
bf302c1
chore: update tod
dexters1 Mar 28, 2026
32d0903
fix: resolve issue with conditional auth test
dexters1 Mar 28, 2026
0839f4d
refactor: comment out unit tests
dexters1 Mar 28, 2026
7f2976f
Aroool sanitize embedding (#2508)
dexters1 Mar 28, 2026
9c6d617
refactor: return removed test
dexters1 Mar 28, 2026
95cfc36
feat: add automigrate for LanceDB
dexters1 Mar 30, 2026
69d23bc
test: add schema migration test
dexters1 Mar 30, 2026
8b9630d
fix: resolve conditional auth test issue
dexters1 Mar 30, 2026
8bd3590
feat: adds back properties that are missing
hajdul88 Mar 30, 2026
c630bd8
feat: only expected properties assert
hajdul88 Mar 30, 2026
694814d
Revert "feat: only expected properties assert"
hajdul88 Mar 30, 2026
4da584c
Merge branch 'release-candidate-v0.5.6' into automigrate-lancedb-rows
dexters1 Mar 30, 2026
a612c79
refactor: remove unnnecessary code
dexters1 Mar 30, 2026
4fa0257
Merge branch 'dev' into release-candidate-v0.5.6
dexters1 Mar 30, 2026
f3d0c00
Merge branch 'release-candidate-v0.5.6' into automigrate-lancedb-rows
dexters1 Mar 30, 2026
611df62
feat: add automigrate for LanceDB (#2520)
Vasilije1990 Mar 30, 2026
e517bf5
chore: update lock files
dexters1 Mar 30, 2026
37252c4
Merge branch 'release-candidate-v0.5.6' of github.com:topoteretes/cog…
dexters1 Mar 30, 2026
f481842
Merge branch 'dev' into release-candidate-v0.5.6
dexters1 Mar 30, 2026
9bb69cf
chore: ruff format
dexters1 Mar 30, 2026
ac87032
refactor: return inherited field
dexters1 Mar 30, 2026
c506d3a
chore: reduce lanceDB version for uv lock
dexters1 Mar 30, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/e2e_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1278,3 +1278,36 @@ jobs:
CACHE_BACKEND: 'redis'
CACHE_HOST: ${{ inputs.ci-image != '' && 'redis' || 'localhost' }}
run: uv run pytest cognee/tests/test_usage_logger_e2e.py -v --log-level=INFO

run_conditional_auth_test:
name: Conditional Authentication Test
runs-on: ubuntu-latest
container: ${{ inputs.ci-image != '' && fromJSON(format('{{"image":"{0}","credentials":{{"username":"{1}","password":"{2}"}}}}', inputs.ci-image, github.actor, github.token)) || null }}
defaults:
run:
shell: bash
steps:
- name: Check out
uses: actions/checkout@v6
with:
fetch-depth: 0

- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'

- name: Run Conditional Authentication Test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_DIMENSIONS: 300
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
ENABLE_BACKEND_ACCESS_CONTROL: "false"
run: uv run python ./cognee/tests/api/test_conditional_authentication_endpoints.py
2 changes: 1 addition & 1 deletion cognee-mcp/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ requires-python = ">=3.10"
dependencies = [
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
#"cognee[postgres,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee",
"cognee[postgres,docs,neo4j]==0.5.4",
"cognee[postgres,docs,neo4j]==0.5.5",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Pin to cognee==0.5.6 to match the main package version.

This line pins cognee to version 0.5.5, but the main package in pyproject.toml line 4 is being bumped to 0.5.6 in this same PR. This creates a version skew where:

  • The cognee-mcp package will fetch cognee 0.5.5 from PyPI
  • The main repository is releasing version 0.5.6
  • Users installing cognee-mcp after the v0.5.6 release will get an outdated cognee dependency

For a consistent release, this should be updated to cognee[postgres,docs,neo4j]==0.5.6.

📦 Proposed fix to align versions
-    "cognee[postgres,docs,neo4j]==0.5.5",
+    "cognee[postgres,docs,neo4j]==0.5.6",

After making this change, regenerate the lock file:

cd cognee-mcp && uv lock
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"cognee[postgres,docs,neo4j]==0.5.5",
"cognee[postgres,docs,neo4j]==0.5.6",
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cognee-mcp/pyproject.toml` at line 11, Update the pinned dependency string
"cognee[postgres,docs,neo4j]==0.5.5" to "cognee[postgres,docs,neo4j]==0.5.6" in
the pyproject.toml for cognee-mcp to match the main package bump, then
regenerate the lock file (e.g., run `uv lock` in the cognee-mcp directory) so
the lockfile reflects the updated version.

"fastmcp>=2.10.0,<3.0.0",
"mcp>=1.12.0,<2.0.0",
"uv>=0.6.3,<1.0.0",
Expand Down
2,732 changes: 1,421 additions & 1,311 deletions cognee-mcp/uv.lock

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
TikTokenTokenizer,
)
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
from cognee.infrastructure.databases.vector.embeddings.utils import (
sanitize_embedding_text_inputs,
handle_embedding_response,
)

litellm.set_verbose = False
logger = get_logger("FastembedEmbeddingEngine")
Expand Down Expand Up @@ -101,18 +105,19 @@ async def embed_text(self, text: List[str]) -> List[List[float]]:
- List[List[float]]: A list of embeddings, where each embedding is a list of floats
representing the vector form of the input text.
"""
sanitized_text_input = sanitize_embedding_text_inputs(text)
try:
if self.mock:
return [[0.0] * self.dimensions for _ in text]
return [[0.0] * self.dimensions for _ in sanitized_text_input]
else:
async with embedding_rate_limiter_context_manager():
embeddings = self.embedding_model.embed(
text,
sanitized_text_input,
batch_size=len(text),
parallel=None,
)

return list(embeddings)
embeddings = list(embeddings)
return handle_embedding_response(text, embeddings, self.dimensions)

except Exception as error:
logger.error(f"Embedding error in FastembedEmbeddingEngine: {str(error)}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
TikTokenTokenizer,
)
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
from cognee.infrastructure.databases.vector.embeddings.utils import (
sanitize_embedding_text_inputs,
handle_embedding_response,
)

litellm.set_verbose = False
logger = get_logger("LiteLLMEmbeddingEngine")
Expand Down Expand Up @@ -123,15 +127,20 @@ async def embed_text(self, text: List[str]) -> List[List[float]]:

- List[List[float]]: A list of vectors representing the embedded texts.
"""

sanitized_text_input = sanitize_embedding_text_inputs(text)

try:
if self.mock:
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
response = {
"data": [{"embedding": [0.0] * self.dimensions} for _ in sanitized_text_input]
}
return [data["embedding"] for data in response["data"]]
else:
async with embedding_rate_limiter_context_manager():
embedding_kwargs = {
"model": self.model,
"input": text,
"input": sanitized_text_input,
"api_key": self.api_key,
"api_base": self.endpoint,
"api_version": self.api_version,
Expand All @@ -146,7 +155,8 @@ async def embed_text(self, text: List[str]) -> List[List[float]]:
timeout=30.0,
)

return [data["embedding"] for data in response.data]
embedding_response = [data["embedding"] for data in response.data]
return handle_embedding_response(text, embedding_response, self.dimensions)

except litellm.exceptions.ContextWindowExceededError as error:
if isinstance(text, list) and len(text) > 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
)
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
from cognee.shared.utils import create_secure_ssl_context
from cognee.infrastructure.databases.vector.embeddings.utils import (
sanitize_embedding_text_inputs,
handle_embedding_response,
)

logger = get_logger("OllamaEmbeddingEngine")

Expand Down Expand Up @@ -90,15 +94,18 @@ async def embed_text(self, text: List[str]) -> List[List[float]]:

- List[List[float]]: A list of embedding vectors corresponding to the text prompts.
"""
sanitized_text_input = sanitize_embedding_text_inputs(text)
if self.mock:
return [[0.0] * self.dimensions for _ in text]
return [[0.0] * self.dimensions for _ in sanitized_text_input]

# Handle case when a single string is passed instead of a list
if not isinstance(text, list):
text = [text]
if not isinstance(sanitized_text_input, list):
text = [sanitized_text_input]

embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
return embeddings
embeddings = await asyncio.gather(
*[self._get_embedding(prompt) for prompt in sanitized_text_input]
)
return handle_embedding_response(text, embeddings, self.dimensions)

def _truncate_text_to_token_limit(self, text: str, max_tokens: int = 2048) -> str:
"""
Expand All @@ -110,6 +117,12 @@ def _truncate_text_to_token_limit(self, text: str, max_tokens: int = 2048) -> st
logger.warning(
f"Text exceeds character limit ({len(text)} > {char_limit}), truncating..."
)
# TODO: Refactor to better handle truncation, handle it the same as it is handled in LiteLLMEmbeddingEngine
# when the ContextWindowExceededError happens.
# Also max_tokens is never provided to function call so it will always default to 2048, we should make
# it so that it is provided based on the model's context length.
# The char_limit is not a good estimate based on the average number of characters per token, and
# actual value should be based on actual token count using the tokenizer or when the ContextWindowExceededError happens.
return text[:char_limit]
return text

Expand Down
51 changes: 51 additions & 0 deletions cognee/infrastructure/databases/vector/embeddings/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import List, Union
from cognee.shared.logging_utils import setup_logging

logger = setup_logging()


def is_embeddable(s: str) -> bool:
"""
Check if input string is embeddable, if not it will be replaced with a dummy value to prevent API errors.
Empty strings and a string with only a space character are not embeddable.
If input string contains at least one alphanumeric character, it is considered embeddable.
"""
if not isinstance(s, str):
return False
# Strip whitespace to check if the string is empty or only contains spaces
s = s.strip()
if len(s) >= 1:
return True
Comment on lines +8 to +18
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Align the docstring with the actual embeddable rule.

The implementation accepts any non-whitespace string, but Lines 10-11 say the input must contain an alphanumeric character. That already changes how "(" is classified in the new test, so callers currently have two different contracts.

✏️ Suggested docstring fix
-    Empty strings and a string with only a space character are not embeddable.
-    If input string contains at least one alphanumeric character, it is considered embeddable.
+    Empty or all-whitespace strings are not embeddable.
+    Any non-whitespace string is considered embeddable.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cognee/infrastructure/databases/vector/embeddings/utils.py` around lines 8 -
18, The docstring incorrectly states an alphanumeric requirement but the
implementation simply treats any non-whitespace string as embeddable; update the
docstring to reflect the actual rule (i.e., “any non-empty string after
s.strip() is embeddable”) or change the implementation to enforce the
alphanumeric rule — in this case, modify the docstring above the function that
uses parameter s (and mentions s.strip()) so it accurately describes that the
function returns True for any string with length >= 1 after stripping
whitespace.

logger.debug(
"Input string was not embeddable. Skipping embedding and using dummy value instead."
)
return False


def sanitize_embedding_text_inputs(text: Union[str, List[str]]) -> List[str]:
"""
Transform invalid/empty inputs into a safe dummy to prevent API 422 embedding errors while
keeping list length consistent.
"""
# Ensure we are working with a list
text_list = [text] if isinstance(text, str) else text
dummy_value = "."

return [t if is_embeddable(t) else dummy_value for t in text_list]


def handle_embedding_response(
original_texts: Union[List[str], str], embeddings: List[List[float]], dimensions: int
) -> List[List[float]]:
"""
Compare the original input strings against the results.
If the original string was 'junk' that was not embeddable, overwrite its vector with zeros.
"""
if isinstance(original_texts, str):
original_texts = [original_texts]

zero_vector = [0.0] * dimensions
return [
embeddings[i] if is_embeddable(original_texts[i]) else zero_vector
for i in range(len(original_texts))
]
Comment on lines +37 to +51
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't assume dimensions is always set.

dimensions is optional in the engine constructors, but this helper now unconditionally does [0.0] * dimensions. That turns a previously valid “use provider default size” configuration into a TypeError on every response. Please resolve the zero-vector length from embeddings when dimensions is None, and fail fast if the provider returns a different number of vectors than inputs.

🛠️ Proposed fix
-from typing import List, Union
+from typing import List, Optional, Union
@@
 def handle_embedding_response(
-    original_texts: Union[List[str], str], embeddings: List[List[float]], dimensions: int
+    original_texts: Union[List[str], str],
+    embeddings: List[List[float]],
+    dimensions: Optional[int],
 ) -> List[List[float]]:
@@
     if isinstance(original_texts, str):
         original_texts = [original_texts]
+    if len(embeddings) != len(original_texts):
+        raise ValueError(
+            f"Expected {len(original_texts)} embeddings, received {len(embeddings)}."
+        )
 
-    zero_vector = [0.0] * dimensions
+    vector_size = dimensions if dimensions is not None else len(embeddings[0]) if embeddings else 0
     return [
-        embeddings[i] if is_embeddable(original_texts[i]) else zero_vector
+        embeddings[i] if is_embeddable(original_texts[i]) else [0.0] * vector_size
         for i in range(len(original_texts))
     ]

As per coding guidelines, "Prefer explicit, structured error handling in Python code."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cognee/infrastructure/databases/vector/embeddings/utils.py` around lines 37 -
51, handle_embedding_response currently assumes dimensions is provided and
blindly does [0.0] * dimensions; change it to derive the zero-vector length from
the returned embeddings when dimensions is None and validate input/output
counts: if dimensions is None and embeddings is non-empty, set zero_len =
len(embeddings[0]); if embeddings is empty and dimensions is None raise a
ValueError; also fail fast by raising a ValueError if len(embeddings) !=
len(original_texts); then build zero_vector = [0.0] * zero_len and continue
using is_embeddable(original_texts[i]) to choose embeddings[i] or zero_vector.

100 changes: 96 additions & 4 deletions cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.files.storage import get_file_storage
from cognee.modules.storage.utils import copy_model, get_own_properties
from cognee.modules.storage.utils import copy_model
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector.pgvector.serialize_data import serialize_data
from cognee.shared.logging_utils import get_logger

from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
Expand All @@ -24,6 +26,8 @@
COGNEE_VECTOR_RESULT_COUNT,
)

logger = get_logger("LanceDBAdapter")


class IndexSchema(DataPoint):
"""
Expand Down Expand Up @@ -187,8 +191,10 @@ class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]):
payload: PayloadSchema

def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
properties = get_own_properties(data_point)
properties["id"] = str(properties["id"])
payload_model = self.get_data_point_schema(type(data_point))
properties = payload_model.model_validate(
serialize_data(data_point.model_dump())
).model_dump()

return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
id=str(data_point.id),
Expand All @@ -203,14 +209,98 @@ def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> Lance

lance_data_points = list({dp.id: dp for dp in lance_data_points}.values())

try:
async with self.VECTOR_DB_LOCK:
await (
collection.merge_insert("id")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(lance_data_points)
)
except (ValueError, OSError, RuntimeError) as e:
if "not found in target schema" not in str(e):
raise
logger.warning(
"Schema mismatch detected for collection '%s', migrating table: %s",
collection_name,
e,
)
await self._migrate_collection_schema(
collection_name, collection, payload_schema, lance_data_points
)

async def _migrate_collection_schema(
self,
collection_name: str,
old_collection,
payload_schema: type,
new_lance_data_points: list,
):
"""Migrate a LanceDB table to a new schema, preserving existing data."""
rows = (await old_collection.to_arrow()).to_pylist()

vector_size = self.embedding_engine.get_vector_size()
schema_model = self.get_data_point_schema(payload_schema)
data_point_types = get_type_hints(schema_model)
valid_payload_fields = set(schema_model.model_fields.keys())
defaults = self._get_payload_defaults(payload_schema)

new_ids = {dp.id for dp in new_lance_data_points}
old_rows = []
for row in rows:
if row.get("id") in new_ids:
continue
if isinstance(row.get("payload"), dict):
# Strip payload to only fields in the new schema
row["payload"] = {
k: v for k, v in row["payload"].items() if k in valid_payload_fields
}
# Fill in defaults for any new fields
for key, val in defaults.items():
row["payload"].setdefault(key, val)
old_rows.append(row)

class MigrationLanceDataPoint(LanceModel):
id: data_point_types["id"]
vector: Vector(vector_size)
payload: schema_model

async with self.VECTOR_DB_LOCK:
connection = await self.get_connection()
await connection.drop_table(collection_name)
await connection.create_table(
name=collection_name,
schema=MigrationLanceDataPoint,
)
collection = await connection.open_table(collection_name)

if old_rows:
await collection.add(old_rows)

await (
collection.merge_insert("id")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(lance_data_points)
.execute(new_lance_data_points)
)

logger.info(
"Migrated collection '%s' schema (%d existing rows preserved)",
collection_name,
len(old_rows),
)

def _get_payload_defaults(self, payload_schema: type) -> dict:
"""Extract default values from the Pydantic payload model."""
schema_model = self.get_data_point_schema(payload_schema)
defaults = {}
for name, field_info in schema_model.model_fields.items():
if field_info.default is not None and not (
hasattr(field_info, "is_required") and field_info.is_required()
):
defaults[name] = field_info.default
return defaults

async def retrieve(self, collection_name: str, data_point_ids: list[str]):
try:
collection = await self.get_collection(collection_name)
Expand Down Expand Up @@ -396,6 +486,7 @@ async def prune(self):

def get_data_point_schema(self, model_type: BaseModel):
related_models_fields = []

for field_name, field_config in model_type.model_fields.items():
if hasattr(field_config, "model_fields"):
related_models_fields.append(field_name)
Expand Down Expand Up @@ -426,6 +517,7 @@ def get_data_point_schema(self, model_type: BaseModel):
model_type,
include_fields={
"id": (str, ...),
"belongs_to_set": (Optional[List[str]], None),
},
exclude_fields=["metadata"] + related_models_fields,
)
Loading
Loading