From bb2c1e088531ff0aa7cd937c397b4abfc2e570a3 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 23 Mar 2026 14:33:59 -0500 Subject: [PATCH 01/15] fix: do not allow in-place operations on images --- jax_galsim/image.py | 228 +++++++++++++++++++++++++------------------- 1 file changed, 129 insertions(+), 99 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index f6f0f518..01d3b0bd 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1261,18 +1261,21 @@ def Image_add(self, other): def Image_iadd(self, other): - check_image_consistency(self, other) - try: - a = other.array - dt = a.dtype - except AttributeError: - a = other - dt = type(a) - if dt == self.array.dtype: - self._array = self.array + a - else: - self._array = (self.array + a).astype(self.array.dtype) - return self + raise RuntimeError( + "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." + ) + # check_image_consistency(self, other) + # try: + # a = other.array + # dt = a.dtype + # except AttributeError: + # a = other + # dt = type(a) + # if dt == self.array.dtype: + # self._array = self.array + a + # else: + # self._array = (self.array + a).astype(self.array.dtype) + # return self def Image_sub(self, other): @@ -1289,18 +1292,21 @@ def Image_rsub(self, other): def Image_isub(self, other): - check_image_consistency(self, other) - try: - a = other.array - dt = a.dtype - except AttributeError: - a = other - dt = type(a) - if dt == self.array.dtype: - self._array = self.array - a - else: - self._array = (self.array - a).astype(self.array.dtype) - return self + raise RuntimeError( + "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." + ) + # check_image_consistency(self, other) + # try: + # a = other.array + # dt = a.dtype + # except AttributeError: + # a = other + # dt = type(a) + # if dt == self.array.dtype: + # self._array = self.array - a + # else: + # self._array = (self.array - a).astype(self.array.dtype) + # return self def Image_mul(self, other): @@ -1313,18 +1319,21 @@ def Image_mul(self, other): def Image_imul(self, other): - check_image_consistency(self, other) - try: - a = other.array - dt = a.dtype - except AttributeError: - a = other - dt = type(a) - if dt == self.array.dtype: - self._array = self.array * a - else: - self._array = (self.array * a).astype(self.array.dtype) - return self + raise RuntimeError( + "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." + ) + # check_image_consistency(self, other) + # try: + # a = other.array + # dt = a.dtype + # except AttributeError: + # a = other + # dt = type(a) + # if dt == self.array.dtype: + # self._array = self.array * a + # else: + # self._array = (self.array * a).astype(self.array.dtype) + # return self def Image_div(self, other): @@ -1341,20 +1350,23 @@ def Image_rdiv(self, other): def Image_idiv(self, other): - check_image_consistency(self, other) - try: - a = other.array - dt = a.dtype - except AttributeError: - a = other - dt = type(a) - if dt == self.array.dtype and not self.isinteger: - # if dtype is an integer type, then numpy doesn't allow true division /= to assign - # back to an integer array. So for integers (or mixed types), don't use /=. - self._array = self.array / a - else: - self._array = (self.array / a).astype(self.array.dtype) - return self + raise RuntimeError( + "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." + ) + # check_image_consistency(self, other) + # try: + # a = other.array + # dt = a.dtype + # except AttributeError: + # a = other + # dt = type(a) + # if dt == self.array.dtype and not self.isinteger: + # # if dtype is an integer type, then numpy doesn't allow true division /= to assign + # # back to an integer array. So for integers (or mixed types), don't use /=. + # self._array = self.array / a + # else: + # self._array = (self.array / a).astype(self.array.dtype) + # return self def Image_floordiv(self, other): @@ -1372,18 +1384,21 @@ def Image_rfloordiv(self, other): def Image_ifloordiv(self, other): - check_image_consistency(self, other, integer=True) - try: - a = other.array - dt = a.dtype - except AttributeError: - a = other - dt = type(a) - if dt == self.array.dtype: - self._array = self.array // a - else: - self._array = (self.array // a).astype(self.array.dtype) - return self + raise RuntimeError( + "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." + ) + # check_image_consistency(self, other, integer=True) + # try: + # a = other.array + # dt = a.dtype + # except AttributeError: + # a = other + # dt = type(a) + # if dt == self.array.dtype: + # self._array = self.array // a + # else: + # self._array = (self.array // a).astype(self.array.dtype) + # return self def Image_mod(self, other): @@ -1401,18 +1416,21 @@ def Image_rmod(self, other): def Image_imod(self, other): - check_image_consistency(self, other, integer=True) - try: - a = other.array - dt = a.dtype - except AttributeError: - a = other - dt = type(a) - if dt == self.array.dtype: - self._array = self.array % a - else: - self._array = (self.array % a).astype(self.array.dtype) - return self + raise RuntimeError( + "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." + ) + # check_image_consistency(self, other, integer=True) + # try: + # a = other.array + # dt = a.dtype + # except AttributeError: + # a = other + # dt = type(a) + # if dt == self.array.dtype: + # self._array = self.array % a + # else: + # self._array = (self.array % a).astype(self.array.dtype) + # return self def Image_pow(self, other): @@ -1420,10 +1438,13 @@ def Image_pow(self, other): def Image_ipow(self, other): - if not isinstance(other, int) and not isinstance(other, float): - raise TypeError("Can only raise an image to a float or int power!") - self._array = self.array**other - return self + raise RuntimeError( + "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." + ) + # if not isinstance(other, int) and not isinstance(other, float): + # raise TypeError("Can only raise an image to a float or int power!") + # self._array = self.array**other + # return self def Image_neg(self): @@ -1443,13 +1464,16 @@ def Image_and(self, other): def Image_iand(self, other): - check_image_consistency(self, other, integer=True) - try: - a = other.array - except AttributeError: - a = other - self._array = self.array & a - return self + raise RuntimeError( + "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." + ) + # check_image_consistency(self, other, integer=True) + # try: + # a = other.array + # except AttributeError: + # a = other + # self._array = self.array & a + # return self def Image_xor(self, other): @@ -1462,13 +1486,16 @@ def Image_xor(self, other): def Image_ixor(self, other): - check_image_consistency(self, other, integer=True) - try: - a = other.array - except AttributeError: - a = other - self._array = self.array ^ a - return self + raise RuntimeError( + "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." + ) + # check_image_consistency(self, other, integer=True) + # try: + # a = other.array + # except AttributeError: + # a = other + # self._array = self.array ^ a + # return self def Image_or(self, other): @@ -1481,13 +1508,16 @@ def Image_or(self, other): def Image_ior(self, other): - check_image_consistency(self, other, integer=True) - try: - a = other.array - except AttributeError: - a = other - self._array = self.array | a - return self + raise RuntimeError( + "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." + ) + # check_image_consistency(self, other, integer=True) + # try: + # a = other.array + # except AttributeError: + # a = other + # self._array = self.array | a + # return self # inject the arithmetic operators as methods of the Image class: From 6ea960fa46ff027cdca27bc292613df165f725bb Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 23 Mar 2026 16:31:25 -0500 Subject: [PATCH 02/15] fix: use actual in place ops --- jax_galsim/convolve.py | 4 +- jax_galsim/gsobject.py | 2 +- jax_galsim/image.py | 233 ++++++++++------------ jax_galsim/noise.py | 2 +- jax_galsim/sum.py | 8 +- tests/jax/test_interpolatedimage_utils.py | 2 +- 6 files changed, 115 insertions(+), 136 deletions(-) diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 6961ad39..98717eae 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -316,7 +316,9 @@ def _drawKImage(self, image, jac=None): image = self.obj_list[0]._drawKImage(image, jac) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image *= obj._drawKImage(image, jac) + image._array = image._array.at[...].multiply( + obj._drawKImage(image, jac)._array + ) return image def tree_flatten(self): diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index b687e175..0e09c57a 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1407,7 +1407,7 @@ def _draw_phot_while_loop_shoot( im1 = ImageD(bounds=image.bounds) added_flux += sensor.accumulate(photons, im1, orig_center) - image += im1 + image._array = image._array.at[...].add(im1) return _DrawPhotReturnTuple( photons, rng, added_flux, image, photon_ops, sensor, resume diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 01d3b0bd..d9259732 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -713,7 +713,7 @@ def calculate_fft(self): jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0 ) - out *= dx * dx + out._array = out._array.at[...].multiply(dx * dx) out.setOrigin(0, -No2) return out @@ -769,7 +769,7 @@ def calculate_inverse_fft(self): ) # Now cut off the bit we don't need. out = out_extra.subImage(BoundsI(-No2, No2 - 1, -No2, No2 - 1)) - out *= (dk * No2 / jnp.pi) ** 2 + out._array = out._array.at[:].multiply((dk * No2 / jnp.pi) ** 2) out.setCenter(0, 0) return out @@ -1261,21 +1261,18 @@ def Image_add(self, other): def Image_iadd(self, other): - raise RuntimeError( - "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." - ) - # check_image_consistency(self, other) - # try: - # a = other.array - # dt = a.dtype - # except AttributeError: - # a = other - # dt = type(a) - # if dt == self.array.dtype: - # self._array = self.array + a - # else: - # self._array = (self.array + a).astype(self.array.dtype) - # return self + check_image_consistency(self, other) + try: + a = other.array + dt = a.dtype + except AttributeError: + a = other + dt = type(a) + if dt == self.array.dtype: + self._array = self.array.at[...].add(a) + else: + self._array = (self.array + a).astype(self.array.dtype) + return self def Image_sub(self, other): @@ -1292,21 +1289,18 @@ def Image_rsub(self, other): def Image_isub(self, other): - raise RuntimeError( - "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." - ) - # check_image_consistency(self, other) - # try: - # a = other.array - # dt = a.dtype - # except AttributeError: - # a = other - # dt = type(a) - # if dt == self.array.dtype: - # self._array = self.array - a - # else: - # self._array = (self.array - a).astype(self.array.dtype) - # return self + check_image_consistency(self, other) + try: + a = other.array + dt = a.dtype + except AttributeError: + a = other + dt = type(a) + if dt == self.array.dtype: + self._array = self.array.at[...].subtract(a) + else: + self._array = (self.array - a).astype(self.array.dtype) + return self def Image_mul(self, other): @@ -1319,21 +1313,18 @@ def Image_mul(self, other): def Image_imul(self, other): - raise RuntimeError( - "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." - ) - # check_image_consistency(self, other) - # try: - # a = other.array - # dt = a.dtype - # except AttributeError: - # a = other - # dt = type(a) - # if dt == self.array.dtype: - # self._array = self.array * a - # else: - # self._array = (self.array * a).astype(self.array.dtype) - # return self + check_image_consistency(self, other) + try: + a = other.array + dt = a.dtype + except AttributeError: + a = other + dt = type(a) + if dt == self.array.dtype: + self._array = self.array.at[...].multiply(a) + else: + self._array = (self.array * a).astype(self.array.dtype) + return self def Image_div(self, other): @@ -1350,23 +1341,20 @@ def Image_rdiv(self, other): def Image_idiv(self, other): - raise RuntimeError( - "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." - ) - # check_image_consistency(self, other) - # try: - # a = other.array - # dt = a.dtype - # except AttributeError: - # a = other - # dt = type(a) - # if dt == self.array.dtype and not self.isinteger: - # # if dtype is an integer type, then numpy doesn't allow true division /= to assign - # # back to an integer array. So for integers (or mixed types), don't use /=. - # self._array = self.array / a - # else: - # self._array = (self.array / a).astype(self.array.dtype) - # return self + check_image_consistency(self, other) + try: + a = other.array + dt = a.dtype + except AttributeError: + a = other + dt = type(a) + if dt == self.array.dtype and not self.isinteger: + # if dtype is an integer type, then numpy doesn't allow true division /= to assign + # back to an integer array. So for integers (or mixed types), don't use /=. + self._array = self.array.at[...].divide(a) + else: + self._array = (self.array / a).astype(self.array.dtype) + return self def Image_floordiv(self, other): @@ -1375,30 +1363,27 @@ def Image_floordiv(self, other): a = other.array except AttributeError: a = other - return Image(self.array // a, bounds=self.bounds, wcs=self.wcs) + return Image(self._array // a, bounds=self.bounds, wcs=self.wcs) def Image_rfloordiv(self, other): check_image_consistency(self, other, integer=True) - return Image(other // self.array, bounds=self.bounds, wcs=self.wcs) + return Image(other // self._array, bounds=self.bounds, wcs=self.wcs) def Image_ifloordiv(self, other): - raise RuntimeError( - "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." - ) - # check_image_consistency(self, other, integer=True) - # try: - # a = other.array - # dt = a.dtype - # except AttributeError: - # a = other - # dt = type(a) - # if dt == self.array.dtype: - # self._array = self.array // a - # else: - # self._array = (self.array // a).astype(self.array.dtype) - # return self + check_image_consistency(self, other, integer=True) + try: + a = other.array + dt = a.dtype + except AttributeError: + a = other + dt = type(a) + if dt == self.array.dtype: + self._array = self.array // a + else: + self._array = (self.array // a).astype(self.array.dtype) + return self def Image_mod(self, other): @@ -1416,21 +1401,18 @@ def Image_rmod(self, other): def Image_imod(self, other): - raise RuntimeError( - "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." - ) - # check_image_consistency(self, other, integer=True) - # try: - # a = other.array - # dt = a.dtype - # except AttributeError: - # a = other - # dt = type(a) - # if dt == self.array.dtype: - # self._array = self.array % a - # else: - # self._array = (self.array % a).astype(self.array.dtype) - # return self + check_image_consistency(self, other, integer=True) + try: + a = other.array + dt = a.dtype + except AttributeError: + a = other + dt = type(a) + if dt == self.array.dtype: + self._array = self.array % a + else: + self._array = (self.array % a).astype(self.array.dtype) + return self def Image_pow(self, other): @@ -1441,10 +1423,10 @@ def Image_ipow(self, other): raise RuntimeError( "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." ) - # if not isinstance(other, int) and not isinstance(other, float): - # raise TypeError("Can only raise an image to a float or int power!") - # self._array = self.array**other - # return self + if not isinstance(other, int) and not isinstance(other, float): + raise TypeError("Can only raise an image to a float or int power!") + self._array = self.array.at[...].power(other) + return self def Image_neg(self): @@ -1464,16 +1446,13 @@ def Image_and(self, other): def Image_iand(self, other): - raise RuntimeError( - "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." - ) - # check_image_consistency(self, other, integer=True) - # try: - # a = other.array - # except AttributeError: - # a = other - # self._array = self.array & a - # return self + check_image_consistency(self, other, integer=True) + try: + a = other.array + except AttributeError: + a = other + self._array = self.array & a + return self def Image_xor(self, other): @@ -1486,16 +1465,13 @@ def Image_xor(self, other): def Image_ixor(self, other): - raise RuntimeError( - "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." - ) - # check_image_consistency(self, other, integer=True) - # try: - # a = other.array - # except AttributeError: - # a = other - # self._array = self.array ^ a - # return self + check_image_consistency(self, other, integer=True) + try: + a = other.array + except AttributeError: + a = other + self._array = self.array ^ a + return self def Image_or(self, other): @@ -1508,16 +1484,13 @@ def Image_or(self, other): def Image_ior(self, other): - raise RuntimeError( - "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." - ) - # check_image_consistency(self, other, integer=True) - # try: - # a = other.array - # except AttributeError: - # a = other - # self._array = self.array | a - # return self + check_image_consistency(self, other, integer=True) + try: + a = other.array + except AttributeError: + a = other + self._array = self.array | a + return self # inject the arithmetic operators as methods of the Image class: diff --git a/jax_galsim/noise.py b/jax_galsim/noise.py index 51325ea3..7964a1c4 100644 --- a/jax_galsim/noise.py +++ b/jax_galsim/noise.py @@ -28,7 +28,7 @@ def addNoiseSNR(self, noise, snr, preserve_flux=False): else: sn_meas = jnp.sqrt(sumsq / noise_var) flux = snr / sn_meas - self *= flux + self._array = self._array.at[...].multiply(flux) self.addNoise(noise) return noise_var diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index 958e6bfa..84cec25c 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -164,14 +164,18 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): image = self.obj_list[0]._drawReal(image, jac, offset, flux_scaling) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image += obj._drawReal(image, jac, offset, flux_scaling) + image._array = image._array.at[...].add( + obj._drawReal(image, jac, offset, flux_scaling)._array + ) return image def _drawKImage(self, image, jac=None): image = self.obj_list[0]._drawKImage(image, jac) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image += obj._drawKImage(image, jac) + image._array = image._array.at[...].add( + obj._drawKImage(image, jac)._array + ) return image @property diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 44c98f13..0274c726 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -290,7 +290,7 @@ def _compute_fft_with_numpy_jax_galsim(im): out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) out._array = np.fft.fftshift(np.fft.rfft2(np.fft.fftshift(ximage.array)), axes=0) - out *= dx * dx + out._array *= dx * dx out.setOrigin(0, -No2) return out From 951e5732bb92c8929bf8c1a48878ae4a97bb918b Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 23 Mar 2026 16:32:08 -0500 Subject: [PATCH 03/15] fix: missed an exception --- jax_galsim/image.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index d9259732..11aef541 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1420,9 +1420,6 @@ def Image_pow(self, other): def Image_ipow(self, other): - raise RuntimeError( - "In-place operations (e.g., `+=`) are not allowed on JAX-GalSim images." - ) if not isinstance(other, int) and not isinstance(other, float): raise TypeError("Can only raise an image to a float or int power!") self._array = self.array.at[...].power(other) From a54105d622011af1ca46a14f3833e3ac639099e2 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 23 Mar 2026 16:33:26 -0500 Subject: [PATCH 04/15] fix: do not need these changes Co-authored-by: Matthew R. Becker --- jax_galsim/image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 11aef541..983817cb 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1363,12 +1363,12 @@ def Image_floordiv(self, other): a = other.array except AttributeError: a = other - return Image(self._array // a, bounds=self.bounds, wcs=self.wcs) + return Image(self.array // a, bounds=self.bounds, wcs=self.wcs) def Image_rfloordiv(self, other): check_image_consistency(self, other, integer=True) - return Image(other // self._array, bounds=self.bounds, wcs=self.wcs) + return Image(other // self.array, bounds=self.bounds, wcs=self.wcs) def Image_ifloordiv(self, other): From a00883214ae9e8503d5772c2ba3b381a26e663b6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 23 Mar 2026 16:49:58 -0500 Subject: [PATCH 05/15] fix: need array attribute --- jax_galsim/gsobject.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 0e09c57a..816f247d 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1407,7 +1407,7 @@ def _draw_phot_while_loop_shoot( im1 = ImageD(bounds=image.bounds) added_flux += sensor.accumulate(photons, im1, orig_center) - image._array = image._array.at[...].add(im1) + image._array = image._array.at[...].add(im1._array) return _DrawPhotReturnTuple( photons, rng, added_flux, image, photon_ops, sensor, resume From 8f8d097ece62787815c8b934e68193628a97aa78 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 23 Mar 2026 16:52:06 -0500 Subject: [PATCH 06/15] fix: make sure to convert to jax array --- jax_galsim/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 983817cb..762f6fc3 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1100,7 +1100,7 @@ def from_galsim(cls, galsim_image): else None ) im = cls( - array=galsim_image.array, + array=jnp.asarray(galsim_image.array), wcs=wcs, bounds=Bounds.from_galsim(galsim_image.bounds), ) From 8321a1e250505dd65f81610597a0a1e1a0a44eb9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 24 Mar 2026 05:06:57 -0500 Subject: [PATCH 07/15] feat: add .at method for images --- jax_galsim/core/index.py | 275 +++++++++++++++++++++++++++++++++++++++ jax_galsim/image.py | 185 +++++++++----------------- tests/jax/test_api.py | 46 +++++++ 3 files changed, 384 insertions(+), 122 deletions(-) create mode 100644 jax_galsim/core/index.py diff --git a/jax_galsim/core/index.py b/jax_galsim/core/index.py new file mode 100644 index 00000000..c48bdf66 --- /dev/null +++ b/jax_galsim/core/index.py @@ -0,0 +1,275 @@ +import jax +import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class + + +@register_pytree_node_class +class ImageIndexer: + def __init__(self, image): + self.image = image + + def tree_flatten(self): + """Flatten the image into a list of values.""" + children = (self.image,) + aux_data = {} + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + obj = object.__new__(cls) + obj.image = children[0] + return obj + + def __getitem__(self, *args): + from jax_galsim import BoundsI, PositionI + + if len(args) == 1: + args = args[0] + else: + raise TypeError("`image.at[index]` got unknown args: %r" % (args,)) + + if isinstance(args, BoundsI): + return ImageIndex(self.image, args) + elif isinstance(args, PositionI): + return ImageIndex(self.image, args) + elif args is Ellipsis: + return ImageIndex(self.image, self.image.bounds) + elif isinstance(args, tuple): + if ( + isinstance(args[0], slice) + and isinstance(args[1], slice) + and args[0] == slice(None) + and args[1] == slice(None) + ): + return ImageIndex(self.image, self.image.bounds) + else: + return ImageIndex(self.image, PositionI(*args)) + else: + raise TypeError( + "`image.at[index]` only accepts BoundsI, PositionI, " + "tuples, `...`, `:, :`, or `x, y` for the index." + ) + + +@register_pytree_node_class +class ImageIndex: + def __init__(self, image, index): + self.image = image + self.index = index + + def tree_flatten(self): + """Flatten the image into a list of values.""" + children = (self.image, self.index) + aux_data = {} + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + obj = object.__new__(cls) + obj.image = children[0] + obj.index = children[1] + return obj + + def set(self, value): + import galsim as _galsim + + from jax_galsim import BoundsI, PositionI + + if self.image.isconst: + raise _galsim.GalSimImmutableError( + "Cannot modify an immutable Image", self.image + ) + + if not self.image.bounds.isDefined(): + raise _galsim.GalSimUndefinedBoundsError( + "Attempt to set value of an undefined image" + ) + + if isinstance(self.index, PositionI): + if not self.image.bounds.includes(self.index): + raise _galsim.GalSimBoundsError( + "Attempt to set position not in bounds of image", + self.index, + self.image.bounds, + ) + self.image._setValue(self.index.x, self.index.y, value) + elif isinstance(self.index, BoundsI): + if not self.image.bounds.includes(self.index): + raise _galsim.GalSimBoundsError( + "Attempt to access subImage not (fully) in image", + self.index, + self.image.bounds, + ) + if ( + hasattr(value, "bounds") + and self.index.numpyShape() != value.bounds.numpyShape() + ): + raise _galsim.GalSimIncompatibleValuesError( + "Trying to copy images that are not the same shape", + self_image=self.image, + rhs=value, + ) + start_inds = ( + self.index.ymin - self.image.ymin, + self.index.xmin - self.image.xmin, + ) + self.image._array = jax.lax.dynamic_update_slice( + self.image.array, + value.array + if hasattr(value, "array") + else jnp.broadcast_to(value, self.index.numpyShape()), + start_inds, + ) + else: + raise TypeError( + "This error should never be raised. " + "image.at[index] only accepts BoundsI or PositionI for the index" + ) + + return self.image + + def get(self): + import galsim as _galsim + + from jax_galsim import BoundsI, PositionI + + if not self.image.bounds.isDefined(): + raise _galsim.GalSimUndefinedBoundsError( + "Attempt to get value of an undefined image" + ) + + if isinstance(self.index, PositionI): + if not self.image.bounds.includes(self.index): + raise _galsim.GalSimBoundsError( + "Attempt to access position not in bounds of image.", + self.index, + self.image.bounds, + ) + return self.image._getValue(self.index.x, self.index.y) + elif isinstance(self.index, BoundsI): + if not self.image.bounds.includes(self.index): + raise _galsim.GalSimBoundsError( + "Attempt to access subImage not (fully) in image", + self.index, + self.image.bounds, + ) + start_inds = ( + self.index.ymin - self.image.ymin, + self.index.xmin - self.image.xmin, + ) + shape = self.index.numpyShape() + arr = jax.lax.dynamic_slice(self.image.array, start_inds, shape) + return self.image.__class__(arr, bounds=self.index, wcs=self.image.wcs) + else: + raise TypeError( + "This error should never be raised. " + "image.at[index] only accepts BoundsI or PositionI for the index" + ) + + def _op(self, value, func, check_integer=False): + import galsim as _galsim + + from jax_galsim import BoundsI, Image, PositionI + + if check_integer and not self.image.isinteger: + raise _galsim.GalSimValueError( + "Image must have integer values.", self.image + ) + + if check_integer and isinstance(value, Image) and not value.isinteger: + raise _galsim.GalSimValueError( + "Image must have integer values.", self.image + ) + + if self.image.isconst: + raise _galsim.GalSimImmutableError( + "Cannot modify an immutable Image", self.image + ) + + if not self.image.bounds.isDefined(): + raise _galsim.GalSimUndefinedBoundsError( + "Attempt to modify to an undefined image" + ) + + if isinstance(self.index, PositionI): + if not self.image.bounds.includes(self.index): + raise _galsim.GalSimBoundsError( + "Attempt to modify position not in bounds of image.", + self.index, + self.image.bounds, + ) + subim = self.image._getValue(self.index.x, self.index.y) + self.image._setValue(self.index.x, self.index.y, func(subim, value)) + elif isinstance(self.index, BoundsI): + if not self.image.bounds.includes(self.index): + raise _galsim.GalSimBoundsError( + "Attempt to access subImage not (fully) in image", + self.index, + self.image.bounds, + ) + if ( + hasattr(value, "bounds") + and self.index.numpyShape() != value.bounds.numpyShape() + ): + raise _galsim.GalSimIncompatibleValuesError( + "Trying to copy images that are not the same shape", + self_image=self.image, + rhs=value, + ) + + start_inds = ( + self.index.ymin - self.image.ymin, + self.index.xmin - self.image.xmin, + ) + shape = self.index.numpyShape() + subim = jax.lax.dynamic_slice(self.image.array, start_inds, shape) + + self.image._array = jax.lax.dynamic_update_slice( + self.image.array, + func( + subim, + value.array + if hasattr(value, "array") + else jnp.broadcast_to(value, self.index.numpyShape()), + ), + start_inds, + ) + else: + raise TypeError( + "This error should never be raised. " + "image.at[index] only accepts BoundsI or PositionI for the index" + ) + + return self.image + + def add(self, value): + return self._op(value, lambda x, y: x + y) + + def subtract(self, value): + return self._op(value, lambda x, y: x - y) + + def multiply(self, value): + return self._op(value, lambda x, y: x * y) + + def divide(self, value): + return self._op(value, lambda x, y: x / y) + + def power(self, value): + return self._op(value, lambda x, y: x**y) + + def mod(self, value): + return self._op(value, lambda x, y: x % y, check_integer=True) + + def floor_divide(self, value): + return self._op(value, lambda x, y: x // y, check_integer=True) + + def bitwise_and(self, value): + return self._op(value, lambda x, y: x & y, check_integer=True) + + def bitwise_xor(self, value): + return self._op(value, lambda x, y: x ^ y, check_integer=True) + + def bitwise_or(self, value): + return self._op(value, lambda x, y: x | y, check_integer=True) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 762f6fc3..b05db52a 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -4,6 +4,7 @@ from jax.tree_util import register_pytree_node_class from jax_galsim.bounds import Bounds, BoundsD, BoundsI +from jax_galsim.core.index import ImageIndexer from jax_galsim.core.utils import ensure_hashable, implements from jax_galsim.errors import GalSimImmutableError from jax_galsim.position import PositionI @@ -556,29 +557,11 @@ def __getitem__(self, *args): raise TypeError("image[..] requires either 1 or 2 args") def __setitem__(self, *args): - """Set either a subimage or a single pixel to new values. - - For example,:: - - >>> im[galsim.BoundsI(3,7,3,7)] = im2 - >>> im[galsim.PositionI(5,5)] = 17. - >>> im[5,5] = 17. - """ - if len(args) == 2: - if isinstance(args[0], BoundsI): - self.setSubImage(*args) - elif isinstance(args[0], PositionI): - self.setValue(*args) - elif isinstance(args[0], tuple): - self.setValue(*args) - else: - raise TypeError( - "image[index] only accepts BoundsI or PositionI for the index" - ) - elif len(args) == 3: - return self.setValue(*args) - else: - raise TypeError("image[..] requires either 1 or 2 args") + raise RuntimeError( + "JAX-GalSim images do not support inplace operations via " + "operators like `img +=` or via `img[index] = ...`. " + "Use the `.at` syntax instead." + ) @implements(_galsim.Image.wrap) def wrap(self, bounds, hermitian=False): @@ -1133,6 +1116,13 @@ def FindAdaptiveMom(self, *args, **kwargs): gs_image = self.to_galsim() return gs_image.FindAdaptiveMom(*args_, **kwargs_) + @property + def at(self): + """ + TODO: write docs + """ + return ImageIndexer(self) + @implements( _galsim._Image, @@ -1261,18 +1251,11 @@ def Image_add(self, other): def Image_iadd(self, other): - check_image_consistency(self, other) - try: - a = other.array - dt = a.dtype - except AttributeError: - a = other - dt = type(a) - if dt == self.array.dtype: - self._array = self.array.at[...].add(a) - else: - self._array = (self.array + a).astype(self.array.dtype) - return self + raise RuntimeError( + "JAX-GalSim images do not support inplace operations via " + "operators like `img +=` or via `img[index] = ...`. " + "Use the `.at` syntax instead." + ) def Image_sub(self, other): @@ -1289,18 +1272,11 @@ def Image_rsub(self, other): def Image_isub(self, other): - check_image_consistency(self, other) - try: - a = other.array - dt = a.dtype - except AttributeError: - a = other - dt = type(a) - if dt == self.array.dtype: - self._array = self.array.at[...].subtract(a) - else: - self._array = (self.array - a).astype(self.array.dtype) - return self + raise RuntimeError( + "JAX-GalSim images do not support inplace operations via " + "operators like `img +=` or via `img[index] = ...`. " + "Use the `.at` syntax instead." + ) def Image_mul(self, other): @@ -1313,18 +1289,11 @@ def Image_mul(self, other): def Image_imul(self, other): - check_image_consistency(self, other) - try: - a = other.array - dt = a.dtype - except AttributeError: - a = other - dt = type(a) - if dt == self.array.dtype: - self._array = self.array.at[...].multiply(a) - else: - self._array = (self.array * a).astype(self.array.dtype) - return self + raise RuntimeError( + "JAX-GalSim images do not support inplace operations via " + "operators like `img +=` or via `img[index] = ...`. " + "Use the `.at` syntax instead." + ) def Image_div(self, other): @@ -1341,20 +1310,11 @@ def Image_rdiv(self, other): def Image_idiv(self, other): - check_image_consistency(self, other) - try: - a = other.array - dt = a.dtype - except AttributeError: - a = other - dt = type(a) - if dt == self.array.dtype and not self.isinteger: - # if dtype is an integer type, then numpy doesn't allow true division /= to assign - # back to an integer array. So for integers (or mixed types), don't use /=. - self._array = self.array.at[...].divide(a) - else: - self._array = (self.array / a).astype(self.array.dtype) - return self + raise RuntimeError( + "JAX-GalSim images do not support inplace operations via " + "operators like `img +=` or via `img[index] = ...`. " + "Use the `.at` syntax instead." + ) def Image_floordiv(self, other): @@ -1372,18 +1332,11 @@ def Image_rfloordiv(self, other): def Image_ifloordiv(self, other): - check_image_consistency(self, other, integer=True) - try: - a = other.array - dt = a.dtype - except AttributeError: - a = other - dt = type(a) - if dt == self.array.dtype: - self._array = self.array // a - else: - self._array = (self.array // a).astype(self.array.dtype) - return self + raise RuntimeError( + "JAX-GalSim images do not support inplace operations via " + "operators like `img +=` or via `img[index] = ...`. " + "Use the `.at` syntax instead." + ) def Image_mod(self, other): @@ -1401,18 +1354,11 @@ def Image_rmod(self, other): def Image_imod(self, other): - check_image_consistency(self, other, integer=True) - try: - a = other.array - dt = a.dtype - except AttributeError: - a = other - dt = type(a) - if dt == self.array.dtype: - self._array = self.array % a - else: - self._array = (self.array % a).astype(self.array.dtype) - return self + raise RuntimeError( + "JAX-GalSim images do not support inplace operations via " + "operators like `img +=` or via `img[index] = ...`. " + "Use the `.at` syntax instead." + ) def Image_pow(self, other): @@ -1420,10 +1366,11 @@ def Image_pow(self, other): def Image_ipow(self, other): - if not isinstance(other, int) and not isinstance(other, float): - raise TypeError("Can only raise an image to a float or int power!") - self._array = self.array.at[...].power(other) - return self + raise RuntimeError( + "JAX-GalSim images do not support inplace operations via " + "operators like `img +=` or via `img[index] = ...`. " + "Use the `.at` syntax instead." + ) def Image_neg(self): @@ -1443,13 +1390,11 @@ def Image_and(self, other): def Image_iand(self, other): - check_image_consistency(self, other, integer=True) - try: - a = other.array - except AttributeError: - a = other - self._array = self.array & a - return self + raise RuntimeError( + "JAX-GalSim images do not support inplace operations via " + "operators like `img +=` or via `img[index] = ...`. " + "Use the `.at` syntax instead." + ) def Image_xor(self, other): @@ -1462,13 +1407,11 @@ def Image_xor(self, other): def Image_ixor(self, other): - check_image_consistency(self, other, integer=True) - try: - a = other.array - except AttributeError: - a = other - self._array = self.array ^ a - return self + raise RuntimeError( + "JAX-GalSim images do not support inplace operations via " + "operators like `img +=` or via `img[index] = ...`. " + "Use the `.at` syntax instead." + ) def Image_or(self, other): @@ -1481,13 +1424,11 @@ def Image_or(self, other): def Image_ior(self, other): - check_image_consistency(self, other, integer=True) - try: - a = other.array - except AttributeError: - a = other - self._array = self.array | a - return self + raise RuntimeError( + "JAX-GalSim images do not support inplace operations via " + "operators like `img +=` or via `img[index] = ...`. " + "Use the `.at` syntax instead." + ) # inject the arithmetic operators as methods of the Image class: diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index e79f320c..8b54968f 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -1106,3 +1106,49 @@ def test_api_gsparams(): assert getattr(jgsp, k) == v assert getattr(gsp, k) == v assert getattr(jjgsp, k) == v + + +def test_api_image_at_with_position(): + rng = np.random.default_rng(seed=10) + arr = rng.normal(size=(13, 19)) + img = jax_galsim.ImageD(jnp.asarray(arr), bounds=jax_galsim.BoundsI(3, 21, 7, 19)) + assert img.at[jax_galsim.PositionI(3, 7)].get() == arr[0, 0] + assert img.at[jax_galsim.PositionI(8, 7)].get() == arr[0, 5] + assert img.at[jax_galsim.PositionI(3, 12)].get() == arr[5, 0] + assert img.at[jax_galsim.PositionI(4, 9)].get() == arr[2, 1] + assert img.at[(4, 9)].get() == arr[2, 1] + assert img.at[4, 9].get() == arr[2, 1] + + img = img.at[4, 9].set(0.1) + assert img.at[jax_galsim.PositionI(3, 7)].get() == arr[0, 0] + assert img.at[jax_galsim.PositionI(8, 7)].get() == arr[0, 5] + assert img.at[jax_galsim.PositionI(3, 12)].get() == arr[5, 0] + assert img.at[jax_galsim.PositionI(4, 9)].get() == 0.1 + + img = img.at[4, 9].add(0.23) + assert img.at[jax_galsim.PositionI(3, 7)].get() == arr[0, 0] + assert img.at[jax_galsim.PositionI(8, 7)].get() == arr[0, 5] + assert img.at[jax_galsim.PositionI(3, 12)].get() == arr[5, 0] + assert img.at[jax_galsim.PositionI(4, 9)].get() == 0.33 + + +def test_api_image_at_with_bounds(): + rng = np.random.default_rng(seed=10) + arr = rng.normal(size=(13, 19)) + img = jax_galsim.ImageD(jnp.asarray(arr), bounds=jax_galsim.BoundsI(3, 21, 7, 19)) + bnds = jax_galsim.BoundsI(4, 6, 11, 15) + + assert img.at[...].get().array.shape == (13, 19) + np.testing.assert_array_equal(img.at[...].get().array, arr) + + assert img.at[:, :].get().array.shape == (13, 19) + np.testing.assert_array_equal(img.at[:, :].get().array, arr) + + assert img.at[bnds].get().array.shape == (5, 3) + np.testing.assert_array_equal(img.at[bnds].get().array, arr[4:9, 1:4]) + + img = img.at[bnds].set(0.0) + assert img.at[jax_galsim.PositionI(3, 7)].get() == arr[0, 0] + + img = img.at[bnds].add(0.10) + assert img.at[jax_galsim.PositionI(3, 7)].get() == arr[0, 0] From 95fb7aa7e7c8f77c97790bd03ecd6dd200b055b1 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 24 Mar 2026 05:15:21 -0500 Subject: [PATCH 08/15] fix: adopt more of the new api --- jax_galsim/convolve.py | 4 +--- jax_galsim/gsobject.py | 2 +- jax_galsim/image.py | 8 ++++---- jax_galsim/sum.py | 8 +++----- tests/jax/test_interpolatedimage_utils.py | 2 +- 5 files changed, 10 insertions(+), 14 deletions(-) diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 98717eae..fd5cd44b 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -316,9 +316,7 @@ def _drawKImage(self, image, jac=None): image = self.obj_list[0]._drawKImage(image, jac) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image._array = image._array.at[...].multiply( - obj._drawKImage(image, jac)._array - ) + image = image.at[...].multiply(obj._drawKImage(image, jac)) return image def tree_flatten(self): diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 816f247d..063a1eba 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1407,7 +1407,7 @@ def _draw_phot_while_loop_shoot( im1 = ImageD(bounds=image.bounds) added_flux += sensor.accumulate(photons, im1, orig_center) - image._array = image._array.at[...].add(im1._array) + image = image.at[...].add(im1) return _DrawPhotReturnTuple( photons, rng, added_flux, image, photon_ops, sensor, resume diff --git a/jax_galsim/image.py b/jax_galsim/image.py index b05db52a..13b45799 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -683,7 +683,7 @@ def calculate_fft(self): else: # Then we pad out with zeros ximage = Image(full_bounds, dtype=self.dtype, init_value=0) - ximage[self.bounds] = self[self.bounds] + ximage = ximage.at[self.bounds].set(self[self.bounds]) dx = self.scale # dk = 2pi / (N dk) @@ -696,7 +696,7 @@ def calculate_fft(self): jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0 ) - out._array = out._array.at[...].multiply(dx * dx) + out = out.at[...].multiply(dx * dx) out.setOrigin(0, -No2) return out @@ -737,7 +737,7 @@ def calculate_inverse_fft(self): posx_bounds = BoundsI( 0, self.bounds.xmax, self.bounds.ymin, self.bounds.ymax ) - kimage[posx_bounds] = self[posx_bounds] + kimage = kimage.at[posx_bounds].set(self[posx_bounds]) kimage = kimage.wrap(target_bounds, hermitian="x") dk = self.scale @@ -752,7 +752,7 @@ def calculate_inverse_fft(self): ) # Now cut off the bit we don't need. out = out_extra.subImage(BoundsI(-No2, No2 - 1, -No2, No2 - 1)) - out._array = out._array.at[:].multiply((dk * No2 / jnp.pi) ** 2) + out = out.at[...].multiply((dk * No2 / jnp.pi) ** 2) out.setCenter(0, 0) return out diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index 84cec25c..ee8a7a13 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -164,8 +164,8 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): image = self.obj_list[0]._drawReal(image, jac, offset, flux_scaling) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image._array = image._array.at[...].add( - obj._drawReal(image, jac, offset, flux_scaling)._array + image = image.at[...].add( + obj._drawReal(image, jac, offset, flux_scaling) ) return image @@ -173,9 +173,7 @@ def _drawKImage(self, image, jac=None): image = self.obj_list[0]._drawKImage(image, jac) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image._array = image._array.at[...].add( - obj._drawKImage(image, jac)._array - ) + image = image.at[...].add(obj._drawKImage(image, jac)) return image @property diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 0274c726..9605f90f 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -282,7 +282,7 @@ def _compute_fft_with_numpy_jax_galsim(im): else: # Then we pad out with zeros ximage = Image(full_bounds, dtype=im.dtype, init_value=0) - ximage[im.bounds] = im[im.bounds] + ximage = ximage.at[im.bounds].set(im[im.bounds]) dx = im.scale # dk = 2pi / (N dk) From 4e02345ac67799ab87ceef6b91de22b0a49f6a7a Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 24 Mar 2026 05:23:29 -0500 Subject: [PATCH 09/15] fix: put these back again --- jax_galsim/image.py | 149 +++++++++++++++++++++++++++++--------------- 1 file changed, 99 insertions(+), 50 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 13b45799..28a625a8 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1251,11 +1251,18 @@ def Image_add(self, other): def Image_iadd(self, other): - raise RuntimeError( - "JAX-GalSim images do not support inplace operations via " - "operators like `img +=` or via `img[index] = ...`. " - "Use the `.at` syntax instead." - ) + check_image_consistency(self, other) + try: + a = other.array + dt = a.dtype + except AttributeError: + a = other + dt = type(a) + if dt == self.array.dtype: + self._array = self.array.at[...].add(a) + else: + self._array = (self.array + a).astype(self.array.dtype) + return self def Image_sub(self, other): @@ -1272,11 +1279,18 @@ def Image_rsub(self, other): def Image_isub(self, other): - raise RuntimeError( - "JAX-GalSim images do not support inplace operations via " - "operators like `img +=` or via `img[index] = ...`. " - "Use the `.at` syntax instead." - ) + check_image_consistency(self, other) + try: + a = other.array + dt = a.dtype + except AttributeError: + a = other + dt = type(a) + if dt == self.array.dtype: + self._array = self.array.at[...].subtract(a) + else: + self._array = (self.array - a).astype(self.array.dtype) + return self def Image_mul(self, other): @@ -1289,11 +1303,18 @@ def Image_mul(self, other): def Image_imul(self, other): - raise RuntimeError( - "JAX-GalSim images do not support inplace operations via " - "operators like `img +=` or via `img[index] = ...`. " - "Use the `.at` syntax instead." - ) + check_image_consistency(self, other) + try: + a = other.array + dt = a.dtype + except AttributeError: + a = other + dt = type(a) + if dt == self.array.dtype: + self._array = self.array.at[...].multiply(a) + else: + self._array = (self.array * a).astype(self.array.dtype) + return self def Image_div(self, other): @@ -1310,11 +1331,20 @@ def Image_rdiv(self, other): def Image_idiv(self, other): - raise RuntimeError( - "JAX-GalSim images do not support inplace operations via " - "operators like `img +=` or via `img[index] = ...`. " - "Use the `.at` syntax instead." - ) + check_image_consistency(self, other) + try: + a = other.array + dt = a.dtype + except AttributeError: + a = other + dt = type(a) + if dt == self.array.dtype and not self.isinteger: + # if dtype is an integer type, then numpy doesn't allow true division /= to assign + # back to an integer array. So for integers (or mixed types), don't use /=. + self._array = self.array.at[...].divide(a) + else: + self._array = (self.array / a).astype(self.array.dtype) + return self def Image_floordiv(self, other): @@ -1332,11 +1362,18 @@ def Image_rfloordiv(self, other): def Image_ifloordiv(self, other): - raise RuntimeError( - "JAX-GalSim images do not support inplace operations via " - "operators like `img +=` or via `img[index] = ...`. " - "Use the `.at` syntax instead." - ) + check_image_consistency(self, other, integer=True) + try: + a = other.array + dt = a.dtype + except AttributeError: + a = other + dt = type(a) + if dt == self.array.dtype: + self._array = self.array // a + else: + self._array = (self.array // a).astype(self.array.dtype) + return self def Image_mod(self, other): @@ -1354,11 +1391,18 @@ def Image_rmod(self, other): def Image_imod(self, other): - raise RuntimeError( - "JAX-GalSim images do not support inplace operations via " - "operators like `img +=` or via `img[index] = ...`. " - "Use the `.at` syntax instead." - ) + check_image_consistency(self, other, integer=True) + try: + a = other.array + dt = a.dtype + except AttributeError: + a = other + dt = type(a) + if dt == self.array.dtype: + self._array = self.array % a + else: + self._array = (self.array % a).astype(self.array.dtype) + return self def Image_pow(self, other): @@ -1366,11 +1410,10 @@ def Image_pow(self, other): def Image_ipow(self, other): - raise RuntimeError( - "JAX-GalSim images do not support inplace operations via " - "operators like `img +=` or via `img[index] = ...`. " - "Use the `.at` syntax instead." - ) + if not isinstance(other, int) and not isinstance(other, float): + raise TypeError("Can only raise an image to a float or int power!") + self._array = self.array.at[...].power(other) + return self def Image_neg(self): @@ -1390,11 +1433,13 @@ def Image_and(self, other): def Image_iand(self, other): - raise RuntimeError( - "JAX-GalSim images do not support inplace operations via " - "operators like `img +=` or via `img[index] = ...`. " - "Use the `.at` syntax instead." - ) + check_image_consistency(self, other, integer=True) + try: + a = other.array + except AttributeError: + a = other + self._array = self.array & a + return self def Image_xor(self, other): @@ -1407,11 +1452,13 @@ def Image_xor(self, other): def Image_ixor(self, other): - raise RuntimeError( - "JAX-GalSim images do not support inplace operations via " - "operators like `img +=` or via `img[index] = ...`. " - "Use the `.at` syntax instead." - ) + check_image_consistency(self, other, integer=True) + try: + a = other.array + except AttributeError: + a = other + self._array = self.array ^ a + return self def Image_or(self, other): @@ -1424,11 +1471,13 @@ def Image_or(self, other): def Image_ior(self, other): - raise RuntimeError( - "JAX-GalSim images do not support inplace operations via " - "operators like `img +=` or via `img[index] = ...`. " - "Use the `.at` syntax instead." - ) + check_image_consistency(self, other, integer=True) + try: + a = other.array + except AttributeError: + a = other + self._array = self.array | a + return self # inject the arithmetic operators as methods of the Image class: From c9f92d1a00d9d6cfdbdfb39416ca394c87b5a131 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 24 Mar 2026 05:37:36 -0500 Subject: [PATCH 10/15] fix: put these changes back as well Co-authored-by: Matthew R. Becker --- jax_galsim/convolve.py | 2 +- jax_galsim/gsobject.py | 2 +- jax_galsim/image.py | 6 +++--- jax_galsim/noise.py | 2 +- jax_galsim/sum.py | 6 ++---- tests/jax/test_interpolatedimage_utils.py | 2 +- 6 files changed, 9 insertions(+), 11 deletions(-) diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index fd5cd44b..6961ad39 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -316,7 +316,7 @@ def _drawKImage(self, image, jac=None): image = self.obj_list[0]._drawKImage(image, jac) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image = image.at[...].multiply(obj._drawKImage(image, jac)) + image *= obj._drawKImage(image, jac) return image def tree_flatten(self): diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 063a1eba..b687e175 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1407,7 +1407,7 @@ def _draw_phot_while_loop_shoot( im1 = ImageD(bounds=image.bounds) added_flux += sensor.accumulate(photons, im1, orig_center) - image = image.at[...].add(im1) + image += im1 return _DrawPhotReturnTuple( photons, rng, added_flux, image, photon_ops, sensor, resume diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 28a625a8..7ae7700c 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -559,7 +559,7 @@ def __getitem__(self, *args): def __setitem__(self, *args): raise RuntimeError( "JAX-GalSim images do not support inplace operations via " - "operators like `img +=` or via `img[index] = ...`. " + "`img[index] = ...`. " "Use the `.at` syntax instead." ) @@ -696,7 +696,7 @@ def calculate_fft(self): jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0 ) - out = out.at[...].multiply(dx * dx) + out *= dx * dx out.setOrigin(0, -No2) return out @@ -752,7 +752,7 @@ def calculate_inverse_fft(self): ) # Now cut off the bit we don't need. out = out_extra.subImage(BoundsI(-No2, No2 - 1, -No2, No2 - 1)) - out = out.at[...].multiply((dk * No2 / jnp.pi) ** 2) + out *= (dk * No2 / jnp.pi) ** 2 out.setCenter(0, 0) return out diff --git a/jax_galsim/noise.py b/jax_galsim/noise.py index 7964a1c4..51325ea3 100644 --- a/jax_galsim/noise.py +++ b/jax_galsim/noise.py @@ -28,7 +28,7 @@ def addNoiseSNR(self, noise, snr, preserve_flux=False): else: sn_meas = jnp.sqrt(sumsq / noise_var) flux = snr / sn_meas - self._array = self._array.at[...].multiply(flux) + self *= flux self.addNoise(noise) return noise_var diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index ee8a7a13..958e6bfa 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -164,16 +164,14 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): image = self.obj_list[0]._drawReal(image, jac, offset, flux_scaling) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image = image.at[...].add( - obj._drawReal(image, jac, offset, flux_scaling) - ) + image += obj._drawReal(image, jac, offset, flux_scaling) return image def _drawKImage(self, image, jac=None): image = self.obj_list[0]._drawKImage(image, jac) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image = image.at[...].add(obj._drawKImage(image, jac)) + image += obj._drawKImage(image, jac) return image @property diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 9605f90f..7b84734e 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -290,7 +290,7 @@ def _compute_fft_with_numpy_jax_galsim(im): out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) out._array = np.fft.fftshift(np.fft.rfft2(np.fft.fftshift(ximage.array)), axes=0) - out._array *= dx * dx + out *= dx * dx out.setOrigin(0, -No2) return out From 2667010af8f92a5ceb086680f0a201b5f3f547d2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 24 Mar 2026 05:44:16 -0500 Subject: [PATCH 11/15] test: added test for setitem usage --- tests/jax/test_api.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 8b54968f..6e15c8bf 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -1152,3 +1152,17 @@ def test_api_image_at_with_bounds(): img = img.at[bnds].add(0.10) assert img.at[jax_galsim.PositionI(3, 7)].get() == arr[0, 0] + + +def test_api_image_raise_on_setitem(): + rng = np.random.default_rng(seed=10) + arr = rng.normal(size=(13, 19)) + img = jax_galsim.ImageD(jnp.asarray(arr), bounds=jax_galsim.BoundsI(3, 21, 7, 19)) + + with pytest.raises(RuntimeError) as e: + img[4, 11] = 10 + assert "JAX-GalSim images do not support inplace operations" in str(e.value) + + with pytest.raises(RuntimeError) as e: + img[4, 11] += 10 + assert "JAX-GalSim images do not support inplace operations" in str(e.value) From e3979910fb41362577783f9a5d2a18be77006e05 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 24 Mar 2026 05:52:25 -0500 Subject: [PATCH 12/15] fix: use `.at` syntax for interp images --- jax_galsim/interpolatedimage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 2baeb6e6..c46e256b 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -665,7 +665,7 @@ def _xim(self): xim.wcs = PixelScale(1.0) # Now place the given image in the center of the padding image: - xim[self._image.bounds] = self._image + xim = xim.at[self._image.bounds].set(self._image) return xim From c152147bcd03997dbca08bd0e005980b47b9826e Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 24 Mar 2026 06:08:05 -0500 Subject: [PATCH 13/15] fix: more api fies --- tests/jax/test_api.py | 1 + tests/jax/test_image_wrapping.py | 12 ++++++------ tests/jax/test_interpolatedimage_utils.py | 6 +++--- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 6e15c8bf..85508bce 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -347,6 +347,7 @@ def _reg_fun(p): "tree_unflatten", "from_galsim", "to_galsim", + "at", ]: # this deprecated method doesn't have consistent doc strings in galsim if ( diff --git a/tests/jax/test_image_wrapping.py b/tests/jax/test_image_wrapping.py index 63136a4b..118d26a7 100644 --- a/tests/jax/test_image_wrapping.py +++ b/tests/jax/test_image_wrapping.py @@ -39,11 +39,11 @@ def test_image_wrapping_expand_contract(): ) # val = 2*(i-j)**2 + 3j*(i+j) - im[i, j] = val + im = im.at[i, j].set(val) if j >= 0: - im2[i, j] = val + im2 = im2.at[i, j].set(val) if i >= 0: - im3[i, j] = val + im3 = im3.at[i, j].set(val) # print("im = ",im.array) @@ -112,11 +112,11 @@ def test_image_wrapping_autodiff(func, K, L): ) # val = 2*(i-j)**2 + 3j*(i+j) - im[i, j] = val + im = im.at[i, j].set(val) if j >= 0: - im2[i, j] = val + im2 = im2.at[i, j].set(val) if i >= 0: - im3[i, j] = val + im3 = im3.at[i, j].set(val) ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 7b84734e..2d97bb6a 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -94,12 +94,12 @@ def test_interpolatedimage_utils_draw_with_interpolant_kval(interp): for y in range(kimherm.bounds.ymin, kimherm.bounds.ymax + 1): for x in range(kimherm.bounds.xmin, kimherm.bounds.xmax + 1): if x >= 0: - kimherm[x, y] = kim[x, y] + kimherm = kimherm.at[x, y].set(kim[x, y]) else: if y == minherm: - kimherm[x, y] = kim[-x, y].conj() + kimherm = kimherm.at[x, y].set(kim[-x, y].conj()) else: - kimherm[x, y] = kim[-x, -y].conj() + kimherm = kimherm.at[x, y].set(kim[-x, -y].conj()) for x in range(nherm): for y in range(nherm): np.testing.assert_allclose( From fdbef71990b4c486dafe9571d4eb6c3e84df4579 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 24 Mar 2026 06:56:23 -0500 Subject: [PATCH 14/15] tests: get tests to pass with .at syntax --- jax_galsim/core/index.py | 22 ++++++++++++++-------- jax_galsim/image.py | 2 ++ tests/Coord | 2 +- tests/GalSim | 2 +- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/jax_galsim/core/index.py b/jax_galsim/core/index.py index c48bdf66..038c4810 100644 --- a/jax_galsim/core/index.py +++ b/jax_galsim/core/index.py @@ -117,9 +117,12 @@ def set(self, value): ) self.image._array = jax.lax.dynamic_update_slice( self.image.array, - value.array - if hasattr(value, "array") - else jnp.broadcast_to(value, self.index.numpyShape()), + jnp.astype( + value.array + if hasattr(value, "array") + else jnp.broadcast_to(value, self.index.numpyShape()), + self.image.array.dtype, + ), start_inds, ) else: @@ -228,11 +231,14 @@ def _op(self, value, func, check_integer=False): self.image._array = jax.lax.dynamic_update_slice( self.image.array, - func( - subim, - value.array - if hasattr(value, "array") - else jnp.broadcast_to(value, self.index.numpyShape()), + jnp.astype( + func( + subim, + value.array + if hasattr(value, "array") + else jnp.broadcast_to(value, self.index.numpyShape()), + ), + self.image.array.dtype, ), start_inds, ) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 7ae7700c..e873d051 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -557,6 +557,8 @@ def __getitem__(self, *args): raise TypeError("image[..] requires either 1 or 2 args") def __setitem__(self, *args): + if self.isconst: + raise _galsim.GalSimImmutableError("Cannot modify an immutable Image", self) raise RuntimeError( "JAX-GalSim images do not support inplace operations via " "`img[index] = ...`. " diff --git a/tests/Coord b/tests/Coord index f8120841..ac093efd 160000 --- a/tests/Coord +++ b/tests/Coord @@ -1 +1 @@ -Subproject commit f812084154af7cc44b6d63eca49b2fa515e87fe2 +Subproject commit ac093efd5d33018162b78e0b32f372c661422a6e diff --git a/tests/GalSim b/tests/GalSim index 04918b11..6f40ae96 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 04918b118926eafc01ec9403b8afed29fb918d51 +Subproject commit 6f40ae9696e298acd8fa5d6aa24b89510956ee3b From b1bd2028b965cb4745a046ce38994ef071e3990e Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 24 Mar 2026 09:30:44 -0500 Subject: [PATCH 15/15] fix: more inplace ops --- jax_galsim/image.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index e873d051..e849a4a2 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1263,7 +1263,9 @@ def Image_iadd(self, other): if dt == self.array.dtype: self._array = self.array.at[...].add(a) else: - self._array = (self.array + a).astype(self.array.dtype) + self._array = self.array.at[...].set( + (self.array + a).astype(self.array.dtype) + ) return self @@ -1291,7 +1293,9 @@ def Image_isub(self, other): if dt == self.array.dtype: self._array = self.array.at[...].subtract(a) else: - self._array = (self.array - a).astype(self.array.dtype) + self._array = self.array.at[...].set( + (self.array - a).astype(self.array.dtype) + ) return self @@ -1315,7 +1319,9 @@ def Image_imul(self, other): if dt == self.array.dtype: self._array = self.array.at[...].multiply(a) else: - self._array = (self.array * a).astype(self.array.dtype) + self._array = self.array.at[...].set( + (self.array * a).astype(self.array.dtype) + ) return self @@ -1345,7 +1351,9 @@ def Image_idiv(self, other): # back to an integer array. So for integers (or mixed types), don't use /=. self._array = self.array.at[...].divide(a) else: - self._array = (self.array / a).astype(self.array.dtype) + self._array = self.array.at[...].set( + (self.array / a).astype(self.array.dtype) + ) return self @@ -1372,9 +1380,11 @@ def Image_ifloordiv(self, other): a = other dt = type(a) if dt == self.array.dtype: - self._array = self.array // a + self._array = self._array.at[...].set(self.array // a) else: - self._array = (self.array // a).astype(self.array.dtype) + self._array = self.array.at[...].set( + (self.array // a).astype(self.array.dtype) + ) return self @@ -1401,9 +1411,11 @@ def Image_imod(self, other): a = other dt = type(a) if dt == self.array.dtype: - self._array = self.array % a + self._array = self.array.at[...].set(self.array % a) else: - self._array = (self.array % a).astype(self.array.dtype) + self._array = self.array.at[...].set( + (self.array % a).astype(self.array.dtype) + ) return self @@ -1440,7 +1452,7 @@ def Image_iand(self, other): a = other.array except AttributeError: a = other - self._array = self.array & a + self._array = self.array.at[...].set(self.array & a) return self @@ -1459,7 +1471,7 @@ def Image_ixor(self, other): a = other.array except AttributeError: a = other - self._array = self.array ^ a + self._array = self.array.at[...].set(self.array ^ a) return self @@ -1478,7 +1490,7 @@ def Image_ior(self, other): a = other.array except AttributeError: a = other - self._array = self.array | a + self._array = self.array.at[...].set(self.array | a) return self