Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3907.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add protocols for stores that support byte-range-writes. This is necessary to support in-place writes of sharded arrays.
18 changes: 18 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"Store",
"SupportsDeleteSync",
"SupportsGetSync",
"SupportsSetRange",
"SupportsSetSync",
"SupportsSyncStore",
"set_or_delete",
Expand Down Expand Up @@ -709,6 +710,23 @@ async def delete(self) -> None: ...
async def set_if_not_exists(self, default: Buffer) -> None: ...


@runtime_checkable
class SupportsSetRange(Protocol):
"""Protocol for stores that support writing to a byte range within an existing value.

Overwrites ``len(value)`` bytes starting at byte offset ``start`` within the
existing stored value for ``key``. The key must already exist and the write
must fit within the existing value (i.e., ``start + len(value) <= len(existing)``).

Behavior when the write extends past the end of the existing value is
implementation-specific and should not be relied upon.
"""

async def set_range(self, key: str, value: Buffer, start: int) -> None: ...

def set_range_sync(self, key: str, value: Buffer, start: int) -> None: ...


@runtime_checkable
class SupportsGetSync(Protocol):
def get_sync(
Expand Down
23 changes: 22 additions & 1 deletion src/zarr/storage/_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
RangeByteRequest,
Store,
SuffixByteRequest,
SupportsSetRange,
)
from zarr.core.buffer import Buffer
from zarr.core.buffer.core import default_buffer_prototype
Expand Down Expand Up @@ -77,6 +78,13 @@ def _atomic_write(
raise


def _put_range(path: Path, value: Buffer, start: int) -> None:
"""Write bytes at a specific offset within an existing file."""
with path.open("r+b") as f:
f.seek(start)
f.write(value.as_numpy_array().tobytes())


def _put(path: Path, value: Buffer, exclusive: bool = False) -> int:
path.parent.mkdir(parents=True, exist_ok=True)
# write takes any object supporting the buffer protocol
Expand All @@ -85,7 +93,7 @@ def _put(path: Path, value: Buffer, exclusive: bool = False) -> int:
return f.write(view)


class LocalStore(Store):
class LocalStore(Store, SupportsSetRange):
"""
Store for the local file system.

Expand Down Expand Up @@ -292,6 +300,19 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None:
path = self.root / key
await asyncio.to_thread(_put, path, value, exclusive=exclusive)

async def set_range(self, key: str, value: Buffer, start: int) -> None:
if not self._is_open:
await self._open()
self._check_writable()
path = self.root / key
await asyncio.to_thread(_put_range, path, value, start)

def set_range_sync(self, key: str, value: Buffer, start: int) -> None:
self._ensure_open_sync()
self._check_writable()
path = self.root / key
_put_range(path, value, start)

async def delete(self, key: str) -> None:
"""
Remove a key from the store.
Expand Down
24 changes: 22 additions & 2 deletions src/zarr/storage/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from logging import getLogger
from typing import TYPE_CHECKING, Any, Self

from zarr.abc.store import ByteRequest, Store
from zarr.abc.store import ByteRequest, Store, SupportsSetRange
from zarr.core.buffer import Buffer, gpu
from zarr.core.buffer.core import default_buffer_prototype
from zarr.core.common import concurrent_map
Expand All @@ -18,7 +18,7 @@
logger = getLogger(__name__)


class MemoryStore(Store):
class MemoryStore(Store, SupportsSetRange):
"""
Store for local memory.

Expand Down Expand Up @@ -186,6 +186,26 @@ async def delete(self, key: str) -> None:
except KeyError:
logger.debug("Key %s does not exist.", key)

def _set_range_impl(self, key: str, value: Buffer, start: int) -> None:
buf = self._store_dict[key]
target = buf.as_numpy_array()
if not target.flags.writeable:
target = target.copy()
self._store_dict[key] = buf.__class__(target)
source = value.as_numpy_array()
target[start : start + len(source)] = source

async def set_range(self, key: str, value: Buffer, start: int) -> None:
self._check_writable()
await self._ensure_open()
self._set_range_impl(key, value, start)

def set_range_sync(self, key: str, value: Buffer, start: int) -> None:
self._check_writable()
if not self._is_open:
self._is_open = True
self._set_range_impl(key, value, start)

async def list(self) -> AsyncIterator[str]:
# docstring inherited
for key in self._store_dict:
Expand Down
49 changes: 49 additions & 0 deletions tests/test_store/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import zarr
from zarr import create_array
from zarr.abc.store import SupportsSetRange
from zarr.core.buffer import Buffer, cpu
from zarr.core.sync import sync
from zarr.storage import LocalStore
Expand Down Expand Up @@ -162,6 +163,54 @@ def test_get_json_sync_with_prototype_none(
result = store._get_json_sync(key, prototype=buffer_cls)
assert result == data

def test_supports_set_range(self, store: LocalStore) -> None:
"""LocalStore should implement SupportsSetRange."""
assert isinstance(store, SupportsSetRange)

@pytest.mark.parametrize(
("start", "patch", "expected"),
[
(0, b"XX", b"XXAAAAAAAA"),
(3, b"XX", b"AAAXXAAAAA"),
(8, b"XX", b"AAAAAAAAXX"),
(0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"),
(5, b"B", b"AAAAABAAAA"),
(0, b"BCDE", b"BCDEAAAAAA"),
],
ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"],
)
async def test_set_range(
self, store: LocalStore, start: int, patch: bytes, expected: bytes
) -> None:
"""set_range should overwrite bytes at the given offset."""
await store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))
await store.set_range("test/key", cpu.Buffer.from_bytes(patch), start=start)
result = await store.get("test/key", prototype=cpu.buffer_prototype)
assert result is not None
assert result.to_bytes() == expected

@pytest.mark.parametrize(
("start", "patch", "expected"),
[
(0, b"XX", b"XXAAAAAAAA"),
(3, b"XX", b"AAAXXAAAAA"),
(8, b"XX", b"AAAAAAAAXX"),
(0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"),
(5, b"B", b"AAAAABAAAA"),
(0, b"BCDE", b"BCDEAAAAAA"),
],
ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"],
)
def test_set_range_sync(
self, store: LocalStore, start: int, patch: bytes, expected: bytes
) -> None:
"""set_range_sync should overwrite bytes at the given offset."""
sync(store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")))
store.set_range_sync("test/key", cpu.Buffer.from_bytes(patch), start=start)
result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype)
assert result is not None
assert result.to_bytes() == expected


@pytest.mark.parametrize("exclusive", [True, False])
def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None:
Expand Down
50 changes: 50 additions & 0 deletions tests/test_store/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

import zarr
from zarr.abc.store import SupportsSetRange
from zarr.core.buffer import Buffer, cpu, gpu
from zarr.core.sync import sync
from zarr.errors import ZarrUserWarning
Expand Down Expand Up @@ -127,6 +128,55 @@ def test_get_json_sync_with_prototype_none(
result = store._get_json_sync(key, prototype=buffer_cls)
assert result == data

def test_supports_set_range(self, store: MemoryStore) -> None:
"""MemoryStore should implement SupportsSetRange."""
assert isinstance(store, SupportsSetRange)

@pytest.mark.parametrize(
("start", "patch", "expected"),
[
(0, b"XX", b"XXAAAAAAAA"),
(3, b"XX", b"AAAXXAAAAA"),
(8, b"XX", b"AAAAAAAAXX"),
(0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"),
(5, b"B", b"AAAAABAAAA"),
(0, b"BCDE", b"BCDEAAAAAA"),
],
ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"],
)
async def test_set_range(
self, store: MemoryStore, start: int, patch: bytes, expected: bytes
) -> None:
"""set_range should overwrite bytes at the given offset."""
await store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))
await store.set_range("test/key", cpu.Buffer.from_bytes(patch), start=start)
result = await store.get("test/key", prototype=cpu.buffer_prototype)
assert result is not None
assert result.to_bytes() == expected

@pytest.mark.parametrize(
("start", "patch", "expected"),
[
(0, b"XX", b"XXAAAAAAAA"),
(3, b"XX", b"AAAXXAAAAA"),
(8, b"XX", b"AAAAAAAAXX"),
(0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"),
(5, b"B", b"AAAAABAAAA"),
(0, b"BCDE", b"BCDEAAAAAA"),
],
ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"],
)
def test_set_range_sync(
self, store: MemoryStore, start: int, patch: bytes, expected: bytes
) -> None:
"""set_range_sync should overwrite bytes at the given offset."""
store._is_open = True
store._store_dict["test/key"] = cpu.Buffer.from_bytes(b"AAAAAAAAAA")
store.set_range_sync("test/key", cpu.Buffer.from_bytes(patch), start=start)
result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype)
assert result is not None
assert result.to_bytes() == expected


# TODO: fix this warning
@pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning")
Expand Down
Loading