Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions pyiceberg/catalog/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ class ListViewsResponse(IcebergBaseModel):
identifiers: list[ListViewResponseEntry] = Field()


class LoadCredentialsResponse(IcebergBaseModel):
credentials: list[StorageCredential] = Field(alias="storage-credentials")


_PLANNING_RESPONSE_ADAPTER = TypeAdapter(PlanningResponse)


Expand Down Expand Up @@ -469,11 +473,13 @@ def _resolve_storage_credentials(storage_credentials: list[StorageCredential], l

return best_match.config if best_match else {}

def _load_file_io(self, properties: Properties = EMPTY_DICT, location: str | None = None) -> FileIO:
def _load_file_io(
self, properties: Properties = EMPTY_DICT, location: str | None = None, session: Session | None = None
) -> FileIO:
merged_properties = {**self.properties, **properties}
if self._auth_manager:
merged_properties[AUTH_MANAGER] = self._auth_manager
return load_file_io(merged_properties, location)
return load_file_io(merged_properties, location, session)

def supports_server_side_planning(self) -> bool:
"""Check if the catalog supports server-side scan planning."""
Expand Down Expand Up @@ -820,6 +826,7 @@ def _response_to_table(self, identifier_tuple: tuple[str, ...], table_response:
io=self._load_file_io(
{**table_response.metadata.properties, **table_response.config, **credential_config},
table_response.metadata_location,
self._session,
),
catalog=self,
config=table_response.config,
Expand All @@ -837,6 +844,7 @@ def _response_to_staged_table(self, identifier_tuple: tuple[str, ...], table_res
io=self._load_file_io(
{**table_response.metadata.properties, **table_response.config, **credential_config},
table_response.metadata_location,
self._session,
),
catalog=self,
)
Expand Down
120 changes: 110 additions & 10 deletions pyiceberg/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import logging
import warnings
from abc import ABC, abstractmethod
from datetime import datetime
from io import SEEK_SET
from types import TracebackType
from typing import (
Expand All @@ -37,7 +38,11 @@
)
from urllib.parse import urlparse

from requests import HTTPError, Session

from pyiceberg.exceptions import ValidationException
from pyiceberg.typedef import EMPTY_DICT, Properties
from pyiceberg.utils.properties import get_first_property_value, property_as_bool, property_as_int

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,6 +72,7 @@
S3_ROLE_SESSION_NAME = "s3.role-session-name"
S3_FORCE_VIRTUAL_ADDRESSING = "s3.force-virtual-addressing"
S3_RETRY_STRATEGY_IMPL = "s3.retry-strategy-impl"
S3_SESSION_TOKEN_EXPIRES_AT_MS = "s3.session-token-expires-at-ms"
HDFS_HOST = "hdfs.host"
HDFS_PORT = "hdfs.port"
HDFS_USER = "hdfs.user"
Expand Down Expand Up @@ -99,6 +105,9 @@
GCS_VERSION_AWARE = "gcs.version-aware"
HF_ENDPOINT = "hf.endpoint"
HF_TOKEN = "hf.token"
CREDENTIALS_ENDPOINT = "client.refresh-credentials-endpoint"
REFRESH_CREDENTIALS_ENABLED = "client.refresh-credentials-enabled"
CATALOG_URI = "uri"


@runtime_checkable
Expand Down Expand Up @@ -258,9 +267,11 @@ class FileIO(ABC):
"""A base class for FileIO implementations."""

properties: Properties
session: Session | None

def __init__(self, properties: Properties = EMPTY_DICT):
def __init__(self, properties: Properties = EMPTY_DICT, session: Session | None = None):
self.properties = properties
self.session = session

@abstractmethod
def new_input(self, location: str) -> InputFile:
Expand Down Expand Up @@ -317,15 +328,15 @@ def delete(self, location: str | InputFile | OutputFile) -> None:
}


def _import_file_io(io_impl: str, properties: Properties) -> FileIO | None:
def _import_file_io(io_impl: str, properties: Properties, session: Session | None = None) -> FileIO | None:
try:
path_parts = io_impl.split(".")
if len(path_parts) < 2:
raise ValueError(f"py-io-impl should be full path (module.CustomFileIO), got: {io_impl}")
module_name, class_name = ".".join(path_parts[:-1]), path_parts[-1]
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_(properties)
return class_(properties, session)
except ModuleNotFoundError:
logger.warning(f"Could not initialize FileIO: {io_impl}", exc_info=logger.isEnabledFor(logging.DEBUG))
return None
Expand All @@ -334,45 +345,134 @@ def _import_file_io(io_impl: str, properties: Properties) -> FileIO | None:
PY_IO_IMPL = "py-io-impl"


def _infer_file_io_from_scheme(path: str, properties: Properties) -> FileIO | None:
def _infer_file_io_from_scheme(path: str, properties: Properties, session: Session | None = None) -> FileIO | None:
parsed_url = urlparse(path)
if parsed_url.scheme:
if file_ios := SCHEMA_TO_FILE_IO.get(parsed_url.scheme):
for file_io_path in file_ios:
if file_io := _import_file_io(file_io_path, properties):
if file_io := _import_file_io(file_io_path, properties, session):
return file_io
else:
warnings.warn(f"No preferred file implementation for scheme: {parsed_url.scheme}", stacklevel=2)
return None


def load_file_io(properties: Properties = EMPTY_DICT, location: str | None = None) -> FileIO:
def load_file_io(properties: Properties = EMPTY_DICT, location: str | None = None, session: Session | None = None) -> FileIO:
# First look for the py-io-impl property to directly load the class
if io_impl := properties.get(PY_IO_IMPL):
if file_io := _import_file_io(io_impl, properties):
if file_io := _import_file_io(io_impl, properties, session):
logger.info("Loaded FileIO: %s", io_impl)
return file_io
else:
raise ValueError(f"Could not initialize FileIO: {io_impl}")

# Check the table location
if location:
if file_io := _infer_file_io_from_scheme(location, properties):
if file_io := _infer_file_io_from_scheme(location, properties, session):
return file_io

# Look at the schema of the warehouse
if warehouse_location := properties.get(WAREHOUSE):
if file_io := _infer_file_io_from_scheme(warehouse_location, properties):
if file_io := _infer_file_io_from_scheme(warehouse_location, properties, session):
return file_io

try:
# Default to PyArrow
logger.info("Defaulting to PyArrow FileIO")
from pyiceberg.io.pyarrow import PyArrowFileIO

return PyArrowFileIO(properties)
return PyArrowFileIO(properties, session)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"Could not load a FileIO, please consider installing one: "
'pip3 install "pyiceberg[pyarrow]", for more options refer to the docs.'
) from e


