diff --git a/README.md b/README.md index 3b8dec5..52e54a6 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,11 @@ with open("data.im", "wb") as f: # You can skip computing or checking CRCs, e.g. if your # embedded object already contains CRCs mb = MapBuffer(..., check_crc=False, compute_crc=False) + +# If your access pattern is such that the index and the +# download are similar in size (e.g. watershed meshes) +# you can cache the index. +mb = MapBuffer(..., index_cache="/tmp/helloworld.mbi") ``` ## Installation diff --git a/automated_test.py b/automated_test.py index 2fac50d..d9d7b6c 100644 --- a/automated_test.py +++ b/automated_test.py @@ -3,11 +3,14 @@ import mmap import os import random +from unittest.mock import patch import numpy as np from mapbuffer import ValidationError, IntMap, MapBuffer, HEADER_LENGTH +CACHE_PATH = "./test_index_cache.mbi" + @pytest.mark.parametrize("compress", (None, "gzip", "br", "zstd", "lzma")) def test_empty(compress): mbuf = MapBuffer({}, compress=compress) @@ -248,6 +251,96 @@ def test_set_object_intmap(): except KeyError: pass +@pytest.fixture(autouse=True) +def cleanup_cache(): + """Ensure cache file is removed before and after each test.""" + if os.path.exists(CACHE_PATH): + os.remove(CACHE_PATH) + yield + if os.path.exists(CACHE_PATH): + os.remove(CACHE_PATH) + + +def make_mapbuffer(data=None, **kwargs): + data = data or {1: b"hello", 2: b"world"} + return MapBuffer(data, index_cache=CACHE_PATH, **kwargs) + + +def test_index_cache_file_is_created(): + """Cache file should be written after first access.""" + mbuf = make_mapbuffer() + mbuf.index() + assert os.path.exists(CACHE_PATH) + + +def test_index_cache_header_and_index_written(): + """Cache file should contain header + full index bytes.""" + mbuf = make_mapbuffer() + index = mbuf.index() + + with open(CACHE_PATH, "rb") as f: + cached = f.read() + + assert len(cached) == HEADER_LENGTH + index.nbytes + + +def test_index_cache_is_loaded_from_disk(): + """Second MapBuffer with same cache should read index from disk, not buffer.""" + mbuf = make_mapbuffer() + original_index = mbuf.index().copy() + + # Reload — this time the cache exists, so index should come from disk + mbuf2 = make_mapbuffer() + mbuf2._index = None # ensure not inherited + + with patch.object(np, "frombuffer", wraps=np.frombuffer) as mock_frombuffer: + loaded_index = mbuf2.index() + # np.frombuffer should NOT be called on the main buffer for the index + for call in mock_frombuffer.call_args_list: + args, kwargs = call + # Ensure we're not reading index from the primary buffer + assert kwargs.get("offset") != HEADER_LENGTH, \ + "Index was re-read from buffer instead of cache" + + np.testing.assert_array_equal(loaded_index, original_index) + + +def test_index_cache_values_correct(): + """Values retrieved using cache should match those from a non-cached buffer.""" + mbuf_cached = make_mapbuffer() + mbuf_plain = MapBuffer({1: b"hello", 2: b"world"}) + + for key in [1, 2]: + assert mbuf_cached[key] == mbuf_plain[key] + + +def test_crc_error_raised_despite_cache(): + """CRC validation should still catch corruption even when cache exists.""" + data = {1: b"hello", 2: b"world"} + mbuf = make_mapbuffer(data) + mbuf.index() # populate cache + + # Corrupt the data region in the buffer + buf = bytearray(mbuf.buffer) + idx = bytes(buf).index(b"hello") + buf[idx] = ord(b"H") + mbuf.buffer = bytes(buf) + mbuf._index = None # force re-read so cache is used but data is still corrupt + + with pytest.raises(ValidationError): + mbuf[1] + + +def test_index_cache_not_rewritten_if_already_complete(): + """Cache file should not be overwritten on second load.""" + mbuf = make_mapbuffer() + mbuf.index() + mtime_after_first = os.path.getmtime(CACHE_PATH) + mbuf2 = make_mapbuffer() + mbuf2.index() + mtime_after_second = os.path.getmtime(CACHE_PATH) + assert mtime_after_first == mtime_after_second, \ + "Cache file was unexpectedly rewritten on second access" \ No newline at end of file diff --git a/mapbuffer/mapbuffer.py b/mapbuffer/mapbuffer.py index e7e7f5f..46fcad7 100644 --- a/mapbuffer/mapbuffer.py +++ b/mapbuffer/mapbuffer.py @@ -1,5 +1,6 @@ from typing import Optional, Any, Union, Literal from collections.abc import Callable +import os import mmap import io @@ -9,6 +10,7 @@ from . import compression import crc32c +import fasteners import numpy as np import mapbufferaccel @@ -21,8 +23,8 @@ class MapBuffer: """Represents a usable int->bytes dictionary as a byte string.""" __slots__ = ( "data", "tobytesfn", "frombytesfn", - "dtype", "buffer", "check_crc", "compute_crc", - "_header", "_index", "_compress" + "dtype", "buffer", "check_crc", "compute_crc", "index_cache", + "_header", "_index", "_compress", "_lock" ) def __init__( self, @@ -32,6 +34,7 @@ def __init__( frombytesfn:Optional[Callable[[bytes], Any]] = None, check_crc:bool = True, compute_crc:bool = True, + index_cache:Optional[str] = None, ): """ data: dict (int->byte serializable object) or bytes @@ -52,10 +55,14 @@ def __init__( self.buffer = None self.check_crc = check_crc self.compute_crc = compute_crc + self.index_cache = index_cache self._header = None self._index = None self._compress = None + self._lock = None + if self.index_cache is not None: + self._lock = fasteners.InterProcessReaderWriterLock(self.index_cache) if isinstance(data, dict): self.buffer = self.dict2buf(data, compress) @@ -102,9 +109,29 @@ def header(self): if self._header is not None: return self._header + if self.index_cache is not None: + if os.path.exists(self.index_cache): + with self._lock.read_lock(): + with open(self.index_cache, "rb") as f: + self._header = f.read(HEADER_LENGTH) + + if len(self._header) == HEADER_LENGTH: + return self._header + # seems dumb, buf if self.buffer is an object that # requires network access, this is a valuable cache self._header = self.buffer[:HEADER_LENGTH] + + if self.index_cache is not None: + with self._lock.write_lock(): + try: + if os.path.getsize(self.index_cache) < HEADER_LENGTH: + with open(self.index_cache, "wb") as f: + f.write(self._header) + except FileNotFoundError: + with open(self.index_cache, "wb") as f: + f.write(self._header) + return self._header def index(self): @@ -115,6 +142,18 @@ def index(self): N = len(self) index_length = 2 * N + if self.index_cache is not None: + try: + if os.path.getsize(self.index_cache) > HEADER_LENGTH: + with self._lock.read_lock(): + with open(self.index_cache, "rb") as f: + f.seek(HEADER_LENGTH) + index = f.read(index_length * 8) + self._index = np.frombuffer(index, dtype=np.uint64).reshape((N,2)) + return self._index + except FileNotFoundError: + pass + if isinstance(self.buffer, (bytes,bytearray,np.ndarray,mmap.mmap)): self._index = np.frombuffer( self.buffer, @@ -127,6 +166,15 @@ def index(self): index = self.buffer[HEADER_LENGTH:index_length+HEADER_LENGTH] self._index = np.frombuffer(index, dtype=np.uint64).reshape((N,2)) + if self.index_cache is not None: + try: + if os.path.getsize(self.index_cache) == HEADER_LENGTH: + with self._lock.write_lock(): + with open(self.index_cache, "ab") as f: + f.write(self._index.tobytes('C')) + except FileNotFoundError: + pass + return self._index def keys(self): diff --git a/requirements.txt b/requirements.txt index f3af510..8c9461d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ brotli crc32c deflate>=0.2.0 +fasteners numpy tqdm zstandard \ No newline at end of file