Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions src/openai/resources/vector_stores/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import time
from typing import TYPE_CHECKING, Dict, Union, Optional
from typing_extensions import Literal, assert_never

Expand Down Expand Up @@ -331,6 +332,7 @@ def create_and_poll(
vector_store_id: str,
attributes: Optional[Dict[str, Union[str, float, bool]]] | Omit = omit,
poll_interval_ms: int | Omit = omit,
max_wait_seconds: float | Omit = omit,
chunking_strategy: FileChunkingStrategyParam | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand All @@ -355,6 +357,7 @@ def create_and_poll(
file_id,
vector_store_id=vector_store_id,
poll_interval_ms=poll_interval_ms,
max_wait_seconds=max_wait_seconds,
)

def poll(
Expand All @@ -363,12 +366,17 @@ def poll(
*,
vector_store_id: str,
poll_interval_ms: int | Omit = omit,
max_wait_seconds: float | Omit = omit,
) -> VectorStoreFile:
"""Wait for the vector store file to finish processing.

Note: this will return even if the file failed to process, you need to check
file.last_error and file.status to handle these cases
"""
if is_given(max_wait_seconds) and max_wait_seconds < 0:
raise ValueError("Expected a non-negative value for `max_wait_seconds`")

start = time.monotonic()
headers: dict[str, str] = {"X-Stainless-Poll-Helper": "true"}
if is_given(poll_interval_ms):
headers["X-Stainless-Custom-Poll-Interval"] = str(poll_interval_ms)
Expand All @@ -389,7 +397,16 @@ def poll(
else:
poll_interval_ms = 1000

self._sleep(poll_interval_ms / 1000)
sleep_seconds = poll_interval_ms / 1000
if is_given(max_wait_seconds):
remaining = max_wait_seconds - (time.monotonic() - start)
if remaining <= 0:
raise TimeoutError(
f"Timed out waiting for vector store file {file_id!r} to finish processing"
)
sleep_seconds = min(sleep_seconds, remaining)

self._sleep(sleep_seconds)
elif file.status == "cancelled" or file.status == "completed" or file.status == "failed":
return file
else:
Expand Down Expand Up @@ -420,6 +437,7 @@ def upload_and_poll(
file: FileTypes,
attributes: Optional[Dict[str, Union[str, float, bool]]] | Omit = omit,
poll_interval_ms: int | Omit = omit,
max_wait_seconds: float | Omit = omit,
chunking_strategy: FileChunkingStrategyParam | Omit = omit,
) -> VectorStoreFile:
"""Add a file to a vector store and poll until processing is complete."""
Expand All @@ -429,6 +447,7 @@ def upload_and_poll(
file_id=file_obj.id,
chunking_strategy=chunking_strategy,
poll_interval_ms=poll_interval_ms,
max_wait_seconds=max_wait_seconds,
attributes=attributes,
)

Expand Down Expand Up @@ -785,6 +804,7 @@ async def create_and_poll(
vector_store_id: str,
attributes: Optional[Dict[str, Union[str, float, bool]]] | Omit = omit,
poll_interval_ms: int | Omit = omit,
max_wait_seconds: float | Omit = omit,
chunking_strategy: FileChunkingStrategyParam | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand All @@ -809,6 +829,7 @@ async def create_and_poll(
file_id,
vector_store_id=vector_store_id,
poll_interval_ms=poll_interval_ms,
max_wait_seconds=max_wait_seconds,
)

async def poll(
Expand All @@ -817,12 +838,17 @@ async def poll(
*,
vector_store_id: str,
poll_interval_ms: int | Omit = omit,
max_wait_seconds: float | Omit = omit,
) -> VectorStoreFile:
"""Wait for the vector store file to finish processing.

Note: this will return even if the file failed to process, you need to check
file.last_error and file.status to handle these cases
"""
if is_given(max_wait_seconds) and max_wait_seconds < 0:
raise ValueError("Expected a non-negative value for `max_wait_seconds`")

start = time.monotonic()
headers: dict[str, str] = {"X-Stainless-Poll-Helper": "true"}
if is_given(poll_interval_ms):
headers["X-Stainless-Custom-Poll-Interval"] = str(poll_interval_ms)
Expand All @@ -843,7 +869,16 @@ async def poll(
else:
poll_interval_ms = 1000

await self._sleep(poll_interval_ms / 1000)
sleep_seconds = poll_interval_ms / 1000
if is_given(max_wait_seconds):
remaining = max_wait_seconds - (time.monotonic() - start)
if remaining <= 0:
raise TimeoutError(
f"Timed out waiting for vector store file {file_id!r} to finish processing"
)
sleep_seconds = min(sleep_seconds, remaining)

await self._sleep(sleep_seconds)
elif file.status == "cancelled" or file.status == "completed" or file.status == "failed":
return file
else:
Expand Down Expand Up @@ -876,6 +911,7 @@ async def upload_and_poll(
file: FileTypes,
attributes: Optional[Dict[str, Union[str, float, bool]]] | Omit = omit,
poll_interval_ms: int | Omit = omit,
max_wait_seconds: float | Omit = omit,
chunking_strategy: FileChunkingStrategyParam | Omit = omit,
) -> VectorStoreFile:
"""Add a file to a vector store and poll until processing is complete."""
Expand All @@ -884,6 +920,7 @@ async def upload_and_poll(
vector_store_id=vector_store_id,
file_id=file_obj.id,
poll_interval_ms=poll_interval_ms,
max_wait_seconds=max_wait_seconds,
chunking_strategy=chunking_strategy,
attributes=attributes,
)
Expand Down
57 changes: 57 additions & 0 deletions tests/api_resources/vector_stores/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,28 @@
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")


def make_vector_store_file(status: str) -> VectorStoreFile:
return VectorStoreFile(
id="file-abc123",
created_at=123,
last_error=None,
object="vector_store.file",
status=status, # type: ignore[arg-type]
usage_bytes=0,
vector_store_id="vs_abc123",
)


class FakeVectorStoreFileResponse:
headers: dict[str, str] = {}

def __init__(self, file: VectorStoreFile) -> None:
self._file = file

def parse(self) -> VectorStoreFile:
return self._file


class TestFiles:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])

Expand Down Expand Up @@ -322,6 +344,22 @@ def test_path_params_content(self, client: OpenAI) -> None:
vector_store_id="vs_abc123",
)

@parametrize
def test_poll_timeout(self, client: OpenAI, monkeypatch: pytest.MonkeyPatch) -> None:
def retrieve(*_args: Any, **_kwargs: Any) -> FakeVectorStoreFileResponse:
return FakeVectorStoreFileResponse(make_vector_store_file("in_progress"))

monkeypatch.setattr(client.vector_stores.files.with_raw_response, "retrieve", retrieve)
monkeypatch.setattr(client.vector_stores.files, "_sleep", lambda _: None)

with pytest.raises(TimeoutError, match="Timed out waiting for vector store file"):
client.vector_stores.files.poll(
"file-abc123",
vector_store_id="vs_abc123",
poll_interval_ms=1,
max_wait_seconds=0,
)


class TestAsyncFiles:
parametrize = pytest.mark.parametrize(
Expand Down Expand Up @@ -627,6 +665,25 @@ async def test_path_params_content(self, async_client: AsyncOpenAI) -> None:
vector_store_id="vs_abc123",
)

@parametrize
async def test_poll_timeout(self, async_client: AsyncOpenAI, monkeypatch: pytest.MonkeyPatch) -> None:
async def retrieve(*_args: Any, **_kwargs: Any) -> FakeVectorStoreFileResponse:
return FakeVectorStoreFileResponse(make_vector_store_file("in_progress"))

async def sleep(_: float) -> None:
return None

monkeypatch.setattr(async_client.vector_stores.files.with_raw_response, "retrieve", retrieve)
monkeypatch.setattr(async_client.vector_stores.files, "_sleep", sleep)

with pytest.raises(TimeoutError, match="Timed out waiting for vector store file"):
await async_client.vector_stores.files.poll(
"file-abc123",
vector_store_id="vs_abc123",
poll_interval_ms=1,
max_wait_seconds=0,
)


@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
def test_create_and_poll_method_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
Expand Down