Skip to content

Commit c763d07

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: GenAI SDK client - Add get session call to create session sdk if an immediate success is returned
PiperOrigin-RevId: 881525769
1 parent 317bf40 commit c763d07

2 files changed

Lines changed: 17 additions & 6 deletions

File tree

tests/unit/vertexai/genai/replays/test_create_agent_engine_session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def test_create_session_with_ttl(client):
4242
assert operation.response.user_id == "test-user-123"
4343
assert operation.response.labels == {"label_key": "label_value"}
4444
assert operation.response.name.startswith(agent_engine.api_resource.name)
45+
assert operation.done
4546
# Expire time is calculated by the server, so we only check that it is
4647
# within a reasonable range to avoid flakiness.
4748
assert (
@@ -78,6 +79,7 @@ def test_create_session_with_expire_time(client):
7879
assert operation.response.user_id == "test-user-123"
7980
assert operation.response.name.startswith(agent_engine.api_resource.name)
8081
assert operation.response.expire_time == expire_time
82+
assert operation.done
8183
finally:
8284
# Clean up resources.
8385
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)

vertexai/_genai/sessions.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -651,13 +651,21 @@ def create(
651651
user_id=user_id,
652652
config=config,
653653
)
654-
if config.wait_for_completion and not operation.done:
655-
operation = _agent_engines_utils._await_operation(
656-
operation_name=operation.name,
657-
get_operation_fn=self._get_session_operation,
658-
poll_interval_seconds=0.5,
659-
)
654+
logger.info("Create session operation response: %s", operation)
655+
if config.wait_for_completion:
656+
logger.info("Wait for completion")
657+
if not operation.done:
658+
logger.info("Poll operation")
659+
operation = _agent_engines_utils._await_operation(
660+
operation_name=operation.name,
661+
get_operation_fn=self._get_session_operation,
662+
poll_interval_seconds=0.5,
663+
)
664+
logger.info("Poll operation done: %s", operation)
665+
# We need to make a call to get the session because the operation
666+
# response might not contain the relevant fields.
660667
if operation.response:
668+
logger.info("Get session from operation response")
661669
operation.response = self.get(name=operation.response.name)
662670
elif operation.error:
663671
raise RuntimeError(f"Failed to create session: {operation.error}")
@@ -666,6 +674,7 @@ def create(
666674
"Error retrieving session from the operation response. "
667675
f"Operation name: {operation.name}"
668676
)
677+
logger.info("Create session final operation: %s", operation)
669678
return operation
670679

671680
def list(

0 commit comments

Comments
 (0)