Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 97 additions & 173 deletions python/packages/core/agent_framework/_workflows/_agent.py

Large diffs are not rendered by default.

39 changes: 35 additions & 4 deletions python/packages/core/agent_framework/_workflows/_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,15 +429,30 @@ async def _run_agent(self, ctx: WorkflowContext[Never, AgentResponse]) -> AgentR
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
await ctx.yield_output(response)

# Handle any user input requests
if response.user_input_requests:
user_input_request_count = len(response.user_input_requests)
total_message_content_count = sum(len(msg.contents) for msg in response.messages)
if user_input_request_count != total_message_content_count:
logger.warning(
"Response %s contains %d user input requests but total message contents are %d. "
"This indicates the response contains both user input requests and message contents. "
"Double check if this is the intended behavior, as non user input request contents in "
"this response will not be emitted.",
response.response_id,
user_input_request_count,
total_message_content_count,
)
for user_input_request in response.user_input_requests:
self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index]
await ctx.request_info(user_input_request, Content)
await ctx.request_info(user_input_request, Content, request_id=user_input_request.id)
Comment thread
moonbox3 marked this conversation as resolved.
return None

# Only yield output if the response is complete and not waiting for user input.
# This is to avoid emitting two events of different types ('output' and 'request_info')
# that carry the same payload.
Comment thread
TaoChenOSU marked this conversation as resolved.
await ctx.yield_output(response)
return response
Comment thread
TaoChenOSU marked this conversation as resolved.

async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUpdate]) -> AgentResponse | None:
Expand Down Expand Up @@ -472,9 +487,25 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp
)
async for update in stream:
updates.append(update)
await ctx.yield_output(update)
if update.user_input_requests:
user_input_request_count = len(update.user_input_requests)
total_message_content_count = len(update.contents)
if user_input_request_count != total_message_content_count:
logger.warning(
"Response update %s contains %d user input requests but total message contents are %d. "
"This indicates the response update contains both user input requests and message contents. "
"Double check if this is the intended behavior, as non user input request contents will "
"not be emitted.",
update.response_id,
user_input_request_count,
total_message_content_count,
)
streamed_user_input_requests.extend(update.user_input_requests)
else:
# Only yield output events for updates that do not contain user input requests.
# This is to avoid emitting two events of different types ('output' and 'request_info')
# that carry the same payload.
await ctx.yield_output(update)
Comment thread
TaoChenOSU marked this conversation as resolved.
Comment thread
moonbox3 marked this conversation as resolved.

Comment thread
TaoChenOSU marked this conversation as resolved.
# Prefer stream finalization when available so result hooks run
# (e.g., thread conversation updates). Fall back to reconstructing from updates
Expand Down Expand Up @@ -509,7 +540,7 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp
if user_input_requests:
for user_input_request in user_input_requests:
self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index]
await ctx.request_info(user_input_request, Content)
await ctx.request_info(user_input_request, Content, request_id=user_input_request.id)
return None

