diff --git a/src/datajoint/expression.py b/src/datajoint/expression.py index 3f75edeec..8bcdaab48 100644 --- a/src/datajoint/expression.py +++ b/src/datajoint/expression.py @@ -639,12 +639,46 @@ def fetch( # Handle specific attributes requested if attrs: + # Check for special 'KEY' attribute + def is_key(attr): + return attr == "KEY" + + has_key = any(is_key(a) for a in attrs) + + # Handle fetch('KEY') alone - return list of primary key dicts + if has_key and len(attrs) == 1: + return list(self.keys(order_by=order_by, limit=limit, offset=offset)) + if as_dict is True: # fetch('col1', 'col2', as_dict=True) -> list of dicts - return self.proj(*attrs).to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze) + # Replace KEY with primary key columns + proj_attrs = [] + for attr in attrs: + if is_key(attr): + proj_attrs.extend(self.primary_key) + else: + proj_attrs.append(attr) + return self.proj(*proj_attrs).to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze) else: # fetch('col1', 'col2') or fetch('col1', 'col2', as_dict=False) -> tuple of arrays # This matches DJ 1.x behavior where fetch('col') returns array(['alpha', 'beta']) + if has_key: + # Need to handle KEY specially - it returns list of dicts, not array + proj_attrs = [] + for attr in attrs: + if is_key(attr): + proj_attrs.extend(self.primary_key) + else: + proj_attrs.append(attr) + dicts = self.proj(*proj_attrs).to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze) + # Build result, with KEY returning list of dicts + results = [] + for attr in attrs: + if is_key(attr): + results.append([{k: d[k] for k in self.primary_key} for d in dicts]) + else: + results.append(np.array([d[attr] for d in dicts])) + return results[0] if len(attrs) == 1 else tuple(results) return self.to_arrays(*attrs, order_by=order_by, limit=limit, offset=offset, squeeze=squeeze) # Handle as_dict=True -> to_dicts() diff --git a/tests/integration/test_fetch.py b/tests/integration/test_fetch.py index dd556ff70..695e02984 100644 --- a/tests/integration/test_fetch.py +++ b/tests/integration/test_fetch.py @@ -457,3 +457,52 @@ def test_to_arrays_inhomogeneous_shapes_second_axis(schema_any): assert data[0].shape == (100,) assert data[1].shape == (1, 100) assert data[2].shape == (2, 100) + + +def test_fetch_KEY(lang, languages): + """Test fetch('KEY') returns list of primary key dicts. + + Regression test for https://github.com/datajoint/datajoint-python/issues/1381 + """ + import warnings + + # Suppress deprecation warning for fetch + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + # fetch('KEY') should return list of primary key dicts + keys = lang.fetch("KEY") + assert isinstance(keys, list) + assert len(keys) == len(languages) + assert all(isinstance(k, dict) for k in keys) + # Primary key is (name, language) + assert all(set(k.keys()) == {"name", "language"} for k in keys) + + +def test_fetch1_KEY(lang): + """Test fetch1('KEY') returns primary key dict. + + Regression test for https://github.com/datajoint/datajoint-python/issues/1381 + """ + key = {"name": "Edgar", "language": "Japanese"} + result = (lang & key).fetch1("KEY") + assert isinstance(result, dict) + assert result == key + + +def test_fetch_KEY_with_other_attrs(lang): + """Test fetch('KEY', 'name') returns (keys_list, name_array). + + Regression test for https://github.com/datajoint/datajoint-python/issues/1381 + """ + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + # fetch('KEY', 'name') should return tuple of (list of dicts, array) + keys, names = lang.fetch("KEY", "name") + assert isinstance(keys, list) + assert all(isinstance(k, dict) for k in keys) + assert isinstance(names, np.ndarray) + assert len(keys) == len(names)