diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index 8c5647e7de..e25d9ae8eb 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -382,6 +382,10 @@ class ListViewsResponse(IcebergBaseModel): identifiers: list[ListViewResponseEntry] = Field() +class LoadCredentialsResponse(IcebergBaseModel): + credentials: list[StorageCredential] = Field(alias="storage-credentials") + + _PLANNING_RESPONSE_ADAPTER = TypeAdapter(PlanningResponse) @@ -469,11 +473,13 @@ def _resolve_storage_credentials(storage_credentials: list[StorageCredential], l return best_match.config if best_match else {} - def _load_file_io(self, properties: Properties = EMPTY_DICT, location: str | None = None) -> FileIO: + def _load_file_io( + self, properties: Properties = EMPTY_DICT, location: str | None = None, session: Session | None = None + ) -> FileIO: merged_properties = {**self.properties, **properties} if self._auth_manager: merged_properties[AUTH_MANAGER] = self._auth_manager - return load_file_io(merged_properties, location) + return load_file_io(merged_properties, location, session) def supports_server_side_planning(self) -> bool: """Check if the catalog supports server-side scan planning.""" @@ -820,6 +826,7 @@ def _response_to_table(self, identifier_tuple: tuple[str, ...], table_response: io=self._load_file_io( {**table_response.metadata.properties, **table_response.config, **credential_config}, table_response.metadata_location, + self._session, ), catalog=self, config=table_response.config, @@ -837,6 +844,7 @@ def _response_to_staged_table(self, identifier_tuple: tuple[str, ...], table_res io=self._load_file_io( {**table_response.metadata.properties, **table_response.config, **credential_config}, table_response.metadata_location, + self._session, ), catalog=self, ) diff --git a/pyiceberg/io/__init__.py b/pyiceberg/io/__init__.py index 7dbc651214..aa4083cfc6 100644 --- a/pyiceberg/io/__init__.py +++ b/pyiceberg/io/__init__.py @@ -29,6 +29,7 @@ import logging import warnings from abc import ABC, abstractmethod +from datetime import datetime from io import SEEK_SET from types import TracebackType from typing import ( @@ -37,7 +38,11 @@ ) from urllib.parse import urlparse +from requests import HTTPError, Session + +from pyiceberg.exceptions import ValidationException from pyiceberg.typedef import EMPTY_DICT, Properties +from pyiceberg.utils.properties import get_first_property_value, property_as_bool, property_as_int logger = logging.getLogger(__name__) @@ -67,6 +72,7 @@ S3_ROLE_SESSION_NAME = "s3.role-session-name" S3_FORCE_VIRTUAL_ADDRESSING = "s3.force-virtual-addressing" S3_RETRY_STRATEGY_IMPL = "s3.retry-strategy-impl" +S3_SESSION_TOKEN_EXPIRES_AT_MS = "s3.session-token-expires-at-ms" HDFS_HOST = "hdfs.host" HDFS_PORT = "hdfs.port" HDFS_USER = "hdfs.user" @@ -99,6 +105,9 @@ GCS_VERSION_AWARE = "gcs.version-aware" HF_ENDPOINT = "hf.endpoint" HF_TOKEN = "hf.token" +CREDENTIALS_ENDPOINT = "client.refresh-credentials-endpoint" +REFRESH_CREDENTIALS_ENABLED = "client.refresh-credentials-enabled" +CATALOG_URI = "uri" @runtime_checkable @@ -258,9 +267,11 @@ class FileIO(ABC): """A base class for FileIO implementations.""" properties: Properties + session: Session | None - def __init__(self, properties: Properties = EMPTY_DICT): + def __init__(self, properties: Properties = EMPTY_DICT, session: Session | None = None): self.properties = properties + self.session = session @abstractmethod def new_input(self, location: str) -> InputFile: @@ -317,7 +328,7 @@ def delete(self, location: str | InputFile | OutputFile) -> None: } -def _import_file_io(io_impl: str, properties: Properties) -> FileIO | None: +def _import_file_io(io_impl: str, properties: Properties, session: Session | None = None) -> FileIO | None: try: path_parts = io_impl.split(".") if len(path_parts) < 2: @@ -325,7 +336,7 @@ def _import_file_io(io_impl: str, properties: Properties) -> FileIO | None: module_name, class_name = ".".join(path_parts[:-1]), path_parts[-1] module = importlib.import_module(module_name) class_ = getattr(module, class_name) - return class_(properties) + return class_(properties, session) except ModuleNotFoundError: logger.warning(f"Could not initialize FileIO: {io_impl}", exc_info=logger.isEnabledFor(logging.DEBUG)) return None @@ -334,22 +345,22 @@ def _import_file_io(io_impl: str, properties: Properties) -> FileIO | None: PY_IO_IMPL = "py-io-impl" -def _infer_file_io_from_scheme(path: str, properties: Properties) -> FileIO | None: +def _infer_file_io_from_scheme(path: str, properties: Properties, session: Session | None = None) -> FileIO | None: parsed_url = urlparse(path) if parsed_url.scheme: if file_ios := SCHEMA_TO_FILE_IO.get(parsed_url.scheme): for file_io_path in file_ios: - if file_io := _import_file_io(file_io_path, properties): + if file_io := _import_file_io(file_io_path, properties, session): return file_io else: warnings.warn(f"No preferred file implementation for scheme: {parsed_url.scheme}", stacklevel=2) return None -def load_file_io(properties: Properties = EMPTY_DICT, location: str | None = None) -> FileIO: +def load_file_io(properties: Properties = EMPTY_DICT, location: str | None = None, session: Session | None = None) -> FileIO: # First look for the py-io-impl property to directly load the class if io_impl := properties.get(PY_IO_IMPL): - if file_io := _import_file_io(io_impl, properties): + if file_io := _import_file_io(io_impl, properties, session): logger.info("Loaded FileIO: %s", io_impl) return file_io else: @@ -357,12 +368,12 @@ def load_file_io(properties: Properties = EMPTY_DICT, location: str | None = Non # Check the table location if location: - if file_io := _infer_file_io_from_scheme(location, properties): + if file_io := _infer_file_io_from_scheme(location, properties, session): return file_io # Look at the schema of the warehouse if warehouse_location := properties.get(WAREHOUSE): - if file_io := _infer_file_io_from_scheme(warehouse_location, properties): + if file_io := _infer_file_io_from_scheme(warehouse_location, properties, session): return file_io try: @@ -370,9 +381,98 @@ def load_file_io(properties: Properties = EMPTY_DICT, location: str | None = Non logger.info("Defaulting to PyArrow FileIO") from pyiceberg.io.pyarrow import PyArrowFileIO - return PyArrowFileIO(properties) + return PyArrowFileIO(properties, session) except ModuleNotFoundError as e: raise ModuleNotFoundError( "Could not load a FileIO, please consider installing one: " 'pip3 install "pyiceberg[pyarrow]", for more options refer to the docs.' ) from e + + +def _extract_s3_credentials(properties: Properties) -> Properties: + """Extract only S3 credential keys from properties, normalizing AWS_ prefixes to S3_.""" + creds: Properties = {} + if access_key := get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID): + creds[S3_ACCESS_KEY_ID] = access_key + if secret_key := get_first_property_value(properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY): + creds[S3_SECRET_ACCESS_KEY] = secret_key + if session_token := get_first_property_value(properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN): + creds[S3_SESSION_TOKEN] = session_token + if expiry := get_first_property_value(properties, S3_SESSION_TOKEN_EXPIRES_AT_MS): + creds[S3_SESSION_TOKEN_EXPIRES_AT_MS] = expiry + return creds + + +def _credential_from_properties(properties: Properties) -> Properties: + """Retrieve current S3 credentials from properties returns empty if expired.""" + access_key = get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID) + secret_access_key = get_first_property_value(properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY) + session_token = get_first_property_value(properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN) + expiration_ms = property_as_int(properties, S3_SESSION_TOKEN_EXPIRES_AT_MS) + + if not access_key or not secret_access_key or not session_token or not expiration_ms: + return EMPTY_DICT + + expiresAt = datetime.fromtimestamp(expiration_ms / 1000) + prefetchAt = (expiresAt - datetime.now()).total_seconds() + + if prefetchAt > 300: + return EMPTY_DICT + + return { + S3_ACCESS_KEY_ID: access_key, + S3_SECRET_ACCESS_KEY: secret_access_key, + S3_SESSION_TOKEN: session_token, + S3_SESSION_TOKEN_EXPIRES_AT_MS: expiration_ms, + } + + +def _credential_refresh_endpoint(properties: Properties) -> str: + """Build credential refresh endpoint from properties.""" + catalog_uri = get_first_property_value(properties, CATALOG_URI) + credentials_path = get_first_property_value(properties, CREDENTIALS_ENDPOINT) + + if catalog_uri is None: + raise ValidationException("Invalid catalog endpoint: None") + + if credentials_path is None: + raise ValidationException("Invalid credentials endpoint: None") + + return str(catalog_uri).rstrip("/") + "/" + str(credentials_path).lstrip("/") + + +def _get_or_refresh_credentials(properties: Properties, session: Session | None) -> Properties: + """Retrieve current S3 credentials from properties, refreshing them if they are close to expiration.""" + refresh_enabled = property_as_bool(properties, REFRESH_CREDENTIALS_ENABLED, False) + if not refresh_enabled or session is None: + return _extract_s3_credentials(properties) + + # Returns empty if credentials missing or not yet expiring + creds = _credential_from_properties(properties) + + if not creds: + return _extract_s3_credentials(properties) + + from pyiceberg.catalog.rest import LoadCredentialsResponse + from pyiceberg.catalog.rest.response import _handle_non_200_response + + load_response: LoadCredentialsResponse | None = None + + try: + http_response = session.get(_credential_refresh_endpoint(properties)) + http_response.raise_for_status() + load_response = LoadCredentialsResponse.model_validate_json(http_response.text) + except HTTPError as exc: + _handle_non_200_response(exc, {}) + + if load_response is None: + raise ValidationException("Load credential response is None") + + if not load_response.credentials: + raise ValueError("Invalid S3 Credentials: empty") + + if len(load_response.credentials) > 1: + raise ValueError("Invalid S3 Credentials: only one S3 credential should exist") + + credentials = load_response.credentials[0].config + return _extract_s3_credentials(credentials) diff --git a/pyiceberg/io/fsspec.py b/pyiceberg/io/fsspec.py index 63ec55bab4..bbb9034641 100644 --- a/pyiceberg/io/fsspec.py +++ b/pyiceberg/io/fsspec.py @@ -34,7 +34,7 @@ import requests from fsspec import AbstractFileSystem from fsspec.implementations.local import LocalFileSystem -from requests import HTTPError +from requests import HTTPError, Session from pyiceberg.catalog import TOKEN, URI from pyiceberg.catalog.rest.auth import AUTH_MANAGER @@ -88,6 +88,7 @@ InputStream, OutputFile, OutputStream, + _get_or_refresh_credentials, ) from pyiceberg.typedef import Properties from pyiceberg.types import strtobool @@ -165,14 +166,16 @@ def _file(_: Properties) -> LocalFileSystem: return LocalFileSystem(auto_mkdir=True) -def _s3(properties: Properties) -> AbstractFileSystem: +def _s3(properties: Properties, session: Session | None = None) -> AbstractFileSystem: from s3fs import S3FileSystem + creds = _get_or_refresh_credentials(properties, session) + client_kwargs = { "endpoint_url": properties.get(S3_ENDPOINT), - "aws_access_key_id": get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), - "aws_secret_access_key": get_first_property_value(properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), - "aws_session_token": get_first_property_value(properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), + "aws_access_key_id": get_first_property_value(creds, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), + "aws_secret_access_key": get_first_property_value(creds, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), + "aws_session_token": get_first_property_value(creds, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), "region_name": get_first_property_value(properties, S3_REGION, AWS_REGION), } config_kwargs = {} @@ -318,6 +321,7 @@ def _hf(properties: Properties) -> AbstractFileSystem: } _ADLS_SCHEMES = frozenset({"abfs", "abfss", "wasb", "wasbs"}) +_S3_SCHEMES = frozenset({"s3", "s3a", "s3n"}) class FsspecInputFile(InputFile): @@ -419,10 +423,10 @@ def to_input_file(self) -> FsspecInputFile: class FsspecFileIO(FileIO): """A FileIO implementation that uses fsspec.""" - def __init__(self, properties: Properties): + def __init__(self, properties: Properties, session: Session | None = None): self._scheme_to_fs: dict[str, Callable[..., AbstractFileSystem]] = dict(SCHEME_TO_FS) self._thread_locals = threading.local() - super().__init__(properties=properties) + super().__init__(properties=properties, session=session) def new_input(self, location: str) -> FsspecInputFile: """Get an FsspecInputFile instance to read bytes from the file at the given location. @@ -488,6 +492,9 @@ def _get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSyste if scheme in _ADLS_SCHEMES: return _adls(self.properties, hostname) + if scheme in _S3_SCHEMES: + return _s3(self.properties, self.session) + return self._scheme_to_fs[scheme](self.properties) def __getstate__(self) -> dict[str, Any]: diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 4517ae7327..95e75867ef 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -63,6 +63,7 @@ FileSystem, FileType, ) +from requests import Session from pyiceberg.conversions import to_bytes from pyiceberg.exceptions import ResolveError @@ -119,6 +120,7 @@ InputStream, OutputFile, OutputStream, + _get_or_refresh_credentials, ) from pyiceberg.io.fileformat import DataFileStatistics as DataFileStatistics from pyiceberg.manifest import ( @@ -386,9 +388,9 @@ def to_input_file(self) -> PyArrowFile: class PyArrowFileIO(FileIO): fs_by_scheme: Callable[[str, str | None], FileSystem] - def __init__(self, properties: Properties = EMPTY_DICT): + def __init__(self, properties: Properties = EMPTY_DICT, session: Session | None = None): self.fs_by_scheme: Callable[[str, str | None], FileSystem] = lru_cache(self._initialize_fs) - super().__init__(properties=properties) + super().__init__(properties=properties, session=session) @staticmethod def parse_location(location: str, properties: Properties = EMPTY_DICT) -> tuple[str, str, str]: @@ -433,11 +435,13 @@ def _initialize_fs(self, scheme: str, netloc: str | None = None) -> FileSystem: def _initialize_oss_fs(self) -> FileSystem: from pyarrow.fs import S3FileSystem + creds = _get_or_refresh_credentials(self.properties, self.session) + client_kwargs: dict[str, Any] = { "endpoint_override": self.properties.get(S3_ENDPOINT), - "access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), - "secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), - "session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), + "access_key": get_first_property_value(creds, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), + "secret_key": get_first_property_value(creds, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), + "session_token": get_first_property_value(creds, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), "region": get_first_property_value(self.properties, S3_REGION, AWS_REGION), "force_virtual_addressing": property_as_bool(self.properties, S3_FORCE_VIRTUAL_ADDRESSING, True), } @@ -480,11 +484,13 @@ def _initialize_s3_fs(self, netloc: str | None) -> FileSystem: else: bucket_region = provided_region + creds = _get_or_refresh_credentials(self.properties, self.session) + client_kwargs: dict[str, Any] = { "endpoint_override": self.properties.get(S3_ENDPOINT), - "access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), - "secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), - "session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), + "access_key": get_first_property_value(creds, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), + "secret_key": get_first_property_value(creds, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), + "session_token": get_first_property_value(creds, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), "region": bucket_region, } diff --git a/tests/io/test_io.py b/tests/io/test_io.py index d9bee33f8b..b60e3443f0 100644 --- a/tests/io/test_io.py +++ b/tests/io/test_io.py @@ -15,21 +15,38 @@ # specific language governing permissions and limitations # under the License. +import json import os import pickle import tempfile +from datetime import datetime, timedelta from typing import Any +from unittest.mock import MagicMock import pytest +from requests import HTTPError +from requests.models import Response +from pyiceberg.exceptions import ServerError, ValidationException from pyiceberg.io import ( ARROW_FILE_IO, + CATALOG_URI, + CREDENTIALS_ENDPOINT, PY_IO_IMPL, + REFRESH_CREDENTIALS_ENABLED, + S3_ACCESS_KEY_ID, + S3_SECRET_ACCESS_KEY, + S3_SESSION_TOKEN, + S3_SESSION_TOKEN_EXPIRES_AT_MS, + _credential_from_properties, + _credential_refresh_endpoint, + _get_or_refresh_credentials, _import_file_io, _infer_file_io_from_scheme, load_file_io, ) from pyiceberg.io.pyarrow import PyArrowFileIO +from pyiceberg.typedef import Properties def test_custom_local_input_file() -> None: @@ -339,3 +356,203 @@ def test_infer_file_io_from_schema_unknown() -> None: _infer_file_io_from_scheme("unknown://bucket/path/", {}) assert str(w[0].message) == "No preferred file implementation for scheme: unknown" + + +def _expiry_ms(delta_seconds: int) -> int: + return int((datetime.now() + timedelta(seconds=delta_seconds)).timestamp() * 1000) + + +def _full_cred_props(expiry_ms: int) -> Properties: + return { + S3_ACCESS_KEY_ID: "AKID", + S3_SECRET_ACCESS_KEY: "SECRET", + S3_SESSION_TOKEN: "TOKEN", + S3_SESSION_TOKEN_EXPIRES_AT_MS: str(expiry_ms), + CATALOG_URI: "https://catalog.example.com", + CREDENTIALS_ENDPOINT: "v1/credentials", + REFRESH_CREDENTIALS_ENABLED: "true", + } + + +def _make_session(response_body: Properties | None = None, status_code: int = 200) -> MagicMock: + session = MagicMock() + mock_resp = MagicMock() + mock_resp.text = json.dumps(response_body or {}) + mock_resp.raise_for_status.return_value = None + session.get.return_value = mock_resp + return session + + +def _make_http_error(status_code: int) -> HTTPError: + response = Response() + response.status_code = status_code + response._content = b'{"error": {"message": "server error", "type": "ServerError", "code": 500}}' + exc = HTTPError(response=response) + return exc + + +def test_credential_from_properties_missing_access_key() -> None: + props = {S3_SECRET_ACCESS_KEY: "SECRET", S3_SESSION_TOKEN: "TOKEN", S3_SESSION_TOKEN_EXPIRES_AT_MS: str(_expiry_ms(200))} + assert _credential_from_properties(props) == {} + + +def test_credential_from_properties_missing_secret_key() -> None: + props = {S3_ACCESS_KEY_ID: "AKID", S3_SESSION_TOKEN: "TOKEN", S3_SESSION_TOKEN_EXPIRES_AT_MS: str(_expiry_ms(200))} + assert _credential_from_properties(props) == {} + + +def test_credential_from_properties_missing_session_token() -> None: + props = {S3_ACCESS_KEY_ID: "AKID", S3_SECRET_ACCESS_KEY: "SECRET", S3_SESSION_TOKEN_EXPIRES_AT_MS: str(_expiry_ms(200))} + assert _credential_from_properties(props) == {} + + +def test_credential_from_properties_missing_expiry() -> None: + props = {S3_ACCESS_KEY_ID: "AKID", S3_SECRET_ACCESS_KEY: "SECRET", S3_SESSION_TOKEN: "TOKEN"} + assert _credential_from_properties(props) == {} + + +def test_credential_from_properties_not_expiring_soon() -> None: + props = { + S3_ACCESS_KEY_ID: "AKID", + S3_SECRET_ACCESS_KEY: "SECRET", + S3_SESSION_TOKEN: "TOKEN", + S3_SESSION_TOKEN_EXPIRES_AT_MS: str(_expiry_ms(600)), + } + assert _credential_from_properties(props) == {} + + +def test_credential_from_properties_expiring_soon() -> None: + expiry = _expiry_ms(200) + props = { + S3_ACCESS_KEY_ID: "AKID", + S3_SECRET_ACCESS_KEY: "SECRET", + S3_SESSION_TOKEN: "TOKEN", + S3_SESSION_TOKEN_EXPIRES_AT_MS: str(expiry), + } + result = _credential_from_properties(props) + assert result[S3_ACCESS_KEY_ID] == "AKID" + assert result[S3_SECRET_ACCESS_KEY] == "SECRET" + assert result[S3_SESSION_TOKEN] == "TOKEN" + assert result[S3_SESSION_TOKEN_EXPIRES_AT_MS] == expiry + + +def test_credential_from_properties_already_expired() -> None: + expiry = _expiry_ms(-60) + props = { + S3_ACCESS_KEY_ID: "AKID", + S3_SECRET_ACCESS_KEY: "SECRET", + S3_SESSION_TOKEN: "TOKEN", + S3_SESSION_TOKEN_EXPIRES_AT_MS: str(expiry), + } + result = _credential_from_properties(props) + assert result[S3_ACCESS_KEY_ID] == "AKID" + + +def test_credential_refresh_endpoint_missing_uri() -> None: + with pytest.raises(ValidationException, match="Invalid catalog endpoint"): + _credential_refresh_endpoint({CREDENTIALS_ENDPOINT: "v1/creds"}) + + +def test_credential_refresh_endpoint_missing_path() -> None: + with pytest.raises(ValidationException, match="Invalid credentials endpoint"): + _credential_refresh_endpoint({CATALOG_URI: "https://catalog.example.com"}) + + +def test_credential_refresh_endpoint_trailing_slash_handling() -> None: + props = {CATALOG_URI: "https://catalog.example.com/", CREDENTIALS_ENDPOINT: "/v1/creds"} + assert _credential_refresh_endpoint(props) == "https://catalog.example.com/v1/creds" + + +def test_credential_refresh_endpoint_no_slash() -> None: + props = {CATALOG_URI: "https://catalog.example.com", CREDENTIALS_ENDPOINT: "v1/creds"} + assert _credential_refresh_endpoint(props) == "https://catalog.example.com/v1/creds" + + +def _expected_s3_creds(expiry_ms: int) -> Properties: + return { + S3_ACCESS_KEY_ID: "AKID", + S3_SECRET_ACCESS_KEY: "SECRET", + S3_SESSION_TOKEN: "TOKEN", + S3_SESSION_TOKEN_EXPIRES_AT_MS: str(expiry_ms), + } + + +def test_get_or_refresh_credentials_disabled() -> None: + expiry = _expiry_ms(200) + props = _full_cred_props(expiry) + props[REFRESH_CREDENTIALS_ENABLED] = "false" + assert _get_or_refresh_credentials(props, MagicMock()) == _expected_s3_creds(expiry) + + +def test_get_or_refresh_credentials_no_session() -> None: + expiry = _expiry_ms(200) + props = _full_cred_props(expiry) + assert _get_or_refresh_credentials(props, None) == _expected_s3_creds(expiry) + + +def test_get_or_refresh_credentials_not_expiring() -> None: + expiry = _expiry_ms(600) + props = _full_cred_props(expiry) + session = MagicMock() + assert _get_or_refresh_credentials(props, session) == _expected_s3_creds(expiry) + session.get.assert_not_called() + + +def test_get_or_refresh_credentials_success() -> None: + new_expiry = _expiry_ms(3600) + response_body = { + "storage-credentials": [ + { + "prefix": "s3://", + "config": { + S3_ACCESS_KEY_ID: "NEW_AKID", + S3_SECRET_ACCESS_KEY: "NEW_SECRET", + S3_SESSION_TOKEN: "NEW_TOKEN", + S3_SESSION_TOKEN_EXPIRES_AT_MS: str(new_expiry), + }, + } + ] + } + props = _full_cred_props(_expiry_ms(200)) + session = _make_session(response_body) + + result = _get_or_refresh_credentials(props, session) + + assert result[S3_ACCESS_KEY_ID] == "NEW_AKID" + assert result[S3_SECRET_ACCESS_KEY] == "NEW_SECRET" + assert result[S3_SESSION_TOKEN] == "NEW_TOKEN" + assert result[S3_SESSION_TOKEN_EXPIRES_AT_MS] == str(new_expiry) + + +def test_get_or_refresh_credentials_http_error() -> None: + props = _full_cred_props(_expiry_ms(200)) + session = MagicMock() + session.get.return_value.raise_for_status.side_effect = _make_http_error(500) + + with pytest.raises(ServerError): + _get_or_refresh_credentials(props, session) + + +def test_get_or_refresh_credentials_empty_credentials() -> None: + props = _full_cred_props(_expiry_ms(200)) + session = _make_session({"storage-credentials": []}) + + with pytest.raises(ValueError, match="Invalid S3 Credentials: empty"): + _get_or_refresh_credentials(props, session) + + +def test_get_or_refresh_credentials_multiple_credentials() -> None: + credential = { + "prefix": "s3://", + "config": { + S3_ACCESS_KEY_ID: "A", + S3_SECRET_ACCESS_KEY: "B", + S3_SESSION_TOKEN: "C", + S3_SESSION_TOKEN_EXPIRES_AT_MS: "123", + }, + } + props = _full_cred_props(_expiry_ms(200)) + session = _make_session({"storage-credentials": [credential, credential]}) + + with pytest.raises(ValueError, match="only one S3 credential should exist"): + _get_or_refresh_credentials(props, session)