diff --git a/src/diffpy/snmf/snmf_class.py b/src/diffpy/snmf/snmf_class.py index a382153..b19c031 100644 --- a/src/diffpy/snmf/snmf_class.py +++ b/src/diffpy/snmf/snmf_class.py @@ -10,23 +10,24 @@ class SNMFOptimizer: """An implementation of stretched NMF (sNMF), including sparse stretched NMF. - Instantiating the SNMFOptimizer class prepares initial guesses and sets up the - optimization. It can then be run using fit(). + Instantiating the SNMFOptimizer class prepares initial guesses and + sets up the optimization. It can then be run using fit(). The results matrices can be accessed as instance attributes of the class (components_, weights_, and stretch_). For more information on sNMF, please reference: - Gu, R., Rakita, Y., Lan, L. et al. Stretched non-negative matrix factorization. - npj Comput Mater 10, 193 (2024). https://doi.org/10.1038/s41524-024-01377-5 + Gu, R., Rakita, Y., Lan, L. et al. + Stretched non-negative matrix factorization. + npj Compu Mater 10, 193 (2024). https://doi.org/10.1038/s41524-024-01377-5 Attributes ---------- source_matrix : ndarray - The original, unmodified data to be decomposed and later, compared against. - Shape is (length_of_signal, number_of_signals). + The original, unmodified data to be decomposed and later, + compared against. Shape is (length_of_signal, number_of_signals). stretch_ : ndarray - The best guess (or while running, the current guess) for the stretching - factor matrix. + The best guess (or while running, the current guess) for the + stretching factor matrix. components_ : ndarray The best guess (or while running, the current guess) for the matrix of component intensities. @@ -34,34 +35,38 @@ class SNMFOptimizer: The best guess (or while running, the current guess) for the matrix of component weights. rho : float - The stretching factor that influences the decomposition. Zero corresponds to no - stretching present. Relatively insensitive and typically adjusted in powers of 10. + The stretching factor that influences the decomposition. Zero + corresponds to no stretching present. Relatively insensitive and + typically adjusted in powers of 10. eta : float - The sparsity factor that influences the decomposition. Should be set to zero for - non-sparse data such as PDF. Can be used to improve results for sparse data such - as XRD, but due to instability, should be used only after first selecting the - best value for rho. Suggested adjustment is by powers of 2. + The sparsity factor that influences the decomposition. Should be set + to zero for non-sparse data such as PDF. Can be used to improve + results for sparse data such as XRD, but due to instability, should + be used only after first selecting the best value for rho. Suggested + adjustment is by powers of 2. max_iter : int - The maximum number of times to update each of stretch, components, and weights before stopping - the optimization. + The maximum number of times to update each of stretch, components, + and weights before stopping the optimization. min_iter : int - The minimum number of times to update each of stretch, components, and weights before terminating - the optimization due to low/no improvement. + The minimum number of times to update each of stretch, components, and + weights before terminating the optimization due to low/no improvement. tol : float - The convergence threshold. This is the minimum fractional improvement in the - objective function to allow without terminating the optimization. + The convergence threshold. This is the minimum fractional improvement + in the objective function to allow without terminating the + optimization. n_components : int - The number of components to extract from source_matrix. Must be provided when and only when - init_weights is not provided. + The number of components to extract from source_matrix. Must be + provided when and only when init_weights is not provided. random_state : int - The seed for the initial guesses at the matrices (stretch, components, and weights) created by - the decomposition. + The seed for the initial guesses at the matrices (stretch, components, + and weights) created by the decomposition. num_updates : int - The total number of times that any of (stretch, components, and weights) have had their values changed. - If not terminated by other means, this value is used to stop when reaching max_iter. + The total number of times that any of (stretch, components, + and weights) have had their values changed. If not terminated by + other means, this value is used to stop when reaching max_iter. objective_difference : float - The change in the objective function value since the last update. A negative value - means that the result improved. + The change in the objective function value since the last update. + A negative value means that the result improved. """ def __init__( @@ -82,31 +87,37 @@ def __init__( Parameters ---------- source_matrix : ndarray - The data to be decomposed. Shape is (length_of_signal, number_of_conditions). - init_weights : ndarray Optional Default = rng.beta(a=2.0, b=2.0, size=(n_components, n_signals)) - The initial guesses for the component weights at each stretching condition. - Shape is (number_of_components, number_of_signals) Must provide exactly one - of this or n_components. - init_components : ndarray Optional Default = rng.random((self.signal_length, self.n_components)) + The data to be decomposed. Shape is (length_of_signal, + number_of_conditions). + init_weights : ndarray Optional Default = rng.beta(a=2.0, b=2.0, + size=(n_components, n_signals)) + The initial guesses for the component weights at each stretching + condition. Shape is (number_of_components, number_of_signals) + Must provide exactly one of this or n_components. + init_components : ndarray Optional Default = rng.random( + (self.signal_length, self.n_components)) The initial guesses for the intensities of each component per - row/sample/angle. Shape is (length_of_signal, number_of_components). - init_stretch : ndarray Optional Default = np.ones((self.n_components, self.n_signals)) + self._rng.normal( + row/sample/angle. Shape is (length_of_signal, number_of_components) + init_stretch : ndarray Optional Default = np.ones((self.n_components, + self.n_signals)) + self._rng.normal( 0, 1e-3, size=(self.n_components, self.n_signals) - The initial guesses for the stretching factor for each component, at each - condition (for each signal). Shape is (number_of_components, number_of_signals). + The initial guesses for the stretching factor for each component, + at each condition (for each signal). + Shape is (number_of_components, number_of_signals). max_iter : int Optional Default = 500 - The maximum number of times to update each of A, X, and Y before stopping - the optimization. + The maximum number of times to update each of A, X, and Y before + stopping the optimization. tol : float Optional Default = 5e-7 - The convergence threshold. This is the minimum fractional improvement in the - objective function to allow without terminating the optimization. Note that - a minimum of 20 updates are run before this parameter is checked. + The convergence threshold. This is the minimum fractional + improvement in the objective function to allow without terminating + the optimization. Note that a minimum of 20 updates are run before + this parameter is checked. n_components : int Optional Default = None - The number of components to extract from source_matrix. Must be provided when and only when - Y0 is not provided. + The number of components to extract from source_matrix. Must be + provided when and only when Y0 is not provided. random_state : int Optional Default = None - The seed for the initial guesses at the matrices (A, X, and Y) created by - the decomposition. + The seed for the initial guesses at the matrices (A, X, and Y) + created by the decomposition. show_plots : boolean Optional Default = False Enables plotting at each step of the decomposition. """ @@ -126,8 +137,8 @@ def __init__( n_components is not None and init_weights is not None ): raise ValueError( - "Conflicting source for n_components. Must provide either init_weights or n_components " - "directly, but not both." + "Conflicting source for n_components. Must provide either " + "init_weights or n_components directly, but not both." ) # Initialize weights and determine number of components @@ -167,7 +178,7 @@ def __init__( self.init_weights = self.weights_.copy() self.init_stretch = self.stretch_.copy() - # Second-order spline: Tridiagonal (-2 on diagonal, 1 on sub/superdiagonals) + # Second-order spline: Tridiagonal (-2 on diag, 1 on sub/superdiags) self._spline_smooth_operator = 0.25 * diags( [1, -2, 1], offsets=[0, 1, 2], @@ -181,17 +192,20 @@ def fit(self, rho=0, eta=0, reset=True): Parameters ---------- rho : float Optional Default = 0 - The stretching factor that influences the decomposition. Zero corresponds to no - stretching present. Relatively insensitive and typically adjusted in powers of 10. + The stretching factor that influences the decomposition. Zero + corresponds to no stretching present. Relatively insensitive and + typically adjusted in powers of 10. eta : int Optional Default = 0 - The sparsity factor that influences the decomposition. Should be set to zero for - non-sparse data such as PDF. Can be used to improve results for sparse data such - as XRD, but due to instability, should be used only after first selecting the + The sparsity factor that influences the decomposition. Should be + set to zero for non-sparse data such as PDF. Can be used to + improve results for sparse data such as XRD, but due to + instability, should be used only after first selecting the best value for rho. Suggested adjustment is by powers of 2. reset : boolean Optional Default = True - Whether to return to the initial set of components_, weights_, and stretch_ before - running the optimization. When set to False, sequential calls to fit() will use the - output of the previous fit() as their input. + Whether to return to the initial set of components_, weights_, + and stretch_ before running the optimization. When set to False, + sequential calls to fit() will use the output of the previous + fit() as their input. """ if reset: @@ -232,7 +246,8 @@ def fit(self, rho=0, eta=0, reset=True): ) # Square root penalty print( f"Start, Objective function: {self.objective_function:.5e}" - f", Obj - reg/sparse: {self.objective_function - regularization_term - sparsity_term:.5e}" + f", Obj - reg/sparse: {self.objective_function - + regularization_term - sparsity_term:.5e}" ) # Main optimization loop @@ -253,11 +268,13 @@ def fit(self, rho=0, eta=0, reset=True): ) # Square root penalty print( f"Obj fun: {self.objective_function:.5e}, " - f"Obj - reg/sparse: {self.objective_function - regularization_term - sparsity_term:.5e}, " + f"Obj - reg/sparse: {self.objective_function + - regularization_term + - sparsity_term:.5e}, " f"Iter: {self.outiter}" ) - # Convergence check: Stop if diffun is small and at least min_iter iterations have passed + # Conv. check: Stop if diffun small and min_iter iterations passed print( "Checking if ", self.objective_difference, @@ -286,7 +303,7 @@ def normalize_results(self): stretch_row_max = np.max(self.stretch_, axis=1, keepdims=True) self.stretch_ = self.stretch_ / stretch_row_max - # effectively just re-running with component updates only vs normalized weights/stretch + # just re-running with component updates only vs norm. weights/stretch self._grad_components = np.zeros_like( self.components_ ) # Gradient of X (zeros for now) @@ -306,7 +323,8 @@ def normalize_results(self): self.residuals = self.get_residual_matrix() self.objective_function = self.get_objective_function() print( - f"Objective function after normalize_components: {self.objective_function:.5e}" + f"Objective function after normalize_components: " + f"{self.objective_function:.5e}" ) self._objective_history.append(self.objective_function) self.objective_difference = ( @@ -333,7 +351,8 @@ def outer_loop(self): self.residuals = self.get_residual_matrix() self.objective_function = self.get_objective_function() print( - f"Objective function after update_components: {self.objective_function:.5e}" + f"Objective function after update_components: " + f"{self.objective_function:.5e}" ) self._objective_history.append(self.objective_function) self.objective_difference = ( @@ -358,7 +377,8 @@ def outer_loop(self): self.residuals = self.get_residual_matrix() self.objective_function = self.get_objective_function() print( - f"Objective function after update_weights: {self.objective_function:.5e}" + f"Objective function after update_weights: " + f"{self.objective_function:.5e}" ) self._objective_history.append(self.objective_function) self.objective_difference = ( @@ -392,7 +412,8 @@ def outer_loop(self): self.residuals = self.get_residual_matrix() self.objective_function = self.get_objective_function() print( - f"Objective function after update_stretch: {self.objective_function:.5e}" + f"Objective function after update_stretch: " + f"{self.objective_function:.5e}" ) self._objective_history.append(self.objective_function) self.objective_difference = ( @@ -486,11 +507,14 @@ def compute_stretched_components( Outputs ------- - stretched_components : array, shape (signal_len, n_components * n_signals) + stretched_components : array, + shape (signal_len, n_components * n_signals) Interpolated and weighted components. - d_stretched_components : array, shape (signal_len, n_components * n_signals) + d_stretched_components : array, shape + (signal_len, n_components * n_signals) First derivatives with respect to stretch. - dd_stretched_components : array, shape (signal_len, n_components * n_signals) + dd_stretched_components : array, shape + (signal_len, n_components * n_signals) Second derivatives with respect to stretch. """ @@ -512,18 +536,18 @@ def compute_stretched_components( stretch = np.clip(stretch, eps, None) stretch_inv = 1.0 / stretch - # Apply stretching to the original sample indices, represented as a "time-stretch" + # Apply stretching to original sample indices as a "time-stretch" t = ( np.arange(signal_len, dtype=float)[:, None, None] * stretch_inv[None, :, :] ) # has shape (signal_len, n_components, n_signals) - # For each stretched coordinate, find its prior integer (original) index and their difference + # For each str coordinate, find prior integer index and their diff i0 = np.floor(t).astype(np.int64) # prior original index alpha = t - i0.astype(float) # fractional distance between left/right - # Clip indices to valid range (0, signal_len - 1) to maintain original size + # Clip to valid range (0, signal_len - 1) to maintain original size max_idx = signal_len - 1 i0 = np.clip(i0, 0, max_idx) i1 = np.clip(i0 + 1, 0, max_idx) @@ -531,7 +555,7 @@ def compute_stretched_components( # Gather sample values comps_3d = components[ :, :, None - ] # expand components by a dimension for broadcasting across n_signals + ] # expand components by a dimension for broadcast across n_signals c0 = np.take_along_axis(comps_3d, i0, axis=0) # left sample values c1 = np.take_along_axis(comps_3d, i1, axis=0) # right sample values @@ -551,7 +575,7 @@ def compute_stretched_components( d_weighted = d_unweighted * weights[None, :, :] dd_weighted = dd_unweighted * weights[None, :, :] - # Flatten back to expected shape (signal_len, n_components * n_signals) + # Flatten to expected shape (signal_len, n_components * n_signals) return ( interp_weighted.reshape(signal_len, n_components * n_signals), d_weighted.reshape(signal_len, n_components * n_signals), @@ -641,7 +665,8 @@ def solve_quadratic_program(self, t, m): Parameters: - t: (N, k) ndarray - - source_matrix_col: (N,) column of source_matrix for the corresponding m + - source_matrix_col: (N,) column of source_matrix for the + corresponding m Returns: - y: (k,) optimal solution @@ -768,7 +793,8 @@ def update_weights(self): for signal in range(self.n_signals): # Stretch factors for this signal across components: this_stretch = self.stretch_[:, signal] - # Build stretched_comps[:, k] by interpolating component at frac. pos. index / this_stretch[comp] + # Build stretched_comps[:, k] by interpolating component + # at frac. pos. index / this_stretch[comp] stretched_comps = np.empty( (self.signal_length, self.n_components), dtype=self.components_.dtype, @@ -827,10 +853,13 @@ def update_stretch(self): """Updates stretching matrix using constrained optimization (equivalent to fmincon in MATLAB).""" - # Flatten stretch for compatibility with the optimizer (since SciPy expects 1D input) + # Flatten stretch for compatibility with the optimizer + # (since SciPy expects 1D input) stretch_flat_initial = self.stretch_.flatten() # Define the optimization function + cache = {"x": None, "fun": None, "grad": None} + def objective(stretch_vec): stretch_matrix = stretch_vec.reshape( self.stretch_.shape @@ -839,6 +868,26 @@ def objective(stretch_vec): gra = gra.flatten() return fun, gra + def fun_only(stretch_vec): + if cache["x"] is None or not np.array_equal( + stretch_vec, cache["x"] + ): + fun, grad = objective(stretch_vec) + cache["x"] = stretch_vec.copy() + cache["fun"] = fun + cache["grad"] = grad + return cache["fun"] + + def jac_only(stretch_vec): + if cache["x"] is None or not np.array_equal( + stretch_vec, cache["x"] + ): + fun, grad = objective(stretch_vec) + cache["x"] = stretch_vec.copy() + cache["fun"] = fun + cache["grad"] = grad + return cache["grad"] + # Optimization constraints: lower bound 0.1, no upper bound bounds = [ (0.1, None) @@ -846,10 +895,10 @@ def objective(stretch_vec): # Solve optimization problem (equivalent to fmincon) result = minimize( - fun=lambda stretch_vec: objective(stretch_vec)[0], + fun=fun_only, x0=stretch_flat_initial, method="trust-constr", # Substitute for 'trust-region-reflective' - jac=lambda stretch_vec: objective(stretch_vec)[1], # Gradient + jac=jac_only, # Gradient bounds=bounds, ) @@ -870,13 +919,15 @@ def _compute_objective_function( residuals : ndarray Difference between reconstructed and observed data. stretch : ndarray - Stretching factors :math:`A` applied to each component across samples. + Stretching factors :math:`A` applied to each component + across samples. rho : float Regularization parameter enforcing smooth variation in :math:`A`. eta : float Sparsity-promoting regularization parameter applied to :math:`X`. spline_smooth_operator : ndarray - Linear operator :math:`L` penalizing non-smooth changes in :math:`A`. + Linear operator :math:`L` penalizing non-smooth changes + in :math:`A`. Returns ------- @@ -894,13 +945,14 @@ def _compute_objective_function( + \tfrac{\rho}{2} \lVert L A \rVert_F^2 + \eta \sum_{i,j} \sqrt{X_{ij}} \,, - where :math:`Z` is the data matrix, :math:`Y` contains the non-negative - weights, :math:`S(A)` denotes the spline-interpolated stretching operator, - and :math:`\lVert \cdot \rVert_F` is the Frobenius norm. + where :math:`Z` is the data matrix, :math:`Y` contains + the non-negative weights, :math:`S(A)` denotes the spline-interpolated + stretching operator, and :math:`\lVert \cdot \rVert_F` is the + Frobenius norm. Special cases ------------- - - :math:`\rho = 0` — no smoothness regularization on stretching factors. + - :math:`\rho = 0` — no smoothness regularization on stretch factors - :math:`\eta = 0` — no sparsity promotion on components. - :math:`\rho = \eta = 0` — reduces to the classical NMF least-squares objective :math:`\tfrac{1}{2} \lVert Z - YX \rVert_F^2`.