Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions tractor/lsqr_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def _optimize_forcedphot_core(
for um, dd in zip(umods, derivs):
if um is None:
continue
dd.append((um * scale, tim))
dd.append((um , scale, tim)) # When you do um * scale,
# Tractor’s Patch.__mul__ allocates a new numpy array, increasing memory
#logverb('forced phot: derivs', Time() - t0)
if sky:
# Sky derivatives are part of the image derivatives, so go
Expand Down Expand Up @@ -369,7 +370,11 @@ def getUpdateDirection(self, tractor, allderivs, damp=0., priors=True,
imgoffs = {}
nextrow = 0
for param in allderivs:
for deriv, img in param:
for item in param:
if len(item) == 3:
deriv, deriv_scale, img = item
else:
deriv, img = item
if img in imgoffs:
continue
imgoffs[img] = nextrow
Expand All @@ -389,7 +394,14 @@ def getUpdateDirection(self, tractor, allderivs, damp=0., priors=True,
RR = []
VV = []
WW = []
for (deriv, img) in param:
for item in param:

if len(item) == 3:
deriv, deriv_scale, img = item
else:
deriv, img = item
deriv_scale = 1.0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add an inline comment here reminding that deriv_scale is unity so the vals calculation just multiplies by a constant


inverrs = img.getInvError()
(H, W) = img.shape
row0 = imgoffs[img]
Expand All @@ -409,7 +421,7 @@ def getUpdateDirection(self, tractor, allderivs, damp=0., priors=True,
continue
rows = row0 + pix[nz]
#print('Adding derivative', deriv.getName(), 'for image', img.name)
vals = dimg.flat[nz]
vals = dimg.flat[nz] * deriv_scale
w = inverrs[deriv.getSlice(img)].flat[nz]
assert(vals.shape == w.shape)
# if not scales_only:
Expand Down Expand Up @@ -669,7 +681,6 @@ def getUpdateDirection(self, tractor, allderivs, damp=0., priors=True,
return X, 1./np.array(var)

return X

# def getParameterScales(self):
# print(self.getName()+': Finding derivs...')
# allderivs = self.getDerivs()
Expand Down
27 changes: 25 additions & 2 deletions tractor/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,28 @@
from astrometry.util.ttime import Time
from tractor.engine import logverb, OptResult, logmsg


import numba

@numba.njit(fastmath=True, nogil=True)
def fast_add_to(mod_img, patch_data, counts, x0, y0):
img_h, img_w = mod_img.shape
patch_h, patch_w = patch_data.shape

# 1. Equivalent to get_overlapping_region for Y
y_start = max(0, -y0)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a docstring clarifying what the input parameters represent. I understand it is trying to find the overlap between mod_img and patch_data, but it's unclear which is the "reference" vs "desired" size.

y_end = min(patch_h, img_h - y0)

# 2. Equivalent to get_overlapping_region for X
x_start = max(0, -x0)
x_end = min(patch_w, img_w - x0)

# 3. Add to image (avoids empty list checks, if start >= end, loop just doesn't run)
for y in range(y_start, y_end):
for x in range(x_start, x_end):
# mod_img[y0 + y, x0 + x] is the 'out' coordinate
# patch_data[y, x] is the 'in' coordinate
mod_img[y0 + y, x0 + x] += patch_data[y, x] * counts

class Optimizer(object):
def optimize(self, tractor, alphas=None, damp=0, priors=True,
scale_columns=True, shared_params=True, variance=False,
Expand Down Expand Up @@ -221,6 +242,7 @@ def _get_umodels(self, tractor, srcs, imgs, minsb, rois, **kwargs):
umodels.append(umods)
return umodels, umodtosource, umodsforsource


def _optimize_forcedphot_core(
self, tractor,
result, umodels, imlist, mod0, scales, skyderivs, minFlux,
Expand Down Expand Up @@ -531,7 +553,8 @@ def _getims(self, fluxes, imgs, umodels, mod0, scales, sky, minFlux, rois):
assert(np.isfinite(counts))
assert(np.all(np.isfinite(um.patch)))
# print 'Adding umod', um, 'with counts', counts, 'to mod', mod.shape
(um * counts).addTo(mod)
# (um * counts).addTo(mod)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can either remove this (since it's commented out), or keep it in as reference for which original function we have substituted/replaced. Which do you think makes more sense?

fast_add_to(mod, um.patch, counts, um.x0, um.y0)

ie = img.getInvError()
im = img.getImage()
Expand Down