Skip to content

Commit 0f59d6c

Browse files
CopilotlappalainenjCopilot
authored
Fix NotFittedError in umap_embedding when fewer than 2 rows have nonzero variance (#19)
* Initial plan * Fix umap_embedding edge case: return early when fewer than 2 non-constant rows exist Co-authored-by: lappalainenj <34949352+lappalainenj@users.noreply.github.com> * Update tests/test_clustering.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fix return type annotation: use Optional[UMAP] and clarify mask semantics in docstring Co-authored-by: lappalainenj <34949352+lappalainenj@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: lappalainenj <34949352+lappalainenj@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 7882efe commit 0f59d6c

2 files changed

Lines changed: 64 additions & 8 deletions

File tree

flyvis/analysis/clustering.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def umap_embedding(
439439
metric: str = "correlation",
440440
n_epochs: int = 1500,
441441
**kwargs,
442-
) -> Tuple[np.ndarray, np.ndarray, UMAP]:
442+
) -> Tuple[np.ndarray, np.ndarray, Optional[UMAP]]:
443443
"""
444444
Perform UMAP embedding on input data.
445445
@@ -456,9 +456,14 @@ def umap_embedding(
456456
457457
Returns:
458458
A tuple containing:
459-
- embedding: The UMAP embedding.
460-
- mask: Boolean mask for valid samples.
461-
- reducer: The fitted UMAP object.
459+
- embedding: The UMAP embedding (n_samples, n_components). May be NaN
460+
if insufficient data.
461+
- mask: Boolean mask (length n_samples). When reducer is not None,
462+
True indicates rows with nonzero variance that were also connected
463+
in the UMAP graph. When reducer is None (insufficient data), True
464+
indicates only rows with nonzero variance.
465+
- reducer: The fitted UMAP object or None if fewer than 2 rows had
466+
nonzero variance.
462467
463468
Raises:
464469
ValueError: If n_components is too large relative to sample size.
@@ -481,10 +486,16 @@ def umap_embedding(
481486
X = X.reshape(X.shape[0], -1)
482487
logging.info("reshaped X from %s to %s", shape, X.shape)
483488

484-
embedding = np.ones([X.shape[0], n_components]) * np.nan
485-
# umap doesn't like contant rows
489+
n_samples = X.shape[0]
490+
embedding = np.ones([n_samples, n_components]) * np.nan
491+
# umap doesn't like constant rows
486492
mask = ~np.isclose(X.std(axis=1), 0)
487-
X = X[mask]
493+
X_nonconst = X[mask]
494+
495+
# If fewer than 2 rows remain, skip UMAP and return embedding of NaNs.
496+
if X_nonconst.shape[0] < 2:
497+
return embedding, mask, None
498+
488499
reducer = UMAP(
489500
n_neighbors=n_neighbors,
490501
min_dist=min_dist,
@@ -495,7 +506,7 @@ def umap_embedding(
495506
n_epochs=n_epochs,
496507
**kwargs,
497508
)
498-
_embedding = reducer.fit_transform(X)
509+
_embedding = reducer.fit_transform(X_nonconst)
499510

500511
# gaussian mixture doesn't like nans through disconnected vertices in umap
501512
connected_vertices_mask = ~disconnected_vertices(reducer)

tests/test_clustering.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import numpy as np
2+
3+
from flyvis.analysis.clustering import umap_embedding
4+
5+
6+
def test_umap_embedding_single_nonzero_variance_row():
7+
"""Test that umap_embedding handles the edge case where only one row has
8+
nonzero variance (all others are constant). UMAP should not be fitted and
9+
the function should return NaN embedding with None reducer."""
10+
rng = np.random.default_rng(0)
11+
# One row with variance, four constant rows
12+
X = np.zeros((5, 10))
13+
X[2] = rng.random(10)
14+
15+
embedding, mask, reducer = umap_embedding(X)
16+
17+
assert reducer is None
18+
assert np.all(np.isnan(embedding))
19+
# Only the one non-constant row should be True in the mask
20+
expected_mask = np.array([False, False, True, False, False])
21+
np.testing.assert_array_equal(mask, expected_mask)
22+
23+
24+
def test_umap_embedding_all_zero_variance_rows():
25+
"""Test that umap_embedding handles all-constant rows gracefully."""
26+
X = np.ones((5, 10))
27+
28+
embedding, mask, reducer = umap_embedding(X)
29+
30+
assert reducer is None
31+
assert np.all(np.isnan(embedding))
32+
assert not np.any(mask)
33+
34+
35+
def test_umap_embedding_returns_none_reducer_when_insufficient_data():
36+
"""Test that reducer is None when fewer than 2 rows have nonzero variance."""
37+
X = np.zeros((4, 8))
38+
# Only one non-constant row
39+
X[0] = np.arange(8, dtype=float)
40+
41+
embedding, mask, reducer = umap_embedding(X)
42+
43+
assert reducer is None
44+
assert embedding.shape == (4, 2)
45+
assert np.all(np.isnan(embedding))

0 commit comments

Comments
 (0)