def _extract_s3_credentials(properties: Properties) -> Properties:
"""Extract only S3 credential keys from properties, normalizing AWS_ prefixes to S3_."""
creds: Properties = {}
if access_key := get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID):
creds[S3_ACCESS_KEY_ID] = access_key
if secret_key := get_first_property_value(properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY):
creds[S3_SECRET_ACCESS_KEY] = secret_key
if session_token := get_first_property_value(properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN):
creds[S3_SESSION_TOKEN] = session_token
if expiry := get_first_property_value(properties, S3_SESSION_TOKEN_EXPIRES_AT_MS):
creds[S3_SESSION_TOKEN_EXPIRES_AT_MS] = expiry
return creds


def _credential_from_properties(properties: Properties) -> Properties:
"""Retrieve current S3 credentials from properties returns empty if expired."""
access_key = get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID)
secret_access_key = get_first_property_value(properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY)
session_token = get_first_property_value(properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN)
expiration_ms = property_as_int(properties, S3_SESSION_TOKEN_EXPIRES_AT_MS)

if not access_key or not secret_access_key or not session_token or not expiration_ms:
return EMPTY_DICT

expiresAt = datetime.fromtimestamp(expiration_ms / 1000)
prefetchAt = (expiresAt - datetime.now()).total_seconds()

