Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import warnings
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -61,6 +62,21 @@
from llama_index.core.tools.types import BaseTool


def _parse_tool_input(raw_input: Any) -> Any:
"""Parse tool input from string to dict if needed.

During streaming, tool call input is accumulated as a concatenated JSON
string. This helper parses it into a dict so that ToolCallBlock.tool_kwargs
is always a dict, matching the non-streaming behavior.
"""
if isinstance(raw_input, str):
try:
return json.loads(raw_input)
except (json.JSONDecodeError, ValueError):
return {}
return raw_input


class BedrockConverse(FunctionCallingLLM):
"""
Bedrock Converse LLM.
Expand Down Expand Up @@ -590,7 +606,7 @@ def gen() -> ChatResponseGen:
for tool_call in tool_calls:
blocks.append(
ToolCallBlock(
tool_kwargs=tool_call.get("input", {}),
tool_kwargs=_parse_tool_input(tool_call.get("input", {})),
tool_name=tool_call.get("name", ""),
tool_call_id=tool_call.get("toolUseId"),
)
Expand Down Expand Up @@ -646,7 +662,7 @@ def gen() -> ChatResponseGen:
for tool_call in tool_calls:
blocks.append(
ToolCallBlock(
tool_kwargs=tool_call.get("input", {}),
tool_kwargs=_parse_tool_input(tool_call.get("input", {})),
tool_name=tool_call.get("name", ""),
tool_call_id=tool_call.get("toolUseId"),
)
Expand Down Expand Up @@ -690,7 +706,7 @@ def gen() -> ChatResponseGen:
for tool_call in tool_calls:
blocks.append(
ToolCallBlock(
tool_kwargs=tool_call.get("input", {}),
tool_kwargs=_parse_tool_input(tool_call.get("input", {})),
tool_name=tool_call.get("name", ""),
tool_call_id=tool_call.get("toolUseId"),
)
Expand Down Expand Up @@ -874,7 +890,7 @@ async def gen() -> ChatResponseAsyncGen:
for tool_call in tool_calls:
blocks.append(
ToolCallBlock(
tool_kwargs=tool_call.get("input", {}),
tool_kwargs=_parse_tool_input(tool_call.get("input", {})),
tool_name=tool_call.get("name", ""),
tool_call_id=tool_call.get("toolUseId"),
)
Expand Down Expand Up @@ -930,7 +946,7 @@ async def gen() -> ChatResponseAsyncGen:
for tool_call in tool_calls:
blocks.append(
ToolCallBlock(
tool_kwargs=tool_call.get("input", {}),
tool_kwargs=_parse_tool_input(tool_call.get("input", {})),
tool_name=tool_call.get("name", ""),
tool_call_id=tool_call.get("toolUseId"),
)
Expand Down Expand Up @@ -975,7 +991,7 @@ async def gen() -> ChatResponseAsyncGen:
for tool_call in tool_calls:
blocks.append(
ToolCallBlock(
tool_kwargs=tool_call.get("input", {}),
tool_kwargs=_parse_tool_input(tool_call.get("input", {})),
tool_name=tool_call.get("name", ""),
tool_call_id=tool_call.get("toolUseId"),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,185 @@ def test_stream_chat(bedrock_converse):
assert final_response.additional_kwargs["total_tokens"] == 41


def test_stream_chat_tool_kwargs_parsed_as_dict(monkeypatch):
"""Test that streaming tool call input is parsed from string to dict.

Bedrock ConverseStream delivers tool use input as string chunks.
After accumulation, ToolCallBlock.tool_kwargs should be a dict,
not a raw JSON string.

