Skip to content

Commit ae2cec8

Browse files
committed
Fix: Resolve OOM DoS via payload truncation/memory bounds and optimize FAISS cache
1 parent 6a26d68 commit ae2cec8

16 files changed

Lines changed: 166 additions & 101 deletions

File tree

chatbot-core/api/models/embedding_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66

77
logger = LoggerFactory.instance().get_logger("api")
88

9-
EMBEDDING_MODEL = load_embedding_model(CONFIG["retrieval"]["embedding_model_name"], logger)
9+
EMBEDDING_MODEL = load_embedding_model(
10+
CONFIG["retrieval"]["embedding_model_name"])

chatbot-core/api/routes/chatbot.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# Third-party imports
1818
# =========================
1919
from typing import List, Optional
20+
from urllib import request
2021
from fastapi import (
2122
APIRouter,
2223
HTTPException,
@@ -106,13 +107,19 @@ async def chatbot_stream(websocket: WebSocket, session_id: str):
106107
message_data = json.loads(data)
107108
user_message = message_data.get("message", "")
108109

110+
if len(user_message) > 2000:
111+
logger.warning(
112+
f"Truncated massive WebSocket payload from session {session_id}")
113+
user_message = user_message[:2000]
114+
109115
if not user_message:
110116
continue
111117

112118
async for token in get_chatbot_reply_stream(
113119
session_id,
114120
user_message,
115121
):
122+
116123
await websocket.send_text(
117124
json.dumps({"token": token})
118125
)
@@ -166,6 +173,7 @@ def start_chat(response: Response):
166173
)
167174
return SessionResponse(session_id=session_id)
168175