if prefetchAt > 300:
return EMPTY_DICT

return {
S3_ACCESS_KEY_ID: access_key,
S3_SECRET_ACCESS_KEY: secret_access_key,
S3_SESSION_TOKEN: session_token,
S3_SESSION_TOKEN_EXPIRES_AT_MS: expiration_ms,
}


def _credential_refresh_endpoint(properties: Properties) -> str:
"""Build credential refresh endpoint from properties."""
catalog_uri = get_first_property_value(properties, CATALOG_URI)
credentials_path = get_first_property_value(properties, CREDENTIALS_ENDPOINT)

if catalog_uri is None:
raise ValidationException("Invalid catalog endpoint: None")

if credentials_path is None:
raise ValidationException("Invalid credentials endpoint: None")

return str(catalog_uri).rstrip("/") + "/" + str(credentials_path).lstrip("/")


def _get_or_refresh_credentials(properties: Properties, session: Session | None) -> Properties:
"""Retrieve current S3 credentials from properties, refreshing them if they are close to expiration."""
refresh_enabled = property_as_bool(properties, REFRESH_CREDENTIALS_ENABLED, False)
if not refresh_enabled or session is None:
return _extract_s3_credentials(properties)

# Returns empty if credentials missing or not yet expiring
creds = _credential_from_properties(properties)

if not creds:
return _extract_s3_credentials(properties)

from pyiceberg.catalog.rest import LoadCredentialsResponse
from pyiceberg.catalog.rest.response import _handle_non_200_response

load_response: LoadCredentialsResponse | None = None

try:
http_response = session.get(_credential_refresh_endpoint(properties))
http_response.raise_for_status()
load_response = LoadCredentialsResponse.model_validate_json(http_response.text)
except HTTPError as exc:
_handle_non_200_response(exc, {})

if load_response is None:
raise ValidationException("Load credential response is None")

if not load_response.credentials:
raise ValueError("Invalid S3 Credentials: empty")

if len(load_response.credentials) > 1:
raise ValueError("Invalid S3 Credentials: only one S3 credential should exist")

credentials = load_response.credentials[0].config
return _extract_s3_credentials(credentials)
21 changes: 14 additions & 7 deletions pyiceberg/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import requests
from fsspec import AbstractFileSystem
from fsspec.implementations.local import LocalFileSystem
from requests import HTTPError
from requests import HTTPError, Session

from pyiceberg.catalog import TOKEN, URI
from pyiceberg.catalog.rest.auth import AUTH_MANAGER
Expand Down Expand Up @@ -88,6 +88,7 @@
InputStream,
OutputFile,
OutputStream,
_get_or_refresh_credentials,
)
from pyiceberg.typedef import Properties
from pyiceberg.types import strtobool
Expand Down Expand Up @@ -165,14 +166,16 @@ def _file(_: Properties) -> LocalFileSystem:
return LocalFileSystem(auto_mkdir=True)


def _s3(properties: Properties) -> AbstractFileSystem:
def _s3(properties: Properties, session: Session | None = None) -> AbstractFileSystem:
from s3fs import S3FileSystem

creds = _get_or_refresh_credentials(properties, session)

