diff --git a/changes/3799.misc.md b/changes/3799.misc.md new file mode 100644 index 0000000000..8e3461febe --- /dev/null +++ b/changes/3799.misc.md @@ -0,0 +1 @@ +Adds a protocol `JSONSerializable` that defines methods for classes that serialize to, and deserialize from, JSON-compatible Python data structures. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 8277c3f752..4b6ba409f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ 'numpy>=2', 'numcodecs>=0.14', 'google-crc32c>=1.5', - 'typing_extensions>=4.12', + 'typing_extensions>=4.13', 'donfig>=0.8', ] @@ -233,7 +233,7 @@ extra-dependencies = [ 'fsspec==2023.10.0', 's3fs==2023.10.0', 'universal_pathlib==0.2.0', - 'typing_extensions==4.12.*', + 'typing_extensions==4.13.*', 'donfig==0.8.*', 'obstore==0.5.*', ] diff --git a/src/zarr/abc/serializable.py b/src/zarr/abc/serializable.py new file mode 100644 index 0000000000..3f4c4c9176 --- /dev/null +++ b/src/zarr/abc/serializable.py @@ -0,0 +1,16 @@ +from typing import Protocol, Self + + +class JSONSerializable[T_contra, T_co](Protocol): + @classmethod + def from_json(cls, obj: T_contra) -> Self: + """ + Deserialize from an instance of T_contra. + """ + ... + + def to_json(self) -> T_co: + """ + Serialize to JSON. + """ + ... diff --git a/src/zarr/core/metadata/common.py b/src/zarr/core/metadata/common.py index 44d3eb292b..bd62e08ac6 100644 --- a/src/zarr/core/metadata/common.py +++ b/src/zarr/core/metadata/common.py @@ -1,13 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from collections.abc import Mapping + from zarr.core.common import JSON -def parse_attributes(data: dict[str, JSON] | None) -> dict[str, JSON]: +def parse_attributes(data: Mapping[str, Any] | None) -> dict[str, JSON]: if data is None: return {} - return data + return dict(data) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 2a5da50c7b..cdea80260b 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -1,11 +1,20 @@ from __future__ import annotations from collections.abc import Mapping -from typing import TYPE_CHECKING, NotRequired, TypedDict, TypeGuard, cast +from typing import TYPE_CHECKING, Final, NotRequired, TypeGuard, cast + +from typing_extensions import TypedDict from zarr.abc.metadata import Metadata +from zarr.abc.serializable import JSONSerializable from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.dtype import VariableLengthUTF8, ZDType, get_data_type_from_json +from zarr.core.dtype import ( + VariableLengthUTF8, + ZDType, + ZDTypeLike, + get_data_type_from_json, + parse_dtype, +) from zarr.core.dtype.common import check_dtype_spec_v3 if TYPE_CHECKING: @@ -35,6 +44,7 @@ ZARR_JSON, DimensionNamesLike, NamedConfig, + ShapeLike, parse_named_configuration, parse_shapelike, ) @@ -156,27 +166,35 @@ def check_allowed_extra_field(data: object) -> TypeGuard[AllowedExtraField]: def parse_extra_fields( - data: Mapping[str, AllowedExtraField] | None, + data: Mapping[str, object] | None, ) -> dict[str, AllowedExtraField]: if data is None: return {} - else: - conflict_keys = ARRAY_METADATA_KEYS & set(data.keys()) - if len(conflict_keys) > 0: - msg = ( - "Invalid extra fields. " - "The following keys: " - f"{sorted(conflict_keys)} " - "are invalid because they collide with keys reserved for use by the " - "array metadata document." - ) - raise ValueError(msg) - return dict(data) + conflict_keys = ARRAY_METADATA_KEYS & set(data.keys()) + if len(conflict_keys) > 0: + msg = ( + "Invalid extra fields. " + "The following keys: " + f"{sorted(conflict_keys)} " + "are invalid because they collide with keys reserved for use by the " + "array metadata document." + ) + raise ValueError(msg) -class ArrayMetadataJSON_V3(TypedDict): + disallowed = {k: v for k, v in data.items() if not check_allowed_extra_field(v)} + if disallowed: + raise MetadataValidationError( + f"Disallowed extra fields: {sorted(disallowed.keys())}. " + 'Extra fields must be a mapping with "must_understand" set to False.' + ) + + return dict(data) # type: ignore[arg-type] + + +class ArrayMetadataJSON_V3(TypedDict, extra_items=AllowedExtraField): # type: ignore[call-arg] """ - A typed dictionary model for zarr v3 metadata. + A typed dictionary model for Zarr v3 array metadata. """ zarr_format: Literal[3] @@ -194,9 +212,159 @@ class ArrayMetadataJSON_V3(TypedDict): ARRAY_METADATA_KEYS = set(ArrayMetadataJSON_V3.__annotations__.keys()) +ChunkGridLike = dict[str, JSON] | ChunkGrid | NamedConfig[str, Any] +CodecLike = Codec | dict[str, JSON] | NamedConfig[str, Any] | str + +# Required keys in ArrayMetadataJSONLike_V3 (excludes zarr_format and node_type, +# which are identity fields consumed by the I/O layer before reaching this point). +_REQUIRED_JSONLIKE_KEYS = frozenset( + {"shape", "data_type", "chunk_grid", "chunk_key_encoding", "codecs", "fill_value"} +) + +# All keys defined by the zarr v3 array metadata spec. +_ARRAY_METADATA_KNOWN_KEYS: Final[frozenset[str]] = frozenset( + { + "zarr_format", + "node_type", + "shape", + "data_type", + "chunk_grid", + "chunk_key_encoding", + "codecs", + "fill_value", + "attributes", + "dimension_names", + "storage_transformers", + } +) + + +def check_array_metadata_like(data: object) -> ArrayMetadataJSONLike_V3: + """ + Narrow an untrusted object to `ArrayMetadataJSONLike_V3`. + + Performs structural type checking — verifies that `data` is a mapping + with the expected keys and that each value has an acceptable Python type. + Raises `TypeError` if the input is structurally wrong. + Does **not** validate the semantic correctness of values (e.g. whether a + data type string is a recognized dtype). That validation is the + responsibility of `__init__`, which raises `ValueError`. + + This function allows `zarr_format` and `node_type` to be absent. The + expectation is that these values have already been validated as a + precondition for invoking this function. + """ + errors: list[str] = [] + + if not isinstance(data, Mapping): + raise TypeError(f"Expected a mapping, got {type(data).__name__}") + + # --- required keys --- + missing = _REQUIRED_JSONLIKE_KEYS - set(data.keys()) + if missing: + errors.append(f"Missing required keys: {sorted(missing)}") + + # --- shape: Iterable (but not str or Mapping) --- + shape = data.get("shape") + if shape is not None and (not isinstance(shape, Iterable) or isinstance(shape, str | Mapping)): + errors.append(f"Invalid shape: expected an iterable, got {type(shape).__name__}") + + # --- data_type: str, Mapping, or ZDType --- + data_type_json = data.get("data_type") + if data_type_json is not None and not isinstance(data_type_json, str | Mapping | ZDType): + errors.append( + f"Invalid data_type: expected a string, mapping, or ZDType, got {type(data_type_json).__name__}" + ) + + # --- chunk_grid: Mapping or ChunkGrid --- + chunk_grid = data.get("chunk_grid") + if chunk_grid is not None and not isinstance(chunk_grid, Mapping | ChunkGrid): + errors.append( + f"Invalid chunk_grid: expected a mapping or ChunkGrid, got {type(chunk_grid).__name__}" + ) + + # --- chunk_key_encoding: Mapping or ChunkKeyEncoding --- + chunk_key_encoding = data.get("chunk_key_encoding") + if chunk_key_encoding is not None and not isinstance( + chunk_key_encoding, Mapping | ChunkKeyEncoding + ): + errors.append( + f"Invalid chunk_key_encoding: expected a mapping or ChunkKeyEncoding, got {type(chunk_key_encoding).__name__}" + ) + + # --- codecs: Iterable (but not str or Mapping) --- + codecs = data.get("codecs") + if codecs is not None and ( + not isinstance(codecs, Iterable) or isinstance(codecs, str | Mapping) + ): + errors.append(f"Invalid codecs: expected an iterable, got {type(codecs).__name__}") + + # --- fill_value: any type is allowed, just must be present (checked via required keys) --- + + # --- attributes (optional): Mapping --- + attributes = data.get("attributes") + if attributes is not None and not isinstance(attributes, Mapping): + errors.append(f"Invalid attributes: expected a mapping, got {type(attributes).__name__}") + + # --- dimension_names (optional): Iterable (but not str) --- + dimension_names = data.get("dimension_names") + if dimension_names is not None and ( + not isinstance(dimension_names, Iterable) or isinstance(dimension_names, str) + ): + errors.append( + f"Invalid dimension_names: expected an iterable or None, got {type(dimension_names).__name__}" + ) + + # --- storage_transformers (optional): Iterable (but not str or Mapping) --- + storage_transformers = data.get("storage_transformers") + if storage_transformers is not None and ( + not isinstance(storage_transformers, Iterable) + or isinstance(storage_transformers, str | Mapping) + ): + errors.append( + f"Invalid storage_transformers: expected an iterable, got {type(storage_transformers).__name__}" + ) + + # --- extra fields: must be AllowedExtraField (mapping with must_understand=False) --- + _known_keys = _ARRAY_METADATA_KNOWN_KEYS + data_map = cast(Mapping[str, object], data) + extra_keys = set(data_map.keys()) - _known_keys + disallowed = [k for k in extra_keys if not check_allowed_extra_field(data_map[k])] + if disallowed: + errors.append( + f"Disallowed extra fields: {sorted(disallowed)}. " + 'Extra fields must be a mapping with "must_understand" set to False.' + ) + + if errors: + raise TypeError( + "Cannot interpret input as Zarr v3 array metadata:\n" + + "\n".join(f" - {e}" for e in errors) + ) + + return cast(ArrayMetadataJSONLike_V3, data) + + +class ArrayMetadataJSONLike_V3(TypedDict, extra_items=AllowedExtraField): # type: ignore[call-arg] + """ + A typed dictionary model of JSON-like input that can be used to create ArrayV3Metadata + """ + + zarr_format: NotRequired[Literal[3]] + node_type: NotRequired[Literal["array"]] + shape: ShapeLike + data_type: ZDTypeLike + chunk_grid: ChunkGridLike + chunk_key_encoding: ChunkKeyEncodingLike + codecs: Iterable[CodecLike] + fill_value: object + attributes: NotRequired[dict[str, JSON]] + dimension_names: NotRequired[DimensionNamesLike] + storage_transformers: NotRequired[Iterable[dict[str, JSON]]] + @dataclass(frozen=True, kw_only=True) -class ArrayV3Metadata(Metadata): +class ArrayV3Metadata(Metadata, JSONSerializable[ArrayMetadataJSONLike_V3, ArrayMetadataJSON_V3]): shape: tuple[int, ...] data_type: ZDType[TBaseDType, TBaseScalar] chunk_grid: ChunkGrid @@ -213,8 +381,8 @@ class ArrayV3Metadata(Metadata): def __init__( self, *, - shape: Iterable[int], - data_type: ZDType[TBaseDType, TBaseScalar], + shape: ShapeLike, + data_type: ZDTypeLike, chunk_grid: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any], chunk_key_encoding: ChunkKeyEncodingLike, fill_value: object, @@ -229,27 +397,27 @@ def __init__( """ shape_parsed = parse_shapelike(shape) + data_type_parsed = parse_dtype(data_type, zarr_format=3) chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid) chunk_key_encoding_parsed = parse_chunk_key_encoding(chunk_key_encoding) dimension_names_parsed = parse_dimension_names(dimension_names) - # Note: relying on a type method is numpy-specific - fill_value_parsed = data_type.cast_scalar(fill_value) + fill_value_parsed = data_type_parsed.cast_scalar(fill_value) attributes_parsed = parse_attributes(attributes) codecs_parsed_partial = parse_codecs(codecs) storage_transformers_parsed = parse_storage_transformers(storage_transformers) extra_fields_parsed = parse_extra_fields(extra_fields) array_spec = ArraySpec( shape=shape_parsed, - dtype=data_type, + dtype=data_type_parsed, fill_value=fill_value_parsed, config=ArrayConfig.from_dict({}), # TODO: config is not needed here. prototype=default_buffer_prototype(), # TODO: prototype is not needed here. ) codecs_parsed = tuple(c.evolve_from_array_spec(array_spec) for c in codecs_parsed_partial) - validate_codecs(codecs_parsed_partial, data_type) + validate_codecs(codecs_parsed_partial, data_type_parsed) object.__setattr__(self, "shape", shape_parsed) - object.__setattr__(self, "data_type", data_type) + object.__setattr__(self, "data_type", data_type_parsed) object.__setattr__(self, "chunk_grid", chunk_grid_parsed) object.__setattr__(self, "chunk_key_encoding", chunk_key_encoding_parsed) object.__setattr__(self, "codecs", codecs_parsed) @@ -410,7 +578,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: storage_transformers=_data_typed.get("storage_transformers", ()), # type: ignore[arg-type] ) - def to_dict(self) -> dict[str, JSON]: + def to_json(self) -> ArrayMetadataJSON_V3: out_dict = super().to_dict() extra_fields = out_dict.pop("extra_fields") out_dict = out_dict | extra_fields # type: ignore[operator] @@ -426,14 +594,35 @@ def to_dict(self) -> dict[str, JSON]: if out_dict["dimension_names"] is None: out_dict.pop("dimension_names") - # TODO: replace the `to_dict` / `from_dict` on the `Metadata`` class with - # to_json, from_json, and have ZDType inherit from `Metadata` - # until then, we have this hack here, which relies on the fact that to_dict will pass through - # any non-`Metadata` fields as-is. + # TODO: have ZDType inherit from JSONSerializable so we can remove this hack dtype_meta = out_dict["data_type"] if isinstance(dtype_meta, ZDType): out_dict["data_type"] = dtype_meta.to_json(zarr_format=3) # type: ignore[unreachable] - return out_dict + return cast(ArrayMetadataJSON_V3, out_dict) + + def to_dict(self) -> dict[str, JSON]: + return dict(self.to_json()) # type: ignore[arg-type] + + @classmethod + def from_json(cls, obj: ArrayMetadataJSONLike_V3) -> Self: + """ + Construct from a trusted, typed input. No validation of the input structure + is performed beyond what `__init__` already does. + """ + _known_keys = _ARRAY_METADATA_KNOWN_KEYS + extra_fields = {k: v for k, v in obj.items() if k not in _known_keys} + return cls( + shape=obj["shape"], + data_type=obj["data_type"], + chunk_grid=obj["chunk_grid"], + chunk_key_encoding=obj["chunk_key_encoding"], + codecs=obj["codecs"], + fill_value=obj["fill_value"], + attributes=obj.get("attributes"), + dimension_names=obj.get("dimension_names"), + storage_transformers=obj.get("storage_transformers"), + extra_fields=extra_fields or None, # type: ignore[arg-type] + ) def update_shape(self, shape: tuple[int, ...]) -> Self: return replace(self, shape=shape) diff --git a/tests/test_metadata/test_v3.py b/tests/test_metadata/test_v3.py index 01ed921053..81dbede166 100644 --- a/tests/test_metadata/test_v3.py +++ b/tests/test_metadata/test_v3.py @@ -2,7 +2,8 @@ import json import re -from typing import TYPE_CHECKING, Literal +from dataclasses import dataclass +from typing import TYPE_CHECKING, Literal, TypeVar import numpy as np import pytest @@ -10,6 +11,7 @@ from zarr import consolidate_metadata, create_group from zarr.codecs.bytes import BytesCodec from zarr.core.buffer import default_buffer_prototype +from zarr.core.chunk_grids import RegularChunkGrid from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding from zarr.core.config import config from zarr.core.dtype import UInt8, get_data_type_from_native_dtype @@ -19,6 +21,8 @@ from zarr.core.metadata.v3 import ( ArrayMetadataJSON_V3, ArrayV3Metadata, + ChunkGridLike, + check_array_metadata_like, parse_codecs, parse_dimension_names, parse_zarr_format, @@ -31,12 +35,14 @@ ) if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterable, Sequence from typing import Any from zarr.core.types import JSON from zarr.abc.codec import Codec + from zarr.core.chunk_key_encodings import ChunkKeyEncodingLike + from zarr.core.common import DimensionNamesLike from zarr.core.metadata.v3 import ( @@ -461,3 +467,156 @@ def test_group_to_dict(use_consolidated: bool, attributes: None | dict[str, Any] expect = {"node_type": "group", "zarr_format": 3, "attributes": expect_attributes} assert meta.to_dict() == expect + + +TIn = TypeVar("TIn") +TOut = TypeVar("TOut") + + +class _Unset: + """Sentinel for 'key not provided in the input dict'.""" + + +UNSET = _Unset() + + +@dataclass(frozen=True) +class Expect[TIn, TOut]: + """An (input, expected output) pair for parametrized tests.""" + + input: TIn + expected: TOut + + +@dataclass(frozen=True) +class ExpectErr[TIn]: + """An (input, expected exception, msg) triplet for parametrized tests.""" + + input: TIn + err_cls: type[Exception] + msg: str + + +@pytest.mark.parametrize("shape", [Expect((10,), (10,)), Expect([5], (5,))]) +@pytest.mark.parametrize("data_type", [Expect(UInt8(), UInt8())]) +@pytest.mark.parametrize( + "chunk_grid", + [ + Expect( + {"name": "regular", "configuration": {"chunk_shape": (10,)}}, + RegularChunkGrid(chunk_shape=(10,)), + ), + Expect(RegularChunkGrid(chunk_shape=(10,)), RegularChunkGrid(chunk_shape=(10,))), + ], +) +@pytest.mark.parametrize( + "chunk_key_encoding", + [ + Expect( + {"name": "default", "configuration": {"separator": "/"}}, + DefaultChunkKeyEncoding(separator="/"), + ), + Expect(DefaultChunkKeyEncoding(separator="."), DefaultChunkKeyEncoding(separator=".")), + ], +) +@pytest.mark.parametrize("codecs", [Expect((BytesCodec(),), (BytesCodec(endian=None),))]) +@pytest.mark.parametrize( + "fill_value", + [Expect(0, np.uint8(0)), Expect(42, np.uint8(42))], +) +@pytest.mark.parametrize( + "attributes", + [Expect({"key": "val"}, {"key": "val"}), Expect(None, {}), Expect(UNSET, {})], +) +@pytest.mark.parametrize( + "dimension_names", + [Expect(("x",), ("x",)), Expect(None, None), Expect(UNSET, None)], +) +@pytest.mark.parametrize( + "extra_fields", + [ + Expect(UNSET, {}), + Expect({"ext": {"must_understand": False}}, {"ext": {"must_understand": False}}), + ], +) +def test_from_json( + shape: Expect[Iterable[int], tuple[int, ...]], + data_type: Expect[UInt8, UInt8], + chunk_grid: Expect[ChunkGridLike, RegularChunkGrid], + chunk_key_encoding: Expect[ChunkKeyEncodingLike, DefaultChunkKeyEncoding], + codecs: Expect[tuple[Codec, ...], tuple[Codec, ...]], + fill_value: Expect[object, np.uint8], + attributes: Expect[dict[str, JSON] | _Unset, dict[str, JSON]], + dimension_names: Expect[DimensionNamesLike | _Unset, tuple[str | None, ...] | None], + extra_fields: Expect[dict[str, object] | _Unset, dict[str, object]], +) -> None: + """ + Test that ArrayV3Metadata.from_json correctly parses each field. + """ + data: dict[str, object] = { + "shape": shape.input, + "data_type": data_type.input, + "chunk_grid": chunk_grid.input, + "chunk_key_encoding": chunk_key_encoding.input, + "codecs": codecs.input, + "fill_value": fill_value.input, + } + if not isinstance(attributes.input, _Unset): + data["attributes"] = attributes.input + if not isinstance(dimension_names.input, _Unset): + data["dimension_names"] = dimension_names.input + if not isinstance(extra_fields.input, _Unset): + data.update(extra_fields.input) + + result = ArrayV3Metadata.from_json(data) # type: ignore[arg-type] + assert result.shape == shape.expected + assert result.data_type == data_type.expected + assert result.chunk_grid == chunk_grid.expected + assert result.chunk_key_encoding == chunk_key_encoding.expected + assert result.codecs == codecs.expected + assert result.fill_value == fill_value.expected + assert result.attributes == attributes.expected + assert result.dimension_names == dimension_names.expected + assert result.extra_fields == extra_fields.expected + + # narrow + from_json should produce the same result for the same input + assert ArrayV3Metadata.from_json(check_array_metadata_like(data)) == result + + +_VALID_TRY_FROM_JSON_INPUT: dict[str, object] = { + "shape": (10,), + "data_type": "uint8", + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (10,)}}, + "chunk_key_encoding": {"name": "default", "configuration": {"separator": "/"}}, + "codecs": ({"name": "bytes"},), + "fill_value": 0, +} + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr("not a dict", TypeError, "Expected a mapping"), + ExpectErr({}, TypeError, "Missing required keys"), + ExpectErr({**_VALID_TRY_FROM_JSON_INPUT, "shape": "not a shape"}, TypeError, "shape"), + ExpectErr({**_VALID_TRY_FROM_JSON_INPUT, "data_type": 12345}, TypeError, "data_type"), + ExpectErr( + {**_VALID_TRY_FROM_JSON_INPUT, "chunk_grid": "not a mapping"}, TypeError, "chunk_grid" + ), + ExpectErr( + {**_VALID_TRY_FROM_JSON_INPUT, "chunk_key_encoding": "not a mapping"}, + TypeError, + "chunk_key_encoding", + ), + ExpectErr({**_VALID_TRY_FROM_JSON_INPUT, "codecs": "not iterable"}, TypeError, "codecs"), + ExpectErr( + {**_VALID_TRY_FROM_JSON_INPUT, "unknown_field": "value"}, + TypeError, + "Disallowed extra fields", + ), + ], +) +def test_check_array_metadata_like_invalid(case: ExpectErr[object]) -> None: + """Structurally invalid input is rejected with TypeError.""" + with pytest.raises(case.err_cls, match=case.msg): + check_array_metadata_like(case.input)