Skip to content

Commit 9467ba2

Browse files
committed
fix: prevent duplicate tool calls when processing tool results
When a tool result submission arrived, the message processing loop was handling system/assistant messages first, which started a new execution before the model saw the function response. This caused the model to call the same tool again. Changes: - Skip directly to tool results when pending tool calls exist - Handle plain string tool results by wrapping them in a result object (fixes JSON parse errors for non-JSON tool responses)
1 parent 6bd0a77 commit 9467ba2

File tree

1 file changed

+55
-39
lines changed
  • integrations/adk-middleware/python/src/ag_ui_adk

1 file changed

+55
-39
lines changed

integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -350,14 +350,14 @@ def _create_runner(self, adk_agent: BaseAgent, user_id: str, app_name: str) -> R
350350

351351
async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]:
352352
"""Run the ADK agent with client-side tool support.
353-
353+
354354
All client-side tools are long-running. For tool result submissions,
355355
we continue existing executions. For new requests, we start new executions.
356356
ADK sessions handle conversation continuity and tool result processing.
357-
357+
358358
Args:
359359
input: The AG-UI run input
360-
360+
361361
Yields:
362362
AG-UI protocol events
363363
"""
@@ -374,6 +374,27 @@ async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]:
374374
app_name = self._get_app_name(input)
375375
skip_tool_message_batch = False
376376

