From ba8ce6642a769bf3c713ad6b1524741b57fc03a3 Mon Sep 17 00:00:00 2001 From: Kristofer Jussmann Date: Sat, 28 Mar 2026 16:26:17 +0200 Subject: [PATCH] fix(mcp): accept scalar group_ids in read tools --- mcp_server/src/graphiti_mcp_server.py | 39 +++++++--- .../tests/test_group_id_normalization.py | 76 +++++++++++++++++++ 2 files changed, 104 insertions(+), 11 deletions(-) create mode 100644 mcp_server/tests/test_group_id_normalization.py diff --git a/mcp_server/src/graphiti_mcp_server.py b/mcp_server/src/graphiti_mcp_server.py index 833bc5d93..352f8cbbe 100644 --- a/mcp_server/src/graphiti_mcp_server.py +++ b/mcp_server/src/graphiti_mcp_server.py @@ -158,6 +158,15 @@ def configure_uvicorn_logging(): semaphore: asyncio.Semaphore +def normalize_group_ids(group_ids: str | list[str] | None) -> list[str] | None: + """Accept either a single group ID string or a list of group IDs.""" + if group_ids is None: + return None + if isinstance(group_ids, str): + return [group_ids] + return group_ids + + class GraphitiService: """Graphiti service using the unified configuration system.""" @@ -407,7 +416,7 @@ async def add_memory( @mcp.tool() async def search_nodes( query: str, - group_ids: list[str] | None = None, + group_ids: str | list[str] | None = None, max_nodes: int = 10, entity_types: list[str] | None = None, ) -> NodeSearchResponse | ErrorResponse: @@ -427,10 +436,12 @@ async def search_nodes( try: client = await graphiti_service.get_client() + normalized_group_ids = normalize_group_ids(group_ids) + # Use the provided group_ids or fall back to the default from config if none provided effective_group_ids = ( - group_ids - if group_ids is not None + normalized_group_ids + if normalized_group_ids is not None else [config.graphiti.group_id] if config.graphiti.group_id else [] @@ -487,7 +498,7 @@ async def search_nodes( @mcp.tool() async def search_memory_facts( query: str, - group_ids: list[str] | None = None, + group_ids: str | list[str] | None = None, max_facts: int = 10, center_node_uuid: str | None = None, ) -> FactSearchResponse | ErrorResponse: @@ -511,10 +522,12 @@ async def search_memory_facts( client = await graphiti_service.get_client() + normalized_group_ids = normalize_group_ids(group_ids) + # Use the provided group_ids or fall back to the default from config if none provided effective_group_ids = ( - group_ids - if group_ids is not None + normalized_group_ids + if normalized_group_ids is not None else [config.graphiti.group_id] if config.graphiti.group_id else [] @@ -619,7 +632,7 @@ async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse: @mcp.tool() async def get_episodes( - group_ids: list[str] | None = None, + group_ids: str | list[str] | None = None, max_episodes: int = 10, ) -> EpisodeSearchResponse | ErrorResponse: """Get episodes from the graph memory. @@ -636,10 +649,12 @@ async def get_episodes( try: client = await graphiti_service.get_client() + normalized_group_ids = normalize_group_ids(group_ids) + # Use the provided group_ids or fall back to the default from config if none provided effective_group_ids = ( - group_ids - if group_ids is not None + normalized_group_ids + if normalized_group_ids is not None else [config.graphiti.group_id] if config.graphiti.group_id else [] @@ -686,7 +701,7 @@ async def get_episodes( @mcp.tool() -async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse: +async def clear_graph(group_ids: str | list[str] | None = None) -> SuccessResponse | ErrorResponse: """Clear all data from the graph for specified group IDs. Args: @@ -700,9 +715,11 @@ async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | E try: client = await graphiti_service.get_client() + normalized_group_ids = normalize_group_ids(group_ids) + # Use the provided group_ids or fall back to the default from config if none provided effective_group_ids = ( - group_ids or [config.graphiti.group_id] if config.graphiti.group_id else [] + normalized_group_ids or [config.graphiti.group_id] if config.graphiti.group_id else [] ) if not effective_group_ids: diff --git a/mcp_server/tests/test_group_id_normalization.py b/mcp_server/tests/test_group_id_normalization.py new file mode 100644 index 000000000..81a77eda6 --- /dev/null +++ b/mcp_server/tests/test_group_id_normalization.py @@ -0,0 +1,76 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest +from graphiti_core.nodes import EpisodicNode + +import graphiti_mcp_server as server + + +class DummyGraphitiService: + def __init__(self, client): + self._client = client + self.entity_types = None + + async def get_client(self): + return self._client + + +@pytest.fixture +def patched_server(monkeypatch): + monkeypatch.setattr( + server, + "config", + SimpleNamespace(graphiti=SimpleNamespace(group_id="default-group")), + raising=False, + ) + + +@pytest.mark.asyncio +async def test_search_nodes_accepts_scalar_group_id(monkeypatch, patched_server): + client = SimpleNamespace(search_=AsyncMock(return_value=SimpleNamespace(nodes=[]))) + monkeypatch.setattr(server, "graphiti_service", DummyGraphitiService(client)) + + result = await server.search_nodes(query="workspace memory", group_ids="ideadb") + + assert result["message"] == "No relevant nodes found" + assert client.search_.await_args.kwargs["group_ids"] == ["ideadb"] + + +@pytest.mark.asyncio +async def test_search_memory_facts_accepts_scalar_group_id(monkeypatch, patched_server): + client = SimpleNamespace(search=AsyncMock(return_value=[])) + monkeypatch.setattr(server, "graphiti_service", DummyGraphitiService(client)) + + result = await server.search_memory_facts(query="workspace memory", group_ids="ideadb") + + assert result["message"] == "No relevant facts found" + assert client.search.await_args.kwargs["group_ids"] == ["ideadb"] + + +@pytest.mark.asyncio +async def test_get_episodes_accepts_scalar_group_id(monkeypatch, patched_server): + get_by_group_ids = AsyncMock(return_value=[]) + monkeypatch.setattr(EpisodicNode, "get_by_group_ids", get_by_group_ids) + monkeypatch.setattr( + server, "graphiti_service", DummyGraphitiService(SimpleNamespace(driver=object())) + ) + + result = await server.get_episodes(group_ids="ideadb") + + assert result["message"] == "No episodes found" + assert get_by_group_ids.await_args.args[1] == ["ideadb"] + + +@pytest.mark.asyncio +async def test_clear_graph_accepts_scalar_group_id(monkeypatch, patched_server): + clear_data = AsyncMock(return_value=None) + monkeypatch.setattr(server, "clear_data", clear_data) + monkeypatch.setattr( + server, "graphiti_service", DummyGraphitiService(SimpleNamespace(driver=object())) + ) + + result = await server.clear_graph(group_ids="ideadb") + + assert result["message"] == "Graph data cleared successfully for group IDs: ideadb" + assert clear_data.await_args.kwargs["group_ids"] == ["ideadb"]