Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/zarr-v2-compat.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: zarr v2 compatibility

on:
pull_request:
push:
branches: [main, test]

jobs:
test:
name: zarr v2 / Python 3.11 / ubuntu
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: astral-sh/setup-uv@v5

- name: Install dependencies
run: uv sync --python 3.11 --group test

- name: Downgrade zarr to v2
run: uv pip install "zarr>=2.18,<3"

- name: Run tests
run: uv run pytest
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"numpy",
"humanize",
"tskit>=1.0.0",
"zarr>=3.1",
"zarr>=2.18",
]
dynamic = ["version"]

Expand Down
6 changes: 3 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
import numpy as np
import pytest
import tskit
import zarr

import tszip
import tszip.cli as cli
from tszip import _zarr_compat


def get_stdout_for_pytest():
Expand Down Expand Up @@ -265,8 +265,8 @@ def test_chunk_size(self):
assert outpath.exists()
ts = tszip.decompress(outpath)
assert ts.tables == self.ts.tables
store = zarr.storage.ZipStore(str(outpath), mode="r")
root = zarr.open_group(store=store, zarr_format=2, mode="r")
store = _zarr_compat.open_zip_store(outpath, mode="r")
root = _zarr_compat.open_group_for_read(store)
for _, g in root.groups():
for _, a in g.arrays():
assert a.chunks == (20,)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
import numpy as np
import pytest
import tskit
import zarr

import tszip
import tszip.compression as compression
import tszip.exceptions as exceptions
import tszip.provenance as provenance
from tszip import _zarr_compat


class TestMinimalDtype:
Expand Down Expand Up @@ -295,17 +295,17 @@ def setup(self, tmp_path):
def test_format_written(self):
ts = msprime.simulate(10, random_seed=1)
tszip.compress(ts, self.path)
with zarr.storage.ZipStore(str(self.path), mode="r") as store:
root = zarr.open_group(store=store, zarr_format=2, mode="r")
with _zarr_compat.open_zip_store(self.path, mode="r") as store:
root = _zarr_compat.open_group_for_read(store)
assert root.attrs["format_name"] == compression.FORMAT_NAME
assert root.attrs["format_version"] == compression.FORMAT_VERSION

def test_provenance(self):
ts = msprime.simulate(10, random_seed=1)
for variants_only in [True, False]:
tszip.compress(ts, self.path, variants_only=variants_only)
with zarr.storage.ZipStore(str(self.path), mode="r") as store:
root = zarr.open_group(store=store, zarr_format=2, mode="r")
with _zarr_compat.open_zip_store(self.path, mode="r") as store:
root = _zarr_compat.open_group_for_read(store)
assert root.attrs["provenance"] == provenance.get_provenance_dict(
{
"variants_only": variants_only,
Expand All @@ -314,8 +314,8 @@ def test_provenance(self):
)

def write_file(self, attrs, path):
with zarr.storage.ZipStore(str(path), mode="w") as store:
root = zarr.open_group(store=store, zarr_format=2, mode="a")
with _zarr_compat.open_zip_store(path, mode="w") as store:
root = _zarr_compat.open_group_for_write(store)
root.attrs.update(attrs)

def test_missing_format_keys(self):
Expand Down Expand Up @@ -538,8 +538,8 @@ def test_good_chunks(self, tmpdir, chunk_size):
ts2 = tszip.decompress(path)
assert ts1 == ts2

store = zarr.storage.ZipStore(str(path), mode="r")
root = zarr.open_group(store=store, zarr_format=2, mode="r")
store = _zarr_compat.open_zip_store(path, mode="r")
root = _zarr_compat.open_group_for_read(store)
for _, g in root.groups():
for _, a in g.arrays():
assert a.chunks == (chunk_size,)
Expand Down
47 changes: 47 additions & 0 deletions tszip/_zarr_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import zarr

_ZARR_V3 = int(zarr.__version__.split(".")[0]) >= 3


def open_zip_store(path, mode):
"""Open a ZipStore compatible with zarr v2 and v3."""
return zarr.storage.ZipStore(str(path), mode=mode)


def open_group_for_read(store):
"""Open a zarr group for reading in zarr v2 format."""
if _ZARR_V3:
return zarr.open_group(store=store, zarr_format=2, mode="r")
else:
return zarr.open_group(store=store, mode="r")


def open_group_for_write(store):
"""Open a zarr group for writing in zarr v2 format."""
if _ZARR_V3:
return zarr.open_group(store=store, zarr_format=2, mode="a")
else:
return zarr.open_group(store=store, mode="a")


def empty_array(root, name, shape, dtype, chunks, filters, compressor):
"""Create an empty zarr array in zarr v2 format."""
if _ZARR_V3:
return root.empty(
name=name,
shape=shape,
dtype=dtype,
chunks=chunks,
zarr_format=2,
filters=filters,
compressor=compressor,
)
else:
return root.empty(
name=name,
shape=shape,
dtype=dtype,
chunks=chunks,
filters=filters,
compressor=compressor,
)
15 changes: 7 additions & 8 deletions tszip/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@
import numpy as np
import tskit
import zarr
from zarr.storage import ZipStore

from . import exceptions, provenance
from . import _zarr_compat, exceptions, provenance

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,8 +105,8 @@ def compress(ts, destination, variants_only=False, *, chunk_size=None):
with tempfile.TemporaryDirectory(dir=destdir, prefix=".tszip_work_") as tmpdir:
filename = pathlib.Path(tmpdir, "tmp.trees.tgz")
logging.debug(f"Writing to temporary file {filename}")
with ZipStore(filename, mode="w") as store:
root = zarr.open_group(store=store, zarr_format=2, mode="a")
with _zarr_compat.open_zip_store(filename, mode="w") as store:
root = _zarr_compat.open_group_for_write(store)
compress_zarr(ts, root, variants_only=variants_only, chunk_size=chunk_size)
if is_path:
os.replace(filename, destination)
Expand Down Expand Up @@ -151,12 +150,12 @@ def compress(self, root, compressor):
filters = None
if self.delta_filter:
filters = [numcodecs.Delta(dtype=dtype)]
compressed_array = root.empty(
compressed_array = _zarr_compat.empty_array(
root,
name=self.name,
shape=shape,
dtype=dtype,
chunks=self.chunks,
zarr_format=2,
filters=filters,
compressor=compressor,
)
Expand Down Expand Up @@ -296,8 +295,8 @@ def check_format(root):
def load_zarr(path):
path = str(path)
try:
store = ZipStore(path, mode="r")
root = zarr.open_group(store=store, zarr_format=2, mode="r")
store = _zarr_compat.open_zip_store(path, mode="r")
root = _zarr_compat.open_group_for_read(store)
except zipfile.BadZipFile as bzf:
raise exceptions.FileFormatError("File is not in tszip format") from bzf

Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.