377+
# Check if there are pending tool calls AND tool results in unseen messages
378+
# If so, we should skip to the tool results first
379+
pending_tool_call_ids = await self._get_pending_tool_call_ids(input.thread_id)
380+
has_pending_tools = pending_tool_call_ids is not None and len(pending_tool_call_ids) > 0
381+
has_tool_results_in_unseen = any(getattr(msg, "role", None) == "tool" for msg in unseen_messages)
382+
383+
if has_pending_tools and has_tool_results_in_unseen:
384+
# Find the index of the first tool result and skip to it
385+
for i, msg in enumerate(unseen_messages):
386+
if getattr(msg, "role", None) == "tool":
387+
# Mark all messages before the tool result as processed (they're already in the ADK session)
388+
skipped_ids = []
389+
for j in range(i):
390+
msg_id = getattr(unseen_messages[j], "id", None)
391+
if msg_id:
392+
skipped_ids.append(msg_id)
393+
if skipped_ids:
394+
self._session_manager.mark_messages_processed(app_name, input.thread_id, skipped_ids)
395+
index = i
396+
break
397+
377398
while index < total_unseen:
378399
current = unseen_messages[index]
379400
role = getattr(current, "role", None)
@@ -621,24 +642,19 @@ async def _handle_tool_result_submission(
621642
return
622643

623644
try:
624-
# Remove tool calls from pending list
645+
# Remove tool calls from pending list and track which ones we processed
646+
processed_tool_ids = []
625647
for tool_result in tool_results:
626648
tool_call_id = tool_result['message'].tool_call_id
627649
has_pending = await self._has_pending_tool_calls(thread_id)
628650

629651
if has_pending:
630-
# Could add more specific check here for the exact tool_call_id
631-
# but for now just log that we're processing a tool result while tools are pending
632-
logger.debug(f"Processing tool result {tool_call_id} for thread {thread_id} with pending tools")
633652
# Remove from pending tool calls now that we're processing it
634653
await self._remove_pending_tool_call(thread_id, tool_call_id)
635-
else:
636-
# No pending tools - this could be a stale result or from a different session
637-
logger.warning(f"No pending tool calls found for tool result {tool_call_id} in thread {thread_id}")
654+
processed_tool_ids.append(tool_call_id)
638655

639656
# Since all tools are long-running, all tool results are standalone
640657
# and should start new executions with the tool results
641-
logger.info(f"Starting new execution for tool result in thread {thread_id}")
642658

643659
# Use trailing_messages if provided, otherwise fall back to candidate_messages
644660
message_batch = trailing_messages if trailing_messages else (candidate_messages if include_message_batch else None)
@@ -884,14 +900,11 @@ async def _start_new_execution(
884900
if input.thread_id in self._active_executions:
885901
execution = self._active_executions[input.thread_id]
886902
execution.is_complete = True
887-
903+
888904
# Check if session has pending tool calls before cleanup
889905
has_pending = await self._has_pending_tool_calls(input.thread_id)
890906
if not has_pending:
891907
del self._active_executions[input.thread_id]
892-
logger.debug(f"Cleaned up execution for thread {input.thread_id}")
893-
else:
894-
logger.info(f"Preserving execution for thread {input.thread_id} - has pending tool calls (HITL scenario)")
895908

896909
async def _start_background_execution(
897910
self,
@@ -1094,24 +1107,24 @@ async def _run_adk_in_background(
10941107
# Debug: Log the actual tool message content we received
10951108
logger.debug(f"Received tool result for call {tool_call_id}: content='{content}', type={type(content)}")
10961109

1097-
# Parse JSON content, handling empty or invalid JSON gracefully
1110+
# Parse content - try JSON first, fall back to plain string
10981111
try:
10991112
if content and content.strip():
1100-
result = json.loads(content)
1113+
# Try to parse as JSON first
1114+
try:
1115+
result = json.loads(content)
1116+
except json.JSONDecodeError:
1117+
# Not valid JSON - treat as plain string result
1118+
result = {"success": True, "result": content, "status": "completed"}
1119+
logger.debug(f"Tool result for {tool_call_id} is plain string, wrapped in result object")
11011120
else:
11021121
# Handle empty content as a success with empty result
1103-
result = {"success": True, "result": None}
1122+
result = {"success": True, "result": None, "status": "completed"}
11041123
logger.warning(f"Empty tool result content for tool call {tool_call_id}, using empty success result")
1105-
except json.JSONDecodeError as json_error:
1106-
# Handle invalid JSON by providing detailed error result
1107-
result = {
1108-
"error": f"Invalid JSON in tool result: {str(json_error)}",
1109-
"raw_content": content,
1110-
"error_type": "JSON_DECODE_ERROR",
1111-
"line": getattr(json_error, 'lineno', None),
1112-
"column": getattr(json_error, 'colno', None)
1113-
}
1114-
logger.error(f"Invalid JSON in tool result for call {tool_call_id}: {json_error} at line {getattr(json_error, 'lineno', '?')}, column {getattr(json_error, 'colno', '?')}")
1124+
except Exception as e:
1125+
# Handle any other error
1126+
result = {"success": True, "result": str(content) if content else None, "status": "completed"}
1127+
logger.warning(f"Error processing tool result for {tool_call_id}: {e}, using string fallback")
11151128

11161129
updated_function_response_part = types.Part(
11171130
function_response=types.FunctionResponse(
@@ -1160,21 +1173,23 @@ async def _run_adk_in_background(
11601173

11611174
logger.debug(f"Received tool result for call {tool_call_id}: content='{content}', type={type(content)}")
11621175

1176+
# Parse content - try JSON first, fall back to plain string
11631177
try:
11641178
if content and content.strip():
1165-
result = json.loads(content)
1179+
# Try to parse as JSON first
1180+
try:
1181+
result = json.loads(content)
1182+
except json.JSONDecodeError:
1183+
# Not valid JSON - treat as plain string result
1184+
result = {"success": True, "result": content, "status": "completed"}
1185+
logger.debug(f"Tool result for {tool_call_id} is plain string, wrapped in result object")
11661186
else:
1167-
result = {"success": True, "result": None}
1187+
result = {"success": True, "result": None, "status": "completed"}
11681188
logger.warning(f"Empty tool result content for tool call {tool_call_id}, using empty success result")
1169-
except json.JSONDecodeError as json_error:
1170-
result = {
1171-
"error": f"Invalid JSON in tool result: {str(json_error)}",
1172-
"raw_content": content,
1173-
"error_type": "JSON_DECODE_ERROR",
1174-
"line": getattr(json_error, 'lineno', None),
1175-
"column": getattr(json_error, 'colno', None)
1176-
}
1177-
logger.error(f"Invalid JSON in tool result for call {tool_call_id}: {json_error} at line {getattr(json_error, 'lineno', '?')}, column {getattr(json_error, 'colno', '?')}")
1189+
except Exception as e:
1190+
# Handle any other error
1191+
result = {"success": True, "result": str(content) if content else None, "status": "completed"}
1192+
logger.warning(f"Error processing tool result for {tool_call_id}: {e}, using string fallback")
11781193

11791194
updated_function_response_part = types.Part(
11801195
function_response=types.FunctionResponse(
@@ -1293,6 +1308,7 @@ async def _run_adk_in_background(
12931308
# hard stop the execution if we find any long running tool
12941309
if is_long_running_tool:
12951310
return
1311+
12961312
# Force close any streaming messages
12971313
async for ag_ui_event in event_translator.force_close_streaming_message():
12981314
await event_queue.put(ag_ui_event)

0 commit comments

Comments
 (0)