Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion src/datajoint/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,12 +639,46 @@ def fetch(

# Handle specific attributes requested
if attrs:
# Check for special 'KEY' attribute
def is_key(attr):
return attr == "KEY"

has_key = any(is_key(a) for a in attrs)

# Handle fetch('KEY') alone - return list of primary key dicts
if has_key and len(attrs) == 1:
return list(self.keys(order_by=order_by, limit=limit, offset=offset))

if as_dict is True:
# fetch('col1', 'col2', as_dict=True) -> list of dicts
return self.proj(*attrs).to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
# Replace KEY with primary key columns
proj_attrs = []
for attr in attrs:
if is_key(attr):
proj_attrs.extend(self.primary_key)
else:
proj_attrs.append(attr)
return self.proj(*proj_attrs).to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
else:
# fetch('col1', 'col2') or fetch('col1', 'col2', as_dict=False) -> tuple of arrays
# This matches DJ 1.x behavior where fetch('col') returns array(['alpha', 'beta'])
if has_key:
# Need to handle KEY specially - it returns list of dicts, not array
proj_attrs = []
for attr in attrs:
if is_key(attr):
proj_attrs.extend(self.primary_key)
else:
proj_attrs.append(attr)
dicts = self.proj(*proj_attrs).to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
# Build result, with KEY returning list of dicts
results = []
for attr in attrs:
if is_key(attr):
results.append([{k: d[k] for k in self.primary_key} for d in dicts])
else:
results.append(np.array([d[attr] for d in dicts]))
return results[0] if len(attrs) == 1 else tuple(results)
return self.to_arrays(*attrs, order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)

# Handle as_dict=True -> to_dicts()
Expand Down
49 changes: 49 additions & 0 deletions tests/integration/test_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,52 @@ def test_to_arrays_inhomogeneous_shapes_second_axis(schema_any):
assert data[0].shape == (100,)
assert data[1].shape == (1, 100)
assert data[2].shape == (2, 100)


def test_fetch_KEY(lang, languages):
"""Test fetch('KEY') returns list of primary key dicts.

Regression test for https://github.com/datajoint/datajoint-python/issues/1381
"""
import warnings

# Suppress deprecation warning for fetch
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)

# fetch('KEY') should return list of primary key dicts
keys = lang.fetch("KEY")
assert isinstance(keys, list)
assert len(keys) == len(languages)
assert all(isinstance(k, dict) for k in keys)
# Primary key is (name, language)
assert all(set(k.keys()) == {"name", "language"} for k in keys)


def test_fetch1_KEY(lang):
"""Test fetch1('KEY') returns primary key dict.

Regression test for https://github.com/datajoint/datajoint-python/issues/1381
"""
key = {"name": "Edgar", "language": "Japanese"}
result = (lang & key).fetch1("KEY")
assert isinstance(result, dict)
assert result == key


def test_fetch_KEY_with_other_attrs(lang):
"""Test fetch('KEY', 'name') returns (keys_list, name_array).

Regression test for https://github.com/datajoint/datajoint-python/issues/1381
"""
import warnings

with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)

# fetch('KEY', 'name') should return tuple of (list of dicts, array)
keys, names = lang.fetch("KEY", "name")
assert isinstance(keys, list)
assert all(isinstance(k, dict) for k in keys)
assert isinstance(names, np.ndarray)
assert len(keys) == len(names)
Loading