diff --git a/src/tinygp/kernels/quasisep.py b/src/tinygp/kernels/quasisep.py index b380f08..2ed03f6 100644 --- a/src/tinygp/kernels/quasisep.py +++ b/src/tinygp/kernels/quasisep.py @@ -243,9 +243,13 @@ def coord_to_sortable(self, X: JAXArray) -> JAXArray: return self.kernel1.coord_to_sortable(X) def _block_or_dense(self, m1: JAXArray, m2: JAXArray) -> JAXArray: - if self.use_block: - return Block(m1, m2) - return jsp_block_diag(m1, m2) + if not self.use_block: + return jsp_block_diag(m1, m2) + + # Ensure we don't nest Block objects to fix Issue #265 + blocks1 = m1.blocks if isinstance(m1, Block) else (m1,) + blocks2 = m2.blocks if isinstance(m2, Block) else (m2,) + return Block(*blocks1, *blocks2) def design_matrix(self) -> JAXArray: return self._block_or_dense( diff --git a/src/tinygp/solvers/quasisep/block.py b/src/tinygp/solvers/quasisep/block.py index 6e102cd..3bcb0bc 100644 --- a/src/tinygp/solvers/quasisep/block.py +++ b/src/tinygp/solvers/quasisep/block.py @@ -55,6 +55,9 @@ def to_dense(self) -> JAXArray: def __mul__(self, other: Any) -> "Block": return Block(*(b * other for b in self.blocks)) + def __rmul__(self, other: Any) -> "Block": + return self.__mul__(other) + @jax.jit def __add__(self, other: Any) -> Any: if isinstance(other, Block):