Skip to content

Commit 69320c9

Browse files
Merge pull request #780 from ag-ui-protocol/contextablemark/test770
Addressing issue #769
2 parents 77a1a75 + 614f9a5 commit 69320c9

File tree

4 files changed

+682
-58
lines changed

4 files changed

+682
-58
lines changed

integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py

Lines changed: 97 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -350,14 +350,14 @@ def _create_runner(self, adk_agent: BaseAgent, user_id: str, app_name: str) -> R
350350

351351
async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]:
352352
"""Run the ADK agent with client-side tool support.
353-
353+
354354
All client-side tools are long-running. For tool result submissions,
355355
we continue existing executions. For new requests, we start new executions.
356356
ADK sessions handle conversation continuity and tool result processing.
357-
357+
358358
Args:
359359
input: The AG-UI run input
360-
360+
361361
Yields:
362362
AG-UI protocol events
363363
"""
@@ -374,6 +374,27 @@ async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]:
374374
app_name = self._get_app_name(input)
375375
skip_tool_message_batch = False
376376

377+
# Check if there are pending tool calls AND tool results in unseen messages
378+
has_pending_tools = await self._has_pending_tool_calls(input.thread_id)
379+
has_tool_results_in_unseen = any(getattr(msg, "role", None) == "tool" for msg in unseen_messages)
380+
381+
if has_pending_tools and has_tool_results_in_unseen:
382+
# HITL/Frontend tool scenario: skip to the tool results first
383+
for i, msg in enumerate(unseen_messages):
384+
if getattr(msg, "role", None) == "tool":
385+
# Mark all messages before the tool result as processed (they're already in the ADK session)
386+
skipped_ids = []
387+
for j in range(i):
388+
msg_id = getattr(unseen_messages[j], "id", None)
389+
if msg_id:
390+
skipped_ids.append(msg_id)
391+
if skipped_ids:
392+
self._session_manager.mark_messages_processed(app_name, input.thread_id, skipped_ids)
393+
index = i
394+
break
395+
396+
logger.debug(f"[RUN_LOOP] Starting message loop for thread={input.thread_id}, total_unseen={total_unseen}, starting_index={index}")
397+
377398
while index < total_unseen:
378399
current = unseen_messages[index]
379400
role = getattr(current, "role", None)
@@ -487,6 +508,38 @@ async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]:
487508
else:
488509
skip_tool_message_batch = False
489510

511+
# Check if there's an upcoming tool batch that will be skipped
512+
# If so, this non-tool batch is part of historical backend tool interaction
513+
# and should also be skipped
514+
upcoming_tool_batch_skipped = False
515+
if index < total_unseen and getattr(unseen_messages[index], "role", None) == "tool":
516+
# Peek at the upcoming tool batch
517+
peek_idx = index
518+
upcoming_tool_call_ids = []
519+
while peek_idx < total_unseen and getattr(unseen_messages[peek_idx], "role", None) == "tool":
520+
tool_call_id = getattr(unseen_messages[peek_idx], "tool_call_id", None)
521+
if tool_call_id:
522+
upcoming_tool_call_ids.append(tool_call_id)
523+
peek_idx += 1
524+
525+
if upcoming_tool_call_ids:
526+
pending_ids = await self._get_pending_tool_call_ids(input.thread_id)
527+
if pending_ids is not None:
528+
pending_set = set(pending_ids)
529+
# If NONE of the upcoming tool results match pending, they're historical
530+
if not any(tc_id in pending_set for tc_id in upcoming_tool_call_ids):
531+
upcoming_tool_batch_skipped = True
532+
533+
if upcoming_tool_batch_skipped:
534+
# Skip this message batch - it's part of historical backend tool interaction
535+
# Mark the messages as processed
536+
logger.debug(f"[RUN_LOOP] Skipping message batch (upcoming tool batch will be skipped)")
537+
batch_ids = self._collect_message_ids(message_batch)
538+
if batch_ids:
539+
self._session_manager.mark_messages_processed(app_name, input.thread_id, batch_ids)
540+
continue
541+
542+
logger.debug(f"[RUN_LOOP] Calling _start_new_execution with message_batch of {len(message_batch)} messages")
490543
async for event in self._start_new_execution(input, message_batch=message_batch):
491544
yield event
492545