client_kwargs = {
"endpoint_url": properties.get(S3_ENDPOINT),
"aws_access_key_id": get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
"aws_secret_access_key": get_first_property_value(properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
"aws_session_token": get_first_property_value(properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
"aws_access_key_id": get_first_property_value(creds, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
"aws_secret_access_key": get_first_property_value(creds, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
"aws_session_token": get_first_property_value(creds, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
"region_name": get_first_property_value(properties, S3_REGION, AWS_REGION),
}
config_kwargs = {}
Expand Down Expand Up @@ -318,6 +321,7 @@ def _hf(properties: Properties) -> AbstractFileSystem:
}

_ADLS_SCHEMES = frozenset({"abfs", "abfss", "wasb", "wasbs"})
_S3_SCHEMES = frozenset({"s3", "s3a", "s3n"})


class FsspecInputFile(InputFile):
Expand Down Expand Up @@ -419,10 +423,10 @@ def to_input_file(self) -> FsspecInputFile:
class FsspecFileIO(FileIO):
"""A FileIO implementation that uses fsspec."""

def __init__(self, properties: Properties):
def __init__(self, properties: Properties, session: Session | None = None):
self._scheme_to_fs: dict[str, Callable[..., AbstractFileSystem]] = dict(SCHEME_TO_FS)
self._thread_locals = threading.local()
super().__init__(properties=properties)
super().__init__(properties=properties, session=session)

def new_input(self, location: str) -> FsspecInputFile:
"""Get an FsspecInputFile instance to read bytes from the file at the given location.
Expand Down Expand Up @@ -488,6 +492,9 @@ def _get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSyste
if scheme in _ADLS_SCHEMES:
return _adls(self.properties, hostname)

if scheme in _S3_SCHEMES:
return _s3(self.properties, self.session)

return self._scheme_to_fs[scheme](self.properties)

def __getstate__(self) -> dict[str, Any]:
Expand Down
22 changes: 14 additions & 8 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
FileSystem,
FileType,
)
from requests import Session

from pyiceberg.conversions import to_bytes
from pyiceberg.exceptions import ResolveError
Expand Down Expand Up @@ -119,6 +120,7 @@
InputStream,
OutputFile,
OutputStream,
_get_or_refresh_credentials,
)
from pyiceberg.io.fileformat import DataFileStatistics as DataFileStatistics
from pyiceberg.manifest import (
Expand Down Expand Up @@ -386,9 +388,9 @@ def to_input_file(self) -> PyArrowFile:
class PyArrowFileIO(FileIO):
fs_by_scheme: Callable[[str, str | None], FileSystem]

def __init__(self, properties: Properties = EMPTY_DICT):
def __init__(self, properties: Properties = EMPTY_DICT, session: Session | None = None):
self.fs_by_scheme: Callable[[str, str | None], FileSystem] = lru_cache(self._initialize_fs)
super().__init__(properties=properties)
super().__init__(properties=properties, session=session)

@staticmethod
def parse_location(location: str, properties: Properties = EMPTY_DICT) -> tuple[str, str, str]:
Expand Down Expand Up @@ -433,11 +435,13 @@ def _initialize_fs(self, scheme: str, netloc: str | None = None) -> FileSystem:
def _initialize_oss_fs(self) -> FileSystem:
from pyarrow.fs import S3FileSystem

creds = _get_or_refresh_credentials(self.properties, self.session)

client_kwargs: dict[str, Any] = {
"endpoint_override": self.properties.get(S3_ENDPOINT),
"access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
"secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
"session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
"access_key": get_first_property_value(creds, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
"secret_key": get_first_property_value(creds, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
"session_token": get_first_property_value(creds, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
"region": get_first_property_value(self.properties, S3_REGION, AWS_REGION),
"force_virtual_addressing": property_as_bool(self.properties, S3_FORCE_VIRTUAL_ADDRESSING, True),
}
Expand Down Expand Up @@ -480,11 +484,13 @@ def _initialize_s3_fs(self, netloc: str | None) -> FileSystem:
else:
bucket_region = provided_region

creds = _get_or_refresh_credentials(self.properties, self.session)

client_kwargs: dict[str, Any] = {
"endpoint_override": self.properties.get(S3_ENDPOINT),
"access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
"secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
"session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
"access_key": get_first_property_value(creds, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
"secret_key": get_first_property_value(creds, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
"session_token": get_first_property_value(creds, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
"region": bucket_region,
}

Expand Down
Loading
Loading