176+
169177
@router.delete(
170178
"/sessions/{session_id}",
171179
response_model=DeleteResponse,
@@ -191,7 +199,6 @@ def delete_chat(session_id: str):
191199
# Chat Endpoint
192200
@router.post("/sessions/{session_id}/message", response_model=ChatResponse)
193201
def chatbot_reply(session_id: str, request: ChatRequest, _background_tasks: BackgroundTasks):
194-
195202
"""
196203
POST endpoint to handle chatbot replies.
197204
@@ -210,11 +217,16 @@ def chatbot_reply(session_id: str, request: ChatRequest, _background_tasks: Back
210217
status_code=404,
211218
detail="Session not found.",
212219
)
213-
reply = get_chatbot_reply(session_id, request.message)
220+
221+
if len(request.message) > 2000:
222+
logger.warning(f"Truncated massive payload from session {session_id}")
223+
request.message = request.message[:2000]
224+
225+
reply = get_chatbot_reply(session_id, request.message)
214226
_background_tasks.add_task(
215227
persist_session,
216228
session_id,
217-
)
229+
)
218230

219231
return reply
220232

@@ -263,6 +275,10 @@ async def chatbot_reply_with_files(
263275
status_code=422,
264276
detail="Either message or files must be provided.",
265277
)
278+
if has_message and len(message) > 2000:
279+
logger.warning(
280+
f"Truncated massive file upload message from session {session_id}")
281+
message = message[:2000]
266282

267283
# Process uploaded files
268284
processed_files: List[FileAttachment] = []

chatbot-core/api/services/chat_service.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def get_chatbot_reply(
6262

6363
memory = get_session(session_id)
6464
if memory is None:
65-
raise RuntimeError(f"Session '{session_id}' not found in the memory store.")
65+
raise RuntimeError(
66+
f"Session '{session_id}' not found in the memory store.")
6667

6768
context = retrieve_context(user_input)
6869
logger.info("Context retrieved: %s", context)
@@ -333,7 +334,8 @@ def _execute_search_tools(tool_calls) -> str:
333334
})
334335

335336
return "\n\n".join(
336-
f"[Result of the search tool {res['tool']}]:\n{res.get('output', '')}".strip()
337+
f"[Result of the search tool {res['tool']}]:\n{res.get('output', '')}".strip(
338+
)
337339
for res in retrieved_results
338340
)
339341

@@ -381,7 +383,6 @@ def retrieve_context(user_input: str) -> str:
381383
data_retrieved, _ = get_relevant_documents(
382384
user_input,
383385
EMBEDDING_MODEL,
384-
logger=logger,
385386
source_name="plugins",
386387
top_k=retrieval_config["top_k"]
387388
)
@@ -434,10 +435,12 @@ def generate_answer(prompt: str, max_tokens: Optional[int] = None) -> str:
434435
logger.error("LLM provider unavailable: %s", e)
435436
return "LLM is not available. Please install llama-cpp-python and configure a model."
436437
except (ValueError, RuntimeError) as exc:
437-
logger.error("LLM generation failed for prompt: %r. Error: %r", prompt, exc)
438+
logger.error(
439+
"LLM generation failed for prompt: %r. Error: %r", prompt, exc)
438440
return "Sorry, I'm having trouble generating a response right now."
439441
except Exception: # pylint: disable=broad-except
440-
logger.exception("Unexpected error during LLM generation for prompt: %r", prompt)
442+
logger.exception(
443+
"Unexpected error during LLM generation for prompt: %r", prompt)
441444
return "Sorry, an unexpected error occurred. Please contact support."
442445

443446

@@ -532,6 +535,7 @@ def _extract_relevance_score(response: str) -> str:
532535

533536
return relevance_score
534537

538+
535539
def _generate_search_query_from_logs(log_text: str) -> str:
536540
"""
537541
Uses the LLM to extract a concise error signature from the logs

chatbot-core/api/services/memory.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import uuid
77
from datetime import datetime, timedelta
88
from threading import Lock
9-
from langchain.memory import ConversationBufferMemory
9+
from typing import Optional
10+
from langchain.memory import ConversationBufferWindowMemory
1011
from api.config.loader import CONFIG
11-
from api.services.sessionmanager import(
12+
from api.services.sessionmanager import (
1213
delete_session_file,
1314
load_session,
1415
session_exists_in_json,
@@ -31,13 +32,13 @@ def init_session() -> str:
3132
session_id = str(uuid.uuid4())
3233
with _lock:
3334
_sessions[session_id] = {
34-
"memory": ConversationBufferMemory(return_messages=True),
35+
"memory": ConversationBufferWindowMemory(k=10, return_messages=True),
3536
"last_accessed": datetime.now()
3637
}
3738
return session_id
3839

3940

40-
def get_session(session_id: str) -> ConversationBufferMemory | None:
41+
def get_session(session_id: str) -> Optional[ConversationBufferWindowMemory]:
4142
"""
4243
Retrieve the chat session memory for the given session ID.
4344
Lazily restores from disk if missing in memory.
@@ -46,24 +47,24 @@ def get_session(session_id: str) -> ConversationBufferMemory | None:
4647
session_id (str): The session identifier.
4748
4849
Returns:
49-
ConversationBufferMemory | None: The memory object if found, else None.
50+
Optional[ConversationBufferWindowMemory]: The memory object if found, else None.
5051
"""
5152

5253
with _lock:
5354

5455
session_data = _sessions.get(session_id)
5556

56-
if session_data :
57+
if session_data:
5758
session_data["last_accessed"] = datetime.now()
5859
return session_data["memory"]
5960

6061
history = load_session(session_id)
6162
if not history:
6263
return None
6364

64-
memory = ConversationBufferMemory(return_messages=True)
65+
memory = ConversationBufferWindowMemory(k=10, return_messages=True)
6566
for msg in history:
66-
memory.chat_memory.add_message(# pylint: disable=no-member
67+
memory.chat_memory.add_message( # pylint: disable=no-member
6768
{
6869
"role": msg["role"],
6970
"content": msg["content"],
@@ -77,14 +78,15 @@ def get_session(session_id: str) -> ConversationBufferMemory | None:
7778

7879
return memory
7980

80-
async def get_session_async(session_id: str) -> ConversationBufferMemory | None:
81+
82+
async def get_session_async(session_id: str) -> Optional[ConversationBufferWindowMemory]:
8183
"""
8284
Async wrapper for get_session to prevent event loop blocking.
8385
"""
8486
return await asyncio.to_thread(get_session, session_id)
8587

8688

87-
def persist_session(session_id: str)-> None:
89+
def persist_session(session_id: str) -> None:
8890
"""
8991
Persist the current session messages to disk.
9092
@@ -97,7 +99,6 @@ def persist_session(session_id: str)-> None:
9799
append_message(session_id, messages)
98100

99101

100-
101102
def delete_session(session_id: str) -> bool:
102103
"""
103104
Delete a chat session and its persisted data.
@@ -138,7 +139,8 @@ def reset_sessions():
138139
with _lock:
139140
_sessions.clear()
140141

141-
def get_last_accessed(session_id: str) -> datetime | None:
142+
143+
def get_last_accessed(session_id: str) -> Optional[datetime]:
142144
"""
143145
Get the last accessed timestamp for a given session.
144146
@@ -157,9 +159,9 @@ def get_last_accessed(session_id: str) -> datetime | None:
157159
if not history:
158160
return None
159161

160-
161162
return history["last_accessed"]
162163

164+
163165
def set_last_accessed(session_id: str, timestamp: datetime) -> bool:
164166
"""
165167
Set the last accessed timestamp for a given session (for testing purposes).
@@ -186,6 +188,7 @@ def set_last_accessed(session_id: str, timestamp: datetime) -> bool:
186188

187189
return False
188190

191+
189192
def get_session_count() -> int:
190193
"""
191194
Get the total number of active sessions (for testing purposes).
@@ -196,6 +199,7 @@ def get_session_count() -> int:
196199
with _lock:
197200
return len(_sessions)
198201

202+
199203
def cleanup_expired_sessions() -> int:
200204
"""
201205
Remove sessions that have not been accessed within the configured timeout period.

chatbot-core/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pytest_plugins = ["tests.unit.mocks.test_env"]

chatbot-core/rag/embedding/embedding_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
"""
44

55
from sentence_transformers import SentenceTransformer
6+
import logging
67

7-
def load_embedding_model(model_name, logger):
8+
logger = logging.getLogger(__name__)
9+
10+
11+
def load_embedding_model(model_name):
812
"""
913
Load the sentence transformer model for generating text embeddings.
1014
@@ -14,7 +18,8 @@ def load_embedding_model(model_name, logger):
1418
logger.info(f"Loading embedding model: {model_name}")
1519
return SentenceTransformer(model_name)
1620

17-
def embed_documents(texts, model, logger, batch_size=32):
21+
22+
def embed_documents(texts, model, batch_size=32):
1823
"""
1924
Embed a list of text documents into dense vector representations using the given model.
2025

chatbot-core/rag/retriever/retrieve.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55
from rag.embedding.embedding_utils import embed_documents
66
from rag.retriever.retriever_utils import load_vector_index, search_index
77
from api.config.loader import CONFIG
8+
import logging
89

9-
def get_relevant_documents(query, model, logger, source_name, top_k=5):
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def get_relevant_documents(query, model, source_name, top_k=5):
1014
"""
1115
Retrieve the top-k most relevant chunks for a given natural language query.
1216
@@ -24,13 +28,13 @@ def get_relevant_documents(query, model, logger, source_name, top_k=5):
2428
logger.warning("Empty query received.")
2529
return [], []
2630

27-
index, metadata = load_vector_index(logger, source_name)
31+
index, metadata = load_vector_index(source_name)
2832

2933
if not index or not metadata:
3034
return [], []
3135

32-
query_vector = embed_documents([query], model, logger)[0]
33-
data, scores = search_index(query_vector, index, metadata, logger, top_k)
36+
query_vector = embed_documents([query], model)[0]
37+
data, scores = search_index(query_vector, index, metadata, top_k)
3438

3539
filtered = [(d, s) for d, s in zip(data, scores)
3640
if s <= CONFIG["retrieval"]["semantic_threshold"]]

chatbot-core/rag/retriever/retriever_utils.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,21 @@
33
to retrieve relevant document chunks based on a query vector.
44
"""
55

6+
67
import os
78
import numpy as np
9+
from functools import lru_cache
10+
import logging
811
from rag.vectorstore.vectorstore_utils import load_faiss_index, load_metadata
912

10-
VECTOR_STORE_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "data", "embeddings")
13+
VECTOR_STORE_DIR = os.path.join(os.path.dirname(
14+
__file__), "..", "..", "data", "embeddings")
15+
16+
logger = logging.getLogger(__name__)
1117

12-
def load_vector_index(logger, source_name):
18+
19+
@lru_cache(maxsize=1)
20+
def load_vector_index(source_name):
1321
"""
1422
Load the FAISS index and associated metadata from disk.
1523
@@ -24,14 +32,16 @@ def load_vector_index(logger, source_name):
2432
logger.warning("No source name provided. Returning empty results.")
2533
return [], []
2634
index_path = os.path.join(VECTOR_STORE_DIR, f"{source_name}_index.idx")
27-
metadata_path = os.path.join(VECTOR_STORE_DIR, f"{source_name}_metadata.pkl")
35+
metadata_path = os.path.join(
36+
VECTOR_STORE_DIR, f"{source_name}_metadata.pkl")
2837

29-
index = load_faiss_index(index_path, logger)
30-
metadata = load_metadata(metadata_path, logger)
38+
index = load_faiss_index(index_path)
39+
metadata = load_metadata(metadata_path)
3140

3241
return index, metadata
3342

34-
def search_index(query_vector, index, metadata, logger, top_k):
43+
44+
def search_index(query_vector, index, metadata, top_k):
3545
"""
3646
Search the FAISS index with a query vector and return the top-k closest metadata results.
3747
@@ -54,7 +64,7 @@ def search_index(query_vector, index, metadata, logger, top_k):
5464

5565
if index.ntotal != len(metadata):
5666
logger.warning(
57-
"Index contains %d vectors but metadata has %d entries." \
67+
"Index contains %d vectors but metadata has %d entries."
5868
" Some results may be missing or inconsistent.",
5969
index.ntotal,
6070
len(metadata)
@@ -73,9 +83,9 @@ def search_index(query_vector, index, metadata, logger, top_k):
7383
})
7484
else:
7585
logger.error("FAISS returned index %d out of range (metadata size: %d)",
76-
idx,
77-
len(metadata)
78-
)
86+
idx,
87+
len(metadata)
88+
)
7989

8090
data = []
8191
scores = []

0 commit comments

Comments
 (0)