@@ -621,24 +674,19 @@ async def _handle_tool_result_submission(
621674
return
622675

623676
try:
624-
# Remove tool calls from pending list
677+
# Remove tool calls from pending list and track which ones we processed
678+
processed_tool_ids = []
625679
for tool_result in tool_results:
626680
tool_call_id = tool_result['message'].tool_call_id
627681
has_pending = await self._has_pending_tool_calls(thread_id)
628682

629683
if has_pending:
630-
# Could add more specific check here for the exact tool_call_id
631-
# but for now just log that we're processing a tool result while tools are pending
632-
logger.debug(f"Processing tool result {tool_call_id} for thread {thread_id} with pending tools")
633684
# Remove from pending tool calls now that we're processing it
634685
await self._remove_pending_tool_call(thread_id, tool_call_id)
635-
else:
636-
# No pending tools - this could be a stale result or from a different session
637-
logger.warning(f"No pending tool calls found for tool result {tool_call_id} in thread {thread_id}")
686+
processed_tool_ids.append(tool_call_id)
638687

639688
# Since all tools are long-running, all tool results are standalone
640689
# and should start new executions with the tool results
641-
logger.info(f"Starting new execution for tool result in thread {thread_id}")
642690

643691
# Use trailing_messages if provided, otherwise fall back to candidate_messages
644692
message_batch = trailing_messages if trailing_messages else (candidate_messages if include_message_batch else None)
@@ -884,14 +932,11 @@ async def _start_new_execution(
884932
if input.thread_id in self._active_executions:
885933
execution = self._active_executions[input.thread_id]
886934
execution.is_complete = True
887-
935+
888936
# Check if session has pending tool calls before cleanup
889937
has_pending = await self._has_pending_tool_calls(input.thread_id)
890938
if not has_pending:
891939
del self._active_executions[input.thread_id]
892-
logger.debug(f"Cleaned up execution for thread {input.thread_id}")
893-
else:
894-
logger.info(f"Preserving execution for thread {input.thread_id} - has pending tool calls (HITL scenario)")
895940

896941
async def _start_background_execution(
897942
self,
@@ -1036,6 +1081,8 @@ async def _run_adk_in_background(
10361081
event_queue: Queue for emitting events
10371082
"""
10381083
runner: Optional[Runner] = None
1084+
logger.debug(f"[BG_EXEC] _run_adk_in_background called for thread={input.thread_id}")
1085+
logger.debug(f"[BG_EXEC] tool_results={len(tool_results) if tool_results else 0}, message_batch={len(message_batch) if message_batch else 0}")
10391086
try:
10401087
# Agent is already prepared with tools and SystemMessage instructions (if any)
10411088
# from _start_background_execution, so no additional agent copying needed here
@@ -1077,7 +1124,10 @@ async def _run_adk_in_background(
10771124
self._session_manager.mark_messages_processed(app_name, input.thread_id, message_ids)
10781125

10791126
# Convert user messages first (if any)
1080-
user_message = await self._convert_latest_message(input, unseen_messages) if message_batch else None
1127+
# Note: We pass unseen_messages which is already set from message_batch or _get_unseen_messages
1128+
# The original code had a bug: `if message_batch else None` would skip conversion when
1129+
# message_batch was None but unseen_messages contained valid user messages
1130+
user_message = await self._convert_latest_message(input, unseen_messages)
10811131

10821132
# if there is a tool response submission by the user, add FunctionResponse to session first
10831133
if active_tool_results and user_message:
@@ -1091,24 +1141,24 @@ async def _run_adk_in_background(
10911141
# Debug: Log the actual tool message content we received
10921142
logger.debug(f"Received tool result for call {tool_call_id}: content='{content}', type={type(content)}")
10931143

1094-
# Parse JSON content, handling empty or invalid JSON gracefully
1144+
# Parse content - try JSON first, fall back to plain string
10951145
try:
10961146
if content and content.strip():
1097-
result = json.loads(content)
1147+
# Try to parse as JSON first
1148+
try:
1149+
result = json.loads(content)
1150+
except json.JSONDecodeError:
1151+
# Not valid JSON - treat as plain string result
1152+
result = {"success": True, "result": content, "status": "completed"}
1153+
logger.debug(f"Tool result for {tool_call_id} is plain string, wrapped in result object")
10981154
else:
10991155
# Handle empty content as a success with empty result
1100-
result = {"success": True, "result": None}
1156+
result = {"success": True, "result": None, "status": "completed"}
11011157
logger.warning(f"Empty tool result content for tool call {tool_call_id}, using empty success result")
1102-
except json.JSONDecodeError as json_error:
1103-
# Handle invalid JSON by providing detailed error result
1104-
result = {
1105-
"error": f"Invalid JSON in tool result: {str(json_error)}",
1106-
"raw_content": content,
1107-
"error_type": "JSON_DECODE_ERROR",
1108-
"line": getattr(json_error, 'lineno', None),
1109-
"column": getattr(json_error, 'colno', None)
1110-
}
1111-
logger.error(f"Invalid JSON in tool result for call {tool_call_id}: {json_error} at line {getattr(json_error, 'lineno', '?')}, column {getattr(json_error, 'colno', '?')}")
1158+
except Exception as e:
1159+
# Handle any other error
1160+
result = {"success": True, "result": str(content) if content else None, "status": "completed"}
1161+
logger.warning(f"Error processing tool result for {tool_call_id}: {e}, using string fallback")
11121162

11131163
updated_function_response_part = types.Part(
11141164
function_response=types.FunctionResponse(
@@ -1157,21 +1207,23 @@ async def _run_adk_in_background(
11571207

11581208
logger.debug(f"Received tool result for call {tool_call_id}: content='{content}', type={type(content)}")
11591209

1210+
# Parse content - try JSON first, fall back to plain string
11601211
try:
11611212
if content and content.strip():
1162-
result = json.loads(content)
1213+
# Try to parse as JSON first
1214+
try:
1215+
result = json.loads(content)
1216+
except json.JSONDecodeError:
1217+
# Not valid JSON - treat as plain string result
1218+
result = {"success": True, "result": content, "status": "completed"}
1219+
logger.debug(f"Tool result for {tool_call_id} is plain string, wrapped in result object")
11631220
else:
1164-
result = {"success": True, "result": None}
1221+
result = {"success": True, "result": None, "status": "completed"}
11651222
logger.warning(f"Empty tool result content for tool call {tool_call_id}, using empty success result")
1166-
except json.JSONDecodeError as json_error:
1167-
result = {
1168-
"error": f"Invalid JSON in tool result: {str(json_error)}",
1169-
"raw_content": content,
1170-
"error_type": "JSON_DECODE_ERROR",
1171-
"line": getattr(json_error, 'lineno', None),
1172-
"column": getattr(json_error, 'colno', None)
1173-
}
1174-
logger.error(f"Invalid JSON in tool result for call {tool_call_id}: {json_error} at line {getattr(json_error, 'lineno', '?')}, column {getattr(json_error, 'colno', '?')}")
1223+
except Exception as e:
1224+
# Handle any other error
1225+
result = {"success": True, "result": str(content) if content else None, "status": "completed"}
1226+
logger.warning(f"Error processing tool result for {tool_call_id}: {e}, using string fallback")
11751227

11761228
updated_function_response_part = types.Part(
11771229
function_response=types.FunctionResponse(
@@ -1185,6 +1237,10 @@ async def _run_adk_in_background(
11851237
new_message = types.Content(parts=function_response_parts, role='user')
11861238
else:
11871239
# No tool results, just use the user message
1240+
# If user_message is None (e.g., unseen_messages was empty because all were
1241+
# already processed), fall back to extracting the latest user message from input.messages
1242+
if user_message is None and input.messages:
1243+
user_message = await self._convert_latest_message(input, input.messages)
11881244
new_message = user_message
11891245

11901246
# Create event translator
@@ -1286,6 +1342,7 @@ async def _run_adk_in_background(
12861342
# hard stop the execution if we find any long running tool
12871343
if is_long_running_tool:
12881344
return
1345+
12891346
# Force close any streaming messages
12901347
async for ag_ui_event in event_translator.force_close_streaming_message():
12911348
await event_queue.put(ag_ui_event)

integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -279,23 +279,10 @@ async def _translate_text_content(
279279
return
280280

281281
combined_text = "".join(text_parts)
282-
if not combined_text:
283-
return
284-
285-
# Use proper ADK streaming detection (handle None values)
286-
is_partial = getattr(adk_event, 'partial', False)
287-
turn_complete = getattr(adk_event, 'turn_complete', False)
288-
289-
# (is_final_response is already calculated above)
290-
291-
# Handle None values: if a turn is complete or a final chunk arrives, end streaming
292-
has_finish_reason = bool(getattr(adk_event, 'finish_reason', None))
293-
should_send_end = (
294-
(turn_complete and not is_partial)
295-
or (is_final_response and not is_partial)
296-
or (has_finish_reason and self._is_streaming)
297-
)
298282

283+
# Handle is_final_response BEFORE the empty text early return.
284+
# An empty final response is a valid stream-closing signal that must close
285+
# any active stream, even if there's no new text content.
299286
if is_final_response:
300287
# This is the final, complete message event.
301288

@@ -365,7 +352,23 @@ async def _translate_text_content(
365352
self._last_streamed_run_id = None
366353
return
367354

368-
355+
# Early return for empty text (non-final responses only).
356+
# Final responses with empty text are handled above to close active streams.
357+
if not combined_text:
358+
return
359+
360+
# Use proper ADK streaming detection (handle None values)
361+
is_partial = getattr(adk_event, 'partial', False)
362+
turn_complete = getattr(adk_event, 'turn_complete', False)
363+
364+
# Handle None values: if a turn is complete or a final chunk arrives, end streaming
365+
has_finish_reason = bool(getattr(adk_event, 'finish_reason', None))
366+
should_send_end = (
367+
(turn_complete and not is_partial)
368+
or (is_final_response and not is_partial)
369+
or (has_finish_reason and self._is_streaming)
370+
)
371+
369372
# Handle streaming logic (if not is_final_response)
370373
if not self._is_streaming:
371374
# Start of new message - emit START event

integrations/adk-middleware/python/src/ag_ui_adk/session_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,12 @@ async def update_session_state(
196196
# This depends on ADK's behavior - may need to explicitly clear
197197

198198
# Create event with state changes
199+
# Use "user" as author since state updates come from the frontend
200+
# Note: Using "system" causes ADK runner warnings in _find_agent_to_run
199201
actions = EventActions(state_delta=state_delta)
200202
event = Event(
201203
invocation_id=f"state_update_{int(time.time())}",
202-
author="system",
204+
author="user",
203205
actions=actions,
204206
timestamp=time.time()
205207
)

0 commit comments

Comments
 (0)