Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/strands/session/repository_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,34 @@ def _fix_broken_tool_use(self, messages: list[Message]) -> list[Message]:
]

# Check if there are more messages after the current toolUse message
tool_result_ids = [
content["toolResult"]["toolUseId"]
for content in messages[index + 1]["content"]
if "toolResult" in content
]
next_message_content = messages[index + 1]["content"]
seen_tool_result_ids: set[str] = set()
cleaned_next_message_content = []
removed_orphaned_tool_results = False
for content in next_message_content:
if "toolResult" not in content:
cleaned_next_message_content.append(content)
continue

tool_result_id = content["toolResult"]["toolUseId"]
if tool_result_id in tool_use_ids and tool_result_id not in seen_tool_result_ids:
seen_tool_result_ids.add(tool_result_id)
cleaned_next_message_content.append(content)
else:
removed_orphaned_tool_results = True

missing_tool_use_ids = list(set(tool_use_ids) - set(tool_result_ids))
if removed_orphaned_tool_results:
logger.warning(
"Session message history has orphaned or duplicate toolResult blocks. "
"Removing them to keep toolUse/toolResult pairs valid."
)
messages[index + 1]["content"] = cleaned_next_message_content

tool_result_ids = list(seen_tool_result_ids)

missing_tool_use_ids = [
tool_use_id for tool_use_id in tool_use_ids if tool_use_id not in tool_result_ids
]
# If there are missing tool use ids, that means the messages history is broken
if missing_tool_use_ids:
logger.warning(
Expand Down
29 changes: 29 additions & 0 deletions tests/strands/session/test_repository_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,35 @@ def test_fix_broken_tool_use_extends_partial_tool_results(existing_session_manag
assert missing_result["toolResult"]["content"][0]["text"] == "Tool was interrupted."


def test_fix_broken_tool_use_prunes_extra_tool_results(session_manager):
"""Test fixing a user message with extra toolResults before adding missing ones."""
messages = [
{
"role": "assistant",
"content": [
{"toolUse": {"toolUseId": "complete-123", "name": "test_tool", "input": {"input": "test1"}}},
{"toolUse": {"toolUseId": "missing-456", "name": "test_tool", "input": {"input": "test2"}}},
],
},
{
"role": "user",
"content": [
{"toolResult": {"toolUseId": "complete-123", "status": "success", "content": [{"text": "result"}]}},
{"toolResult": {"toolUseId": "complete-123", "status": "success", "content": [{"text": "dup"}]}},
{"toolResult": {"toolUseId": "stale-999", "status": "success", "content": [{"text": "stale"}]}},
],
},
]

fixed_messages = session_manager._fix_broken_tool_use(messages)

tool_results = [content["toolResult"] for content in fixed_messages[1]["content"] if "toolResult" in content]
assert [tool_result["toolUseId"] for tool_result in tool_results] == ["complete-123", "missing-456"]
missing_result = next(tool_result for tool_result in tool_results if tool_result["toolUseId"] == "missing-456")
assert missing_result["status"] == "error"
assert missing_result["content"][0]["text"] == "Tool was interrupted."


def test_fix_broken_tool_use_handles_multiple_orphaned_tools(existing_session_manager):
"""Test fixing multiple orphaned toolUse messages."""

Expand Down