|
2 | 2 | import json |
3 | 3 | import os |
4 | 4 | import uuid |
5 | | -from typing import Any, Callable, Dict, overload |
| 5 | +from typing import Any, Callable, Dict, Optional, Type, overload |
6 | 6 |
|
| 7 | +from llama_index.core.agent.workflow import BaseWorkflowAgent |
7 | 8 | from llama_index.core.workflow import ( |
8 | 9 | HumanResponseEvent, |
9 | 10 | InputRequiredEvent, |
|
15 | 16 | get_steps_from_class, |
16 | 17 | get_steps_from_instance, |
17 | 18 | ) |
| 19 | +from pydantic import BaseModel |
18 | 20 | from uipath._cli._utils._console import ConsoleLogger |
19 | 21 | from uipath._cli._utils._parse_ast import generate_bindings_json # type: ignore |
20 | 22 | from uipath._cli.middlewares import MiddlewareResult |
@@ -71,18 +73,46 @@ def generate_schema_from_workflow(workflow: Workflow) -> Dict[str, Any]: |
71 | 73 |
|
72 | 74 | # Generate input schema from StartEvent using Pydantic's schema method |
73 | 75 | try: |
74 | | - input_schema = start_event_class.model_json_schema() |
75 | | - # Resolve references and handle nullable types |
76 | | - input_schema = resolve_refs(input_schema) |
77 | | - schema["input"]["properties"] = process_nullable_types( |
78 | | - input_schema.get("properties", {}) |
79 | | - ) |
80 | | - schema["input"]["required"] = input_schema.get("required", []) |
| 76 | + if isinstance(workflow, BaseWorkflowAgent): |
| 77 | + # For workflow agents, define a simple schema with just user_msg |
| 78 | + schema["input"] = { |
| 79 | + "type": "object", |
| 80 | + "properties": { |
| 81 | + "user_msg": { |
| 82 | + "type": "string", |
| 83 | + "title": "User Message", |
| 84 | + "description": "The user's question or request", |
| 85 | + } |
| 86 | + }, |
| 87 | + "required": ["user_msg"], |
| 88 | + } |
| 89 | + else: |
| 90 | + input_schema = start_event_class.model_json_schema() |
| 91 | + # Resolve references and handle nullable types |
| 92 | + input_schema = resolve_refs(input_schema) |
| 93 | + schema["input"]["properties"] = process_nullable_types( |
| 94 | + input_schema.get("properties", {}) |
| 95 | + ) |
| 96 | + schema["input"]["required"] = input_schema.get("required", []) |
81 | 97 | except (AttributeError, Exception): |
82 | 98 | pass |
83 | 99 |
|
84 | | - # For output schema, check if it's the base StopEvent or a custom subclass |
85 | | - if stop_event_class is StopEvent: |
| 100 | + # Handle output schema - check if it's a workflow agent with output_cls first |
| 101 | + if isinstance(workflow, BaseWorkflowAgent): |
| 102 | + output_cls: Optional[Type[BaseModel]] = getattr(workflow, "output_cls", None) |
| 103 | + if output_cls is not None: |
| 104 | + try: |
| 105 | + output_schema = output_cls.model_json_schema() |
| 106 | + # Resolve references and handle nullable types |
| 107 | + output_schema = resolve_refs(output_schema) |
| 108 | + schema["output"]["properties"] = process_nullable_types( |
| 109 | + output_schema.get("properties", {}) |
| 110 | + ) |
| 111 | + schema["output"]["required"] = output_schema.get("required", []) |
| 112 | + except (AttributeError, Exception): |
| 113 | + pass |
| 114 | + # Check if it's the base StopEvent or a custom subclass |
| 115 | + elif stop_event_class is StopEvent: |
86 | 116 | # base StopEvent |
87 | 117 | schema["output"] = { |
88 | 118 | "type": "object", |
|
0 commit comments