diff --git a/changes/3004.feature.md b/changes/3004.feature.md new file mode 100644 index 0000000000..f7289ced67 --- /dev/null +++ b/changes/3004.feature.md @@ -0,0 +1,4 @@ +Optimizes reading multiple chunks from a shard. +Serial calls to `.get()` in the sharding codec have been replaced with +a single call to `.get_partial_values()` which stores may optimize by making +concurrent requests and/or coalescing nearby requests to the same shard. diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index d2ab353d43..427733f890 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -695,6 +695,12 @@ async def get( self, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: ... + async def get_partial_values( + self, + prototype: BufferPrototype, + byte_ranges: Iterable[ByteRequest | None], + ) -> list[Buffer | None]: ... + @runtime_checkable class ByteSetter(Protocol): @@ -702,6 +708,12 @@ async def get( self, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: ... + async def get_partial_values( + self, + prototype: BufferPrototype, + byte_ranges: Iterable[ByteRequest | None], + ) -> list[Buffer | None]: ... + async def set(self, value: Buffer) -> None: ... async def delete(self) -> None: ... diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 85162c2f74..59705de6a1 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -100,6 +100,14 @@ async def get( start, stop = _normalize_byte_range_index(value, byte_range) return value[start:stop] + async def get_partial_values( + self, + prototype: BufferPrototype, + byte_ranges: Iterable[ByteRequest | None], + ) -> list[Buffer | None]: + # Concurrency not needed, each .get() is a slice of an in-memory buffer. + return [await self.get(prototype, byte_range) for byte_range in byte_ranges] + @dataclass(frozen=True) class _ShardingByteSetter(_ShardingByteGetter, ByteSetter): @@ -475,7 +483,7 @@ async def _decode_partial_single( all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks} # reading bytes of all requested chunks - shard_dict: ShardMapping = {} + shard_dict_maybe: ShardMapping | None if self._is_total_shard(all_chunk_coords, chunks_per_shard): # read entire shard shard_dict_maybe = await self._load_full_shard_maybe( @@ -483,24 +491,18 @@ async def _decode_partial_single( prototype=chunk_spec.prototype, chunks_per_shard=chunks_per_shard, ) - if shard_dict_maybe is None: - return None - shard_dict = shard_dict_maybe else: # read some chunks within the shard - shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) - if shard_index is None: - return None - shard_dict = {} - for chunk_coords in all_chunk_coords: - chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords) - if chunk_byte_slice: - chunk_bytes = await byte_getter.get( - prototype=chunk_spec.prototype, - byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]), - ) - if chunk_bytes: - shard_dict[chunk_coords] = chunk_bytes + shard_dict_maybe = await self._load_partial_shard_maybe( + byte_getter, + chunk_spec.prototype, + chunks_per_shard, + all_chunk_coords, + ) + + if shard_dict_maybe is None: + return None + shard_dict = shard_dict_maybe # decoding chunks and writing them into the output buffer await self.codec_pipeline.read( @@ -770,6 +772,43 @@ async def _load_full_shard_maybe( else None ) + async def _load_partial_shard_maybe( + self, + byte_getter: ByteGetter, + prototype: BufferPrototype, + chunks_per_shard: tuple[int, ...], + all_chunk_coords: set[tuple[int, ...]], + ) -> ShardMapping | None: + """ + Read chunks from `byte_getter` for the case where the read is less than a full shard. + Returns a mapping of chunk coordinates to bytes or None. + """ + shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) + if shard_index is None: + return None # shard index read failure, the ByteGetter returned None + + # Build parallel lists of chunk coordinates and byte ranges for non-empty chunks + chunk_coords_list: list[tuple[int, ...]] = [] + byte_ranges: list[RangeByteRequest] = [] + for chunk_coords in all_chunk_coords: + chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords) + if chunk_byte_slice is not None: + chunk_coords_list.append(chunk_coords) + byte_ranges.append(RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1])) + + if not byte_ranges: + return {} + + # Fetch all chunk byte ranges via get_partial_values + buffers = await byte_getter.get_partial_values(prototype, byte_ranges) + + shard_dict: ShardMutableMapping = {} + for chunk_coords, buf in zip(chunk_coords_list, buffers, strict=True): + if buf is not None: + shard_dict[chunk_coords] = buf + + return shard_dict + def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int: chunks_per_shard = self._get_chunks_per_shard(shard_spec) return input_byte_length + self._shard_index_size(chunks_per_shard) diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index 08c05864aa..74028164da 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -33,6 +33,8 @@ FSMap = None if TYPE_CHECKING: + from collections.abc import Iterable + from zarr.core.buffer import BufferPrototype @@ -173,6 +175,30 @@ async def get( prototype = default_buffer_prototype() return await self.store.get(self.path, prototype=prototype, byte_range=byte_range) + async def get_partial_values( + self, + prototype: BufferPrototype, + byte_ranges: Iterable[ByteRequest | None], + ) -> list[Buffer | None]: + """ + Read multiple byte ranges from the store. + + Parameters + ---------- + prototype : BufferPrototype + The buffer prototype to use when reading the bytes. + byte_ranges : Iterable[ByteRequest | None] + The byte ranges to read. + + Returns + ------- + list of Buffer or None + The read bytes for each range, or None for missing keys. + """ + return await self.store.get_partial_values( + prototype, [(self.path, byte_range) for byte_range in byte_ranges] + ) + async def set(self, value: Buffer) -> None: """ Write bytes to the store. diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index d7cbeb5bdb..a22244c95a 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -1,6 +1,7 @@ import pickle import re from typing import Any +from unittest.mock import AsyncMock import numpy as np import numpy.typing as npt @@ -198,6 +199,267 @@ def test_sharding_partial_read( assert np.all(read_data == 1) +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_multiple_chunks_partial_shard_read( + store: Store, + index_location: ShardingCodecIndexLocation, +) -> None: + array_shape = (16, 64) + shard_shape = (8, 32) + chunk_shape = (2, 4) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + a = zarr.create_array( + StorePath(store_mock), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=1, + ) + a[:] = data + + store_mock.reset_mock() # ignore store calls during array creation + + # Reads 3 (2 full, 1 partial) chunks each from 2 shards (a subset of both shards) + # for a total of 6 chunks accessed + assert np.allclose(a[0, 22:42], np.arange(22, 42, dtype="float32")) + + # 2 shard index reads via store.get() + 2 get_partial_values calls (one per shard) + assert store_mock.get.call_count == 2 + assert store_mock.get_partial_values.call_count == 2 + + store_mock.reset_mock() + + # Reads 4 chunks from both shards along dimension 0 for a total of 8 chunks accessed + assert np.allclose(a[:, 0], np.arange(0, data.size, array_shape[1], dtype="float32")) + + # 2 shard index reads via store.get() + 2 get_partial_values calls (one per shard) + assert store_mock.get.call_count == 2 + assert store_mock.get_partial_values.call_count == 2 + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_duplicate_read_indexes( + store: Store, + index_location: ShardingCodecIndexLocation, +) -> None: + """ + Check that duplicate index reads are handled correctly when + using get_partial_values for chunk data. + """ + array_shape = (15,) + shard_shape = (8,) + chunk_shape = (2,) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + a = zarr.create_array( + StorePath(store_mock), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=-1, + ) + a[:] = data + + store_mock.reset_mock() # ignore store calls during array creation + + # Read the same index multiple times from two chunks + indexer = [8, 8, 12, 12] + np.array_equal(a[indexer], data[indexer]) + + # 1 shard index read via store.get() + 1 get_partial_values call + assert store_mock.get.call_count == 1 + assert store_mock.get_partial_values.call_count == 1 + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_read_empty_chunks_within_non_empty_shard_write_empty_false( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """ + Case where + - some, but not all, chunks in the last shard are empty + - the last shard is not complete (array length is not a multiple of shard shape), + this takes us down the partial shard read path + - write_empty_chunks=False so the shard index will have fewer entries than chunks in the shard + """ + # array with mixed empty and non-empty chunks in second shard + data = np.array([ + # shard 0. full 8 elements, all chunks have some non-fill data + 0, 1, 2, 3, 4, 5, 6, 7, + # shard 1. 6 elements (< shard shape) + 2, 0, # chunk 0, written + -9, -9, # chunk 1, all fill, not written + 4, 5 # chunk 2, written + ], dtype="int32") # fmt: off + + spath = StorePath(store) + a = zarr.create_array( + spath, + shape=(14,), + chunks=(2,), + shards={"shape": (8,), "index_location": index_location}, + dtype="int32", + fill_value=-9, + filters=None, + compressors=None, + config={"write_empty_chunks": False}, + ) + a[:] = data + + assert np.array_equal(a[:], data) + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_read_empty_chunks_within_empty_shard_write_empty_false( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """ + Case where + - all chunks in last shard are empty + - the last shard is not complete (array length is not a multiple of shard shape), + this takes us down the partial shard read path + - write_empty_chunks=False so the shard index will have no entries + """ + fill_value = -99 + shard_size = 8 + data = np.arange(14, dtype="int32") + data[shard_size:] = fill_value # 2nd shard is all fill value + + spath = StorePath(store) + a = zarr.create_array( + spath, + shape=(14,), + chunks=(2,), + shards={"shape": (shard_size,), "index_location": index_location}, + dtype="int32", + fill_value=fill_value, + filters=None, + compressors=None, + config={"write_empty_chunks": False}, + ) + a[:] = data + + assert np.array_equal(a[:], data) + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_partial_shard_read__index_load_fails( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """Test fill value is returned when the call to the store to load the bytes of the shard's chunk index fails.""" + array_shape = (16,) + shard_shape = (16,) + chunk_shape = (8,) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + fill_value = -999 + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + # loading the index is the first call to .get() so returning None will simulate an index load failure + store_mock.get.return_value = None + + a = zarr.create_array( + StorePath(store_mock), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=fill_value, + ) + a[:] = data + + # Read from one of two chunks in a shard to test the partial shard read path + assert a[0] == fill_value + assert a[0] != data[0] + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_partial_shard_read__index_chunk_slice_fails( + store: Store, + index_location: ShardingCodecIndexLocation, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test fill value is returned when looking up a chunk's byte slice within a shard fails.""" + array_shape = (16,) + shard_shape = (16,) + chunk_shape = (8,) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + fill_value = -999 + + monkeypatch.setattr( + "zarr.codecs.sharding._ShardIndex.get_chunk_slice", + lambda self, chunk_coords: None, + ) + + a = zarr.create_array( + StorePath(store), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=fill_value, + ) + a[:] = data + + # Read from one of two chunks in a shard to test the partial shard read path + assert a[0] == fill_value + assert a[0] != data[0] + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_partial_shard_read__chunk_load_fails( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """Test fill value is returned when the call to the store to load a chunk's bytes fails.""" + array_shape = (16,) + shard_shape = (16,) + chunk_shape = (8,) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + fill_value = -999 + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + + a = zarr.create_array( + StorePath(store_mock), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=fill_value, + ) + a[:] = data + + # Set up store mock after array creation to simulate chunk load failure. + # Index loads still succeed (via store.get), but chunk data loads fail + # (via store.get_partial_values returning None for each range). + store_mock.reset_mock() + + async def fail_chunk_reads(prototype: Any, key_ranges: Any, **kwargs: Any) -> list[None]: + return [None] * len(list(key_ranges)) + + store_mock.get_partial_values.side_effect = fail_chunk_reads + + # Read from one of two chunks in a shard to test the partial shard read path + assert a[0] == fill_value + assert a[0] != data[0] + + @pytest.mark.parametrize( "array_fixture", [ diff --git a/tests/test_codecs/test_sharding_unit.py b/tests/test_codecs/test_sharding_unit.py new file mode 100644 index 0000000000..f282a9310d --- /dev/null +++ b/tests/test_codecs/test_sharding_unit.py @@ -0,0 +1,625 @@ +from collections.abc import Iterable +from dataclasses import dataclass +from typing import cast + +import numpy as np +import pytest + +from zarr.abc.store import ByteGetter, ByteRequest +from zarr.codecs.sharding import ( + MAX_UINT_64, + ShardingCodec, + _ShardIndex, + _ShardReader, +) +from zarr.core.buffer import BufferPrototype, default_buffer_prototype +from zarr.core.buffer.cpu import Buffer + +# ============================================================================ +# _ShardIndex tests +# ============================================================================ + + +def test_shard_index_create_empty() -> None: + """Test that create_empty creates an index filled with MAX_UINT_64.""" + chunks_per_shard = (2, 3) + index = _ShardIndex.create_empty(chunks_per_shard) + + assert index.chunks_per_shard == chunks_per_shard + assert index.offsets_and_lengths.shape == (2, 3, 2) + assert index.offsets_and_lengths.dtype == np.dtype(" None: + """Test create_empty with 1D chunks_per_shard.""" + chunks_per_shard = (4,) + index = _ShardIndex.create_empty(chunks_per_shard) + + assert index.chunks_per_shard == chunks_per_shard + assert index.offsets_and_lengths.shape == (4, 2) + + +def test_shard_index_is_all_empty_true() -> None: + """Test is_all_empty returns True for a freshly created empty index.""" + index = _ShardIndex.create_empty((2, 2)) + assert index.is_all_empty() is True + + +def test_shard_index_is_all_empty_false() -> None: + """Test is_all_empty returns False when at least one chunk is set.""" + index = _ShardIndex.create_empty((2, 2)) + index.set_chunk_slice((0, 0), slice(0, 100)) + assert index.is_all_empty() is False + + +def test_shard_index_get_chunk_slice_empty() -> None: + """Test get_chunk_slice returns None for empty chunks.""" + index = _ShardIndex.create_empty((2, 2)) + assert index.get_chunk_slice((0, 0)) is None + assert index.get_chunk_slice((1, 1)) is None + + +def test_shard_index_get_chunk_slice_set() -> None: + """Test get_chunk_slice returns correct (start, end) tuple after setting.""" + index = _ShardIndex.create_empty((2, 2)) + index.set_chunk_slice((0, 1), slice(100, 200)) + + result = index.get_chunk_slice((0, 1)) + assert result == (100, 200) + + +def test_shard_index_set_chunk_slice() -> None: + """Test set_chunk_slice correctly sets offset and length.""" + index = _ShardIndex.create_empty((3, 3)) + + # Set a chunk slice + index.set_chunk_slice((1, 2), slice(50, 150)) + + # Verify the underlying array + assert index.offsets_and_lengths[1, 2, 0] == 50 # offset + assert index.offsets_and_lengths[1, 2, 1] == 100 # length (150 - 50) + + +def test_shard_index_set_chunk_slice_none() -> None: + """Test set_chunk_slice with None marks chunk as empty.""" + index = _ShardIndex.create_empty((2, 2)) + + # First set a value + index.set_chunk_slice((0, 0), slice(0, 100)) + assert index.get_chunk_slice((0, 0)) == (0, 100) + + # Then clear it + index.set_chunk_slice((0, 0), None) + assert index.get_chunk_slice((0, 0)) is None + assert index.offsets_and_lengths[0, 0, 0] == MAX_UINT_64 + assert index.offsets_and_lengths[0, 0, 1] == MAX_UINT_64 + + +def test_shard_index_get_full_chunk_map() -> None: + """Test get_full_chunk_map returns correct boolean array.""" + index = _ShardIndex.create_empty((2, 3)) + + # Set some chunks + index.set_chunk_slice((0, 0), slice(0, 10)) + index.set_chunk_slice((1, 2), slice(10, 20)) + + chunk_map = index.get_full_chunk_map() + + assert chunk_map.shape == (2, 3) + assert chunk_map.dtype == np.bool_ + assert chunk_map[0, 0] is np.True_ + assert chunk_map[0, 1] is np.False_ + assert chunk_map[0, 2] is np.False_ + assert chunk_map[1, 0] is np.False_ + assert chunk_map[1, 1] is np.False_ + assert chunk_map[1, 2] is np.True_ + + +def test_shard_index_localize_chunk() -> None: + """Test _localize_chunk maps global coords to local shard coords via modulo.""" + index = _ShardIndex.create_empty((2, 3)) + + # Within bounds - should return same coords + assert index._localize_chunk((0, 0)) == (0, 0) + assert index._localize_chunk((1, 2)) == (1, 2) + + # Out of bounds - should wrap via modulo + assert index._localize_chunk((2, 0)) == (0, 0) # 2 % 2 = 0 + assert index._localize_chunk((3, 5)) == (1, 2) # 3 % 2 = 1, 5 % 3 = 2 + assert index._localize_chunk((4, 6)) == (0, 0) # 4 % 2 = 0, 6 % 3 = 0 + + +def test_shard_index_is_dense_true() -> None: + """Test is_dense returns True when chunks are contiguously packed.""" + index = _ShardIndex.create_empty((2,)) + chunk_byte_length = 100 + + # Set chunks contiguously: [0-100), [100-200) + index.set_chunk_slice((0,), slice(0, 100)) + index.set_chunk_slice((1,), slice(100, 200)) + + assert index.is_dense(chunk_byte_length) is True + + +def test_shard_index_is_dense_false_duplicate_offsets() -> None: + """Test is_dense returns False when chunks have duplicate offsets.""" + index = _ShardIndex.create_empty((2,)) + chunk_byte_length = 100 + + # Set both chunks to same offset (duplicate) + index.set_chunk_slice((0,), slice(0, 100)) + index.set_chunk_slice((1,), slice(0, 100)) + + assert index.is_dense(chunk_byte_length) is False + + +def test_shard_index_is_dense_false_wrong_alignment() -> None: + """Test is_dense returns False when chunks are not aligned to chunk_byte_length.""" + index = _ShardIndex.create_empty((2,)) + chunk_byte_length = 100 + + # Set chunks not aligned: [0-100), [150-250) + index.set_chunk_slice((0,), slice(0, 100)) + index.set_chunk_slice((1,), slice(150, 250)) + + assert index.is_dense(chunk_byte_length) is False + + +def test_shard_index_is_dense_with_empty_chunks() -> None: + """Test is_dense handles empty chunks correctly.""" + index = _ShardIndex.create_empty((3,)) + chunk_byte_length = 100 + + # Only set first and third chunk, skip middle + index.set_chunk_slice((0,), slice(0, 100)) + # (1,) is empty + index.set_chunk_slice((2,), slice(100, 200)) + + # Should still be dense since only non-empty chunks are considered + assert index.is_dense(chunk_byte_length) is True + + +# ============================================================================ +# _ShardingByteGetter.get_partial_values tests +# ============================================================================ + + +async def test_sharding_byte_getter_get_partial_values_returns_slices() -> None: + """Test that get_partial_values returns correct slices from the shard dict.""" + from zarr.codecs.sharding import _ShardingByteGetter + + chunk_data = Buffer.from_bytes(b"AAAABBBB") + shard_dict: dict[tuple[int, ...], Buffer | None] = {(0,): chunk_data} + getter = _ShardingByteGetter(shard_dict, (0,)) + + from zarr.abc.store import RangeByteRequest + + results = await getter.get_partial_values( + default_buffer_prototype(), + [RangeByteRequest(0, 4), RangeByteRequest(4, 8)], + ) + + assert len(results) == 2 + assert results[0] is not None + assert results[0].as_numpy_array().tobytes() == b"AAAA" + assert results[1] is not None + assert results[1].as_numpy_array().tobytes() == b"BBBB" + + +async def test_sharding_byte_getter_get_partial_values_missing_chunk() -> None: + """Test that get_partial_values returns None for a missing chunk.""" + from zarr.codecs.sharding import _ShardingByteGetter + + shard_dict: dict[tuple[int, ...], Buffer | None] = {} + getter = _ShardingByteGetter(shard_dict, (0,)) + + from zarr.abc.store import RangeByteRequest + + results = await getter.get_partial_values( + default_buffer_prototype(), + [RangeByteRequest(0, 10)], + ) + + assert results == [None] + + +# ============================================================================ +# StorePath.get_partial_values tests +# ============================================================================ + + +async def test_store_path_get_partial_values() -> None: + """Test that StorePath.get_partial_values delegates to Store.get_partial_values.""" + from zarr.abc.store import RangeByteRequest + from zarr.storage._common import StorePath + from zarr.storage._memory import MemoryStore + + store = MemoryStore() + await store.set("key", Buffer.from_bytes(b"0123456789")) + path = StorePath(store, "key") + + results = await path.get_partial_values( + default_buffer_prototype(), + [RangeByteRequest(0, 3), RangeByteRequest(7, 10)], + ) + + assert len(results) == 2 + assert results[0] is not None + assert results[0].as_numpy_array().tobytes() == b"012" + assert results[1] is not None + assert results[1].as_numpy_array().tobytes() == b"789" + + +async def test_store_path_get_partial_values_missing_key() -> None: + """Test that StorePath.get_partial_values returns None for a missing key.""" + from zarr.abc.store import RangeByteRequest + from zarr.storage._common import StorePath + from zarr.storage._memory import MemoryStore + + store = MemoryStore() + path = StorePath(store, "nonexistent") + + results = await path.get_partial_values( + default_buffer_prototype(), + [RangeByteRequest(0, 10)], + ) + + assert results == [None] + + +# ============================================================================ +# Mock ByteGetter for _load_partial_shard_maybe tests +# ============================================================================ + + +@dataclass +class MockByteGetter: + """Mock ByteGetter for testing.""" + + data: bytes + return_none: bool = False + get_call_count: int = 0 + get_partial_values_call_count: int = 0 + + async def get( + self, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> Buffer | None: + self.get_call_count += 1 + if self.return_none: + return None + if byte_range is None: + return Buffer.from_bytes(self.data) + # For RangeByteRequest, extract start and end + start = getattr(byte_range, "start", 0) + end = getattr(byte_range, "end", len(self.data)) + return Buffer.from_bytes(self.data[start:end]) + + async def get_partial_values( + self, prototype: BufferPrototype, byte_ranges: Iterable[ByteRequest | None] + ) -> list[Buffer | None]: + self.get_partial_values_call_count += 1 + return [await self.get(prototype, br) for br in byte_ranges] + + +@dataclass +class MockByteGetterWithIndex: + """Mock ByteGetter that returns index on first get() and chunk data on get_partial_values().""" + + index_data: bytes | None + chunk_data: bytes | None + get_call_count: int = 0 + get_partial_values_call_count: int = 0 + return_none_for_chunks: bool = False + + async def get( + self, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> Buffer | None: + self.get_call_count += 1 + if self.index_data is None: + return None + return Buffer.from_bytes(self.index_data) + + async def get_partial_values( + self, prototype: BufferPrototype, byte_ranges: Iterable[ByteRequest | None] + ) -> list[Buffer | None]: + self.get_partial_values_call_count += 1 + if self.return_none_for_chunks or self.chunk_data is None: + return [None for _ in byte_ranges] + results: list[Buffer | None] = [] + for br in byte_ranges: + if br is None: + results.append(Buffer.from_bytes(self.chunk_data)) + else: + start = getattr(br, "start", 0) + end = getattr(br, "end", len(self.chunk_data)) + results.append(Buffer.from_bytes(self.chunk_data[start:end])) + return results + + +# ============================================================================ +# _load_partial_shard_maybe tests +# ============================================================================ + + +async def test_load_partial_shard_maybe_index_load_fails() -> None: + """Test _load_partial_shard_maybe returns None when index load fails.""" + codec = ShardingCodec(chunk_shape=(8,)) + byte_getter = cast(ByteGetter, MockByteGetterWithIndex(index_data=None, chunk_data=None)) + + chunks_per_shard = (2,) + all_chunk_coords: set[tuple[int, ...]] = {(0,)} + + result = await codec._load_partial_shard_maybe( + byte_getter=byte_getter, + prototype=default_buffer_prototype(), + chunks_per_shard=chunks_per_shard, + all_chunk_coords=all_chunk_coords, + ) + + assert result is None + + +async def test_load_partial_shard_maybe_with_empty_chunks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test _load_partial_shard_maybe skips chunks where get_chunk_slice returns None.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (4,) + + # Create an index where chunk (1,) is empty (returns None from get_chunk_slice) + index = _ShardIndex.create_empty(chunks_per_shard) + index.set_chunk_slice((0,), slice(0, 100)) + # (1,) is intentionally left empty + index.set_chunk_slice((2,), slice(100, 200)) + index.set_chunk_slice((3,), slice(200, 300)) + + async def mock_load_index( + self: ShardingCodec, byte_getter: MockByteGetter, cps: tuple[int, ...] + ) -> _ShardIndex: + return index + + monkeypatch.setattr(ShardingCodec, "_load_shard_index_maybe", mock_load_index) + + chunk_data = b"x" * 300 + byte_getter = cast(ByteGetter, MockByteGetter(data=chunk_data)) + + # Request chunks including the empty one + all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,)} + + result = await codec._load_partial_shard_maybe( + byte_getter=byte_getter, + prototype=default_buffer_prototype(), + chunks_per_shard=chunks_per_shard, + all_chunk_coords=all_chunk_coords, + ) + + assert result is not None + # Only chunks (0,) and (2,) should be in result, (1,) is empty and skipped + assert (0,) in result + assert (1,) not in result # Empty chunk should be skipped + assert (2,) in result + + +async def test_load_partial_shard_maybe_all_chunks_empty( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test _load_partial_shard_maybe returns empty dict when all requested chunks are empty.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (4,) + + # Create an empty index (all chunks empty) + index = _ShardIndex.create_empty(chunks_per_shard) + + async def mock_load_index( + self: ShardingCodec, byte_getter: MockByteGetter, cps: tuple[int, ...] + ) -> _ShardIndex: + return index + + monkeypatch.setattr(ShardingCodec, "_load_shard_index_maybe", mock_load_index) + + byte_getter = cast(ByteGetter, MockByteGetter(data=b"")) + + # Request some chunks - all will be empty + all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,)} + + result = await codec._load_partial_shard_maybe( + byte_getter=byte_getter, + prototype=default_buffer_prototype(), + chunks_per_shard=chunks_per_shard, + all_chunk_coords=all_chunk_coords, + ) + + assert result is not None + assert result == {} # All chunks were empty, so result is empty dict + + +async def test_load_partial_shard_uses_get_partial_values( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that _load_partial_shard_maybe uses get_partial_values for chunk reads.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (4,) + + index = _ShardIndex.create_empty(chunks_per_shard) + index.set_chunk_slice((0,), slice(0, 100)) + index.set_chunk_slice((1,), slice(100, 200)) + + async def mock_load_index( + self: ShardingCodec, byte_getter: MockByteGetter, cps: tuple[int, ...] + ) -> _ShardIndex: + return index + + monkeypatch.setattr(ShardingCodec, "_load_shard_index_maybe", mock_load_index) + + chunk_data = b"A" * 100 + b"B" * 100 + mock_getter = MockByteGetter(data=chunk_data) + byte_getter = cast(ByteGetter, mock_getter) + + all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,)} + + result = await codec._load_partial_shard_maybe( + byte_getter=byte_getter, + prototype=default_buffer_prototype(), + chunks_per_shard=chunks_per_shard, + all_chunk_coords=all_chunk_coords, + ) + + assert result is not None + assert len(result) == 2 + assert (0,) in result + assert (1,) in result + + # get_partial_values should have been called exactly once + assert mock_getter.get_partial_values_call_count == 1 + + +async def test_load_partial_shard_single_chunk_read( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test single chunk read (most common case for single element access).""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (4,) + + index = _ShardIndex.create_empty(chunks_per_shard) + index.set_chunk_slice((1,), slice(100, 200)) + + async def mock_load_index( + self: ShardingCodec, byte_getter: MockByteGetter, cps: tuple[int, ...] + ) -> _ShardIndex: + return index + + monkeypatch.setattr(ShardingCodec, "_load_shard_index_maybe", mock_load_index) + + chunk_data = b"\x00" * 100 + b"E" * 100 + byte_getter = cast(ByteGetter, MockByteGetter(data=chunk_data)) + + all_chunk_coords: set[tuple[int, ...]] = {(1,)} + + result = await codec._load_partial_shard_maybe( + byte_getter=byte_getter, + prototype=default_buffer_prototype(), + chunks_per_shard=chunks_per_shard, + all_chunk_coords=all_chunk_coords, + ) + + assert result is not None + assert (1,) in result + assert len(result) == 1 + + chunk1 = result[(1,)] + assert chunk1 is not None + assert chunk1.as_numpy_array().tobytes() == b"E" * 100 + + +async def test_load_partial_shard_chunk_load_returns_none( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that chunks are omitted from result when get_partial_values returns None.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (4,) + + index = _ShardIndex.create_empty(chunks_per_shard) + index.set_chunk_slice((0,), slice(0, 100)) + + async def mock_load_index( + self: ShardingCodec, byte_getter: MockByteGetterWithIndex, cps: tuple[int, ...] + ) -> _ShardIndex: + return index + + monkeypatch.setattr(ShardingCodec, "_load_shard_index_maybe", mock_load_index) + + byte_getter = cast( + ByteGetter, + MockByteGetterWithIndex(index_data=b"", chunk_data=None, return_none_for_chunks=True), + ) + + all_chunk_coords: set[tuple[int, ...]] = {(0,)} + + result = await codec._load_partial_shard_maybe( + byte_getter=byte_getter, + prototype=default_buffer_prototype(), + chunks_per_shard=chunks_per_shard, + all_chunk_coords=all_chunk_coords, + ) + + assert result is not None + assert len(result) == 0 + + +# ============================================================================ +# Supporting class tests (_ShardReader, _is_total_shard) +# ============================================================================ + + +def test_shard_reader_create_empty() -> None: + """Test _ShardReader.create_empty creates reader with empty index.""" + chunks_per_shard = (2, 3) + reader = _ShardReader.create_empty(chunks_per_shard) + + assert reader.index.is_all_empty() + assert len(reader.buf) == 0 + assert len(reader) == 2 * 3 + + +def test_shard_reader_iteration() -> None: + """Test _ShardReader iteration yields all chunk coordinates.""" + chunks_per_shard = (2, 2) + reader = _ShardReader.create_empty(chunks_per_shard) + + coords = list(reader) + + assert len(coords) == 4 + assert (0, 0) in coords + assert (0, 1) in coords + assert (1, 0) in coords + assert (1, 1) in coords + + +def test_shard_reader_getitem_raises_for_empty() -> None: + """Test _ShardReader.__getitem__ raises KeyError for empty chunks.""" + chunks_per_shard = (2,) + reader = _ShardReader.create_empty(chunks_per_shard) + + with pytest.raises(KeyError): + _ = reader[(0,)] + + +def test_is_total_shard_full() -> None: + """Test _is_total_shard returns True when all chunk coords are present.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (2, 2) + all_chunk_coords: set[tuple[int, ...]] = {(0, 0), (0, 1), (1, 0), (1, 1)} + + assert codec._is_total_shard(all_chunk_coords, chunks_per_shard) is True + + +def test_is_total_shard_partial() -> None: + """Test _is_total_shard returns False for partial chunk coords.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (2, 2) + all_chunk_coords: set[tuple[int, ...]] = {(0, 0), (1, 1)} # Missing (0, 1) and (1, 0) + + assert codec._is_total_shard(all_chunk_coords, chunks_per_shard) is False + + +def test_is_total_shard_empty() -> None: + """Test _is_total_shard returns False for empty chunk coords.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (2, 2) + all_chunk_coords: set[tuple[int, ...]] = set() + + assert codec._is_total_shard(all_chunk_coords, chunks_per_shard) is False + + +def test_is_total_shard_1d() -> None: + """Test _is_total_shard works with 1D shards.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (4,) + all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,), (3,)} + + assert codec._is_total_shard(all_chunk_coords, chunks_per_shard) is True + + # Partial + partial_coords: set[tuple[int, ...]] = {(0,), (2,)} + assert codec._is_total_shard(partial_coords, chunks_per_shard) is False