Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions mcp_server/src/graphiti_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ def configure_uvicorn_logging():
semaphore: asyncio.Semaphore


def normalize_group_ids(group_ids: str | list[str] | None) -> list[str] | None:
"""Accept either a single group ID string or a list of group IDs."""
if group_ids is None:
return None
if isinstance(group_ids, str):
return [group_ids]
return group_ids


class GraphitiService:
"""Graphiti service using the unified configuration system."""

Expand Down Expand Up @@ -407,7 +416,7 @@ async def add_memory(
@mcp.tool()
async def search_nodes(
query: str,
group_ids: list[str] | None = None,
group_ids: str | list[str] | None = None,
max_nodes: int = 10,
entity_types: list[str] | None = None,
) -> NodeSearchResponse | ErrorResponse:
Expand All @@ -427,10 +436,12 @@ async def search_nodes(
try:
client = await graphiti_service.get_client()

normalized_group_ids = normalize_group_ids(group_ids)

# Use the provided group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids
if group_ids is not None
normalized_group_ids
if normalized_group_ids is not None
else [config.graphiti.group_id]
if config.graphiti.group_id
else []
Expand Down Expand Up @@ -487,7 +498,7 @@ async def search_nodes(
@mcp.tool()
async def search_memory_facts(
query: str,
group_ids: list[str] | None = None,
group_ids: str | list[str] | None = None,
max_facts: int = 10,
center_node_uuid: str | None = None,
) -> FactSearchResponse | ErrorResponse:
Expand All @@ -511,10 +522,12 @@ async def search_memory_facts(

client = await graphiti_service.get_client()

normalized_group_ids = normalize_group_ids(group_ids)

# Use the provided group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids
if group_ids is not None
normalized_group_ids
if normalized_group_ids is not None
else [config.graphiti.group_id]
if config.graphiti.group_id
else []
Expand Down Expand Up @@ -619,7 +632,7 @@ async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:

@mcp.tool()
async def get_episodes(
group_ids: list[str] | None = None,
group_ids: str | list[str] | None = None,
max_episodes: int = 10,
) -> EpisodeSearchResponse | ErrorResponse:
"""Get episodes from the graph memory.
Expand All @@ -636,10 +649,12 @@ async def get_episodes(
try:
client = await graphiti_service.get_client()

normalized_group_ids = normalize_group_ids(group_ids)

# Use the provided group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids
if group_ids is not None
normalized_group_ids
if normalized_group_ids is not None
else [config.graphiti.group_id]
if config.graphiti.group_id
else []
Expand Down Expand Up @@ -686,7 +701,7 @@ async def get_episodes(


@mcp.tool()
async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse:
async def clear_graph(group_ids: str | list[str] | None = None) -> SuccessResponse | ErrorResponse:
"""Clear all data from the graph for specified group IDs.

Args:
Expand All @@ -700,9 +715,11 @@ async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | E
try:
client = await graphiti_service.get_client()

normalized_group_ids = normalize_group_ids(group_ids)

# Use the provided group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids or [config.graphiti.group_id] if config.graphiti.group_id else []
normalized_group_ids or [config.graphiti.group_id] if config.graphiti.group_id else []
)

if not effective_group_ids:
Expand Down
76 changes: 76 additions & 0 deletions mcp_server/tests/test_group_id_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock

import pytest
from graphiti_core.nodes import EpisodicNode

import graphiti_mcp_server as server


class DummyGraphitiService:
def __init__(self, client):
self._client = client
self.entity_types = None

async def get_client(self):
return self._client


@pytest.fixture
def patched_server(monkeypatch):
monkeypatch.setattr(
server,
"config",
SimpleNamespace(graphiti=SimpleNamespace(group_id="default-group")),
raising=False,
)


@pytest.mark.asyncio
async def test_search_nodes_accepts_scalar_group_id(monkeypatch, patched_server):
client = SimpleNamespace(search_=AsyncMock(return_value=SimpleNamespace(nodes=[])))
monkeypatch.setattr(server, "graphiti_service", DummyGraphitiService(client))

result = await server.search_nodes(query="workspace memory", group_ids="ideadb")

assert result["message"] == "No relevant nodes found"
assert client.search_.await_args.kwargs["group_ids"] == ["ideadb"]


@pytest.mark.asyncio
async def test_search_memory_facts_accepts_scalar_group_id(monkeypatch, patched_server):
client = SimpleNamespace(search=AsyncMock(return_value=[]))
monkeypatch.setattr(server, "graphiti_service", DummyGraphitiService(client))

result = await server.search_memory_facts(query="workspace memory", group_ids="ideadb")

assert result["message"] == "No relevant facts found"
assert client.search.await_args.kwargs["group_ids"] == ["ideadb"]


@pytest.mark.asyncio
async def test_get_episodes_accepts_scalar_group_id(monkeypatch, patched_server):
get_by_group_ids = AsyncMock(return_value=[])
monkeypatch.setattr(EpisodicNode, "get_by_group_ids", get_by_group_ids)
monkeypatch.setattr(
server, "graphiti_service", DummyGraphitiService(SimpleNamespace(driver=object()))
)

result = await server.get_episodes(group_ids="ideadb")

assert result["message"] == "No episodes found"
assert get_by_group_ids.await_args.args[1] == ["ideadb"]


@pytest.mark.asyncio
async def test_clear_graph_accepts_scalar_group_id(monkeypatch, patched_server):
clear_data = AsyncMock(return_value=None)
monkeypatch.setattr(server, "clear_data", clear_data)
monkeypatch.setattr(
server, "graphiti_service", DummyGraphitiService(SimpleNamespace(driver=object()))
)

result = await server.clear_graph(group_ids="ideadb")

assert result["message"] == "Graph data cleared successfully for group IDs: ideadb"
assert clear_data.await_args.kwargs["group_ids"] == ["ideadb"]
Loading