Regression test for https://github.com/run-llama/llama_index/issues/21579
"""

class ToolStreamMockClient:
def __init__(self):
self.exceptions = MockExceptions()

def converse(self, *args, **kwargs):
return {"output": {"message": {"content": [{"text": EXP_RESPONSE}]}}}

def converse_stream(self, *args, **kwargs):
def stream_generator():
# contentBlockStart: tool use block begins
yield {
"contentBlockStart": {
"start": {
"toolUse": {
"toolUseId": "tool-1",
"name": "get_weather",
}
},
"contentBlockIndex": 0,
}
}
# contentBlockDelta: partial JSON string chunks
yield {
"contentBlockDelta": {
"delta": {"toolUse": {"input": '{"locat'}},
"contentBlockIndex": 0,
}
}
yield {
"contentBlockDelta": {
"delta": {"toolUse": {"input": 'ion": "London"}'}},
"contentBlockIndex": 0,
}
}
yield {"messageStop": {"stopReason": "tool_use"}}
yield {
"metadata": {
"usage": {
"inputTokens": 10,
"outputTokens": 20,
"totalTokens": 30,
},
}
}

return {"stream": stream_generator()}

monkeypatch.setattr(
"boto3.Session.client", lambda *a, **kw: ToolStreamMockClient()
)
monkeypatch.setattr("aioboto3.Session", MockAsyncSession)

llm = BedrockConverse(
model=EXP_MODEL,
max_tokens=EXP_MAX_TOKENS,
temperature=EXP_TEMPERATURE,
)

responses = list(llm.stream_chat(messages))

# Collect all ToolCallBlocks from the final response
final = responses[-1]
tool_blocks = [
b for b in final.message.blocks if isinstance(b, ToolCallBlock)
]
assert len(tool_blocks) == 1
# tool_kwargs must be a dict, not a JSON string
assert isinstance(tool_blocks[0].tool_kwargs, dict), (
f"Expected dict, got {type(tool_blocks[0].tool_kwargs)}: {tool_blocks[0].tool_kwargs}"
)
assert tool_blocks[0].tool_kwargs == {"location": "London"}
assert tool_blocks[0].tool_name == "get_weather"
assert tool_blocks[0].tool_call_id == "tool-1"


@pytest.mark.asyncio
async def test_astream_chat_tool_kwargs_parsed_as_dict(monkeypatch):
"""Async variant: streaming tool call input is parsed from string to dict.

Regression test for https://github.com/run-llama/llama_index/issues/21579
"""

class ToolStreamAsyncMockClient:
def __init__(self):
self.exceptions = MockExceptions()

async def __aenter__(self):
return self

async def __aexit__(self, *args):
pass

async def converse(self, *args, **kwargs):
return {"output": {"message": {"content": [{"text": EXP_RESPONSE}]}}}

async def converse_stream(self, *args, **kwargs):
async def stream_generator():
yield {
"contentBlockStart": {
"start": {
"toolUse": {
"toolUseId": "tool-1",
"name": "get_weather",
}
},
"contentBlockIndex": 0,
}
}
yield {
"contentBlockDelta": {
"delta": {"toolUse": {"input": '{"locat'}},
"contentBlockIndex": 0,
}
}
yield {
"contentBlockDelta": {
"delta": {"toolUse": {"input": 'ion": "London"}'}},
"contentBlockIndex": 0,
}
}
yield {"messageStop": {"stopReason": "tool_use"}}
yield {
"metadata": {
"usage": {
"inputTokens": 10,
"outputTokens": 20,
"totalTokens": 30,
},
}
}

return {"stream": stream_generator()}

class ToolStreamAsyncSession:
def __init__(self, *args, **kwargs):
pass

def client(self, *args, **kwargs):
return ToolStreamAsyncMockClient()

monkeypatch.setattr(
"boto3.Session.client", lambda *a, **kw: MockClient()
)
monkeypatch.setattr("aioboto3.Session", ToolStreamAsyncSession)

llm = BedrockConverse(
model=EXP_MODEL,
max_tokens=EXP_MAX_TOKENS,
temperature=EXP_TEMPERATURE,
)

response_stream = await llm.astream_chat(messages)
responses = []
async for r in response_stream:
responses.append(r)

final = responses[-1]
tool_blocks = [
b for b in final.message.blocks if isinstance(b, ToolCallBlock)
]
assert len(tool_blocks) == 1
assert isinstance(tool_blocks[0].tool_kwargs, dict), (
f"Expected dict, got {type(tool_blocks[0].tool_kwargs)}: {tool_blocks[0].tool_kwargs}"
)
assert tool_blocks[0].tool_kwargs == {"location": "London"}
assert tool_blocks[0].tool_name == "get_weather"
assert tool_blocks[0].tool_call_id == "tool-1"


@pytest.mark.asyncio
async def test_achat(bedrock_converse):
response = await bedrock_converse.achat(messages)
Expand Down