return response
Expand Down
29 changes: 25 additions & 4 deletions python/packages/core/agent_framework/_workflows/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,22 @@ def __init__(
# Flag to prevent concurrent workflow executions
self._is_running = False

# Current run-level status of this workflow instance. Updated in lockstep with
# the status events emitted from `_run_workflow_with_tracing`. Defaults to IDLE
# for a freshly built workflow that has not yet been run.
self._status: WorkflowRunState = WorkflowRunState.IDLE

@property
def status(self) -> WorkflowRunState:
"""Return the current run-level status of this workflow instance.

Mirrors the most recent status event emitted by the workflow. Safe to read at
any time: workflows run on a single asyncio event loop, and the underlying
attribute is a single enum reference whose assignment is atomic under the
CPython GIL, so no locking is required.
"""
return self._status

def _ensure_not_running(self) -> None:
"""Ensure the workflow is not already running."""
if self._is_running:
Expand Down Expand Up @@ -513,8 +529,9 @@ async def _run_workflow_with_tracing(
with _framework_event_origin():
started = WorkflowEvent.started()
yield started # noqa: RUF070
self._status = WorkflowRunState.IN_PROGRESS
with _framework_event_origin():
in_progress = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS)
in_progress = WorkflowEvent.status(self._status)
yield in_progress # noqa: RUF070

# Per-run reset for fresh-message runs only. We deliberately
Expand Down Expand Up @@ -569,17 +586,20 @@ async def _run_workflow_with_tracing(

if event.type == "request_info" and not emitted_in_progress_pending:
emitted_in_progress_pending = True
self._status = WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS
with _framework_event_origin():
pending_status = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS)
pending_status = WorkflowEvent.status(self._status)
yield pending_status # noqa: RUF070
# Workflow runs until idle - emit final status based on whether requests are pending
if saw_request:
self._status = WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
with _framework_event_origin():
terminal_status = WorkflowEvent.status(WorkflowRunState.IDLE_WITH_PENDING_REQUESTS)
terminal_status = WorkflowEvent.status(self._status)
yield terminal_status
else:
self._status = WorkflowRunState.IDLE
with _framework_event_origin():
terminal_status = WorkflowEvent.status(WorkflowRunState.IDLE)
terminal_status = WorkflowEvent.status(self._status)
yield terminal_status

span.add_event(OtelAttr.WORKFLOW_COMPLETED)
Expand All @@ -593,6 +613,7 @@ async def _run_workflow_with_tracing(
with _framework_event_origin():
failed_event = WorkflowEvent.failed(details)
yield failed_event # noqa: RUF070
self._status = WorkflowRunState.FAILED
with _framework_event_origin():
failed_status = WorkflowEvent.status(WorkflowRunState.FAILED)
yield failed_status # noqa: RUF070
Expand Down
11 changes: 5 additions & 6 deletions python/packages/core/tests/core/test_observability.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
prepend_agent_framework_to_user_agent,
tool,
)
from agent_framework._serialization import make_json_safe
from agent_framework.observability import (
ROLE_EVENT_MAP,
AgentTelemetryLayer,
Expand Down Expand Up @@ -3195,17 +3196,15 @@ def test_capture_messages_with_prepared_request_info_function_call_arguments(spa

from opentelemetry import trace

from agent_framework import WorkflowAgent

@dataclasses.dataclass
class HandoffRequest:
target_agent: str
reason: str

arguments = WorkflowAgent.RequestInfoFunctionArgs(
request_id="call_dc",
data=HandoffRequest(target_agent="helper", reason="overflow"),
).to_dict()
arguments = {
"request_id": "call_dc",
"data": make_json_safe(HandoffRequest(target_agent="helper", reason="overflow")),
}
msg = Message(
role="assistant",
contents=[
Expand Down
168 changes: 168 additions & 0 deletions python/packages/core/tests/workflow/test_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,171 @@ async def test_resolve_executor_kwargs_empty_per_executor_does_not_fallback_to_g
resolved = {"exec_a": {}, GLOBAL_KWARGS_KEY: {"global_key": "global_val"}}
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
assert result == {}


# region Tool approval emission


class _ApprovalEmittingAgent(BaseAgent):
"""Agent that returns a single ``function_approval_request`` Content.

Used to verify that ``AgentExecutor`` does *not* surface the approval
payload via both an ``output`` event and a ``request_info`` event in the
same superstep — only the ``request_info`` event must carry it.
"""

def __init__(
self,
*,
approval_request_id: str = "apr_1",
tool_name: str = "delete_file",
tool_arguments: dict[str, Any] | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
self._approval_request_id = approval_request_id
self._tool_name = tool_name
self._tool_arguments: dict[str, Any] = tool_arguments or {"path": "/tmp/secret.txt"}
self.run_count = 0

def _build_approval_content(self) -> Content:
function_call = Content.from_function_call(
call_id=self._approval_request_id,
name=self._tool_name,
arguments=self._tool_arguments,
)
return Content.from_function_approval_request(id=self._approval_request_id, function_call=function_call)

@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...

@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...

def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
self.run_count += 1
approval = self._build_approval_content()

if stream:

async def _stream() -> AsyncIterable[AgentResponseUpdate]:
yield AgentResponseUpdate(contents=[approval], role="assistant")

return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)

async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", [approval])])

return _run()


def _has_approval_payload(event: WorkflowEvent[Any]) -> bool:
"""Return True if the event's data carries a ``function_approval_request`` content."""
data: Any = event.data

def _contents_of(value: Any) -> list[Content]:
if isinstance(value, AgentResponseUpdate):
return list(value.contents)
if isinstance(value, AgentResponse):
return [c for m in value.messages for c in m.contents]
if isinstance(value, AgentExecutorResponse):
return [c for m in value.agent_response.messages for c in m.contents]
if isinstance(value, Message):
return list(value.contents)
if isinstance(value, Content):
return [value]
return []

return any(c.type == "function_approval_request" for c in _contents_of(data))


async def test_agent_executor_does_not_double_emit_approval_non_streaming() -> None:
"""Non-streaming: approval payload must only appear in the ``request_info`` event.

Regression test for the bug where ``AgentExecutor._run_agent`` first
``yield_output``-ed the response (carrying the approval Content) and then
additionally emitted a ``request_info`` event for the same payload.
"""
agent = _ApprovalEmittingAgent(id="approve_agent", name="ApproveAgent", approval_request_id="apr_ns_1")
executor = AgentExecutor(agent, id="approve_exec")
workflow = WorkflowBuilder(start_executor=executor).build()

request_info_events: list[WorkflowEvent[Any]] = []
output_events: list[WorkflowEvent[Any]] = []

for event in await workflow.run("please delete it"):
if event.type == "request_info":
request_info_events.append(event)
elif event.type == "output":
output_events.append(event)

assert len(request_info_events) == 1
assert _has_approval_payload(request_info_events[0])
# The approval payload must not also be surfaced as a workflow output.
assert not any(_has_approval_payload(e) for e in output_events)
assert agent.run_count == 1


async def test_agent_executor_does_not_double_emit_approval_streaming() -> None:
"""Streaming: per-update approval payload must not be ``yield_output``-ed."""
agent = _ApprovalEmittingAgent(id="approve_agent_s", name="ApproveAgentS", approval_request_id="apr_st_1")
executor = AgentExecutor(agent, id="approve_exec_s")
workflow = WorkflowBuilder(start_executor=executor).build()

request_info_events: list[WorkflowEvent[Any]] = []
output_events: list[WorkflowEvent[Any]] = []

async for event in workflow.run("please delete it", stream=True):
if event.type == "request_info":
request_info_events.append(event)
elif event.type == "output":
output_events.append(event)

assert len(request_info_events) == 1
assert _has_approval_payload(request_info_events[0])
assert not any(_has_approval_payload(e) for e in output_events)
assert agent.run_count == 1


async def test_agent_executor_request_info_uses_user_input_request_id() -> None:
"""``ctx.request_info`` must register the request under the agent's approval id.

This makes the workflow's pending-request id round-trip with the
``function_approval_response.id`` the caller echoes back, so
``Workflow._send_responses_internal`` can look it up directly.
"""
agent = _ApprovalEmittingAgent(id="approve_agent_id", name="ApproveAgentId", approval_request_id="apr_match")
executor = AgentExecutor(agent, id="approve_exec_id")
workflow = WorkflowBuilder(start_executor=executor).build()

request_info_events: list[WorkflowEvent[Any]] = []
async for event in workflow.run("please delete it", stream=True):
if event.type == "request_info":
request_info_events.append(event)

assert len(request_info_events) == 1
assert request_info_events[0].request_id == "apr_match"


# endregion Tool approval emission
Loading
Loading