88# License: BSD 3 clause
99
1010
11+ def deviation_from_normality (matrices_tensor ):
12+ """
13+ Calculates the total deviation from normality for a set of matrices.
14+
15+ Metric: Sum of || A @ A.T - A.T @ A ||_F^2 for all matrices in the tensor.
16+
17+ Args:
18+ matrices_tensor (Tensor): Dimension (k, k, n)
19+
20+ Returns:
21+ float: The total deviation error.
22+ """
23+ n_matrices = matrices_tensor .shape [2 ]
24+ total_deviation = 0.0
25+
26+ for i in range (n_matrices ):
27+ A = matrices_tensor [:, :, i ]
28+ # Calculate A Transpose
29+ A_t = tl .transpose (A )
30+
31+ # Calculate Commutator: (A * A^T) - (A^T * A)
32+ # Note: Depending on the backend, tl.dot might act differently on 2D matrices.
33+ # Using explicit matrix multiplication is safer if available,
34+ # but here is the standard dot approach for 2D slices:
35+ commutator = tl .dot (A , A_t ) - tl .dot (A_t , A )
36+
37+ # Add the squared norm of the commutator
38+ total_deviation += tl .norm (commutator ) ** 2
39+
40+ return total_deviation
41+
42+
1143def joint_matrix_diagonalization (
1244 matrices_tensor ,
1345 max_n_iter : int = 50 ,
14- threshold : float = 1e-10 ,
46+ threshold : float = 1e-8 ,
1547 verbose : bool = False ,
1648):
1749 """
@@ -46,7 +78,7 @@ def joint_matrix_diagonalization(
4678 Args:
4779 X (_type_): n matrices, organized in a single tensor of dimension (k, k, n).
4880 max_n_iter (int, optional): Maximum iteration number. Defaults to 50.
49- threshold (float, optional): Threshold for decrease in error indicating convergence. Defaults to 1e-10 .
81+ threshold (float, optional): Threshold for decrease in deviation indicating convergence. Defaults to 1e-8 .
5082 verbose (bool, optional): Output progress information during diagonalization. Defaults to False.
5183
5284 Raises:
@@ -67,15 +99,11 @@ def joint_matrix_diagonalization(
6799 assert tl .ndim (matrices_tensor ) == 3 , "Input must be a 3D tensor"
68100 assert matrix_dimension == matrices_tensor .shape [1 ], "All matrices must be square."
69101
70- # Initial error calculation
71- # Transpose is because np.tril operates on the last two dimensions
72- error = (
73- tl .norm (matrices_tensor ) ** 2.0
74- - tl .norm (tl .diagonal (matrices_tensor , axis1 = 1 , axis2 = 2 )) ** 2.0
75- )
102+ # Deviation from normality is strictly decreasing
103+ deviation = deviation_from_normality (matrices_tensor )
76104
77105 if verbose :
78- print (f"Sweep # 0: e = { error :.3e} " )
106+ print (f"Sweep # 0: dev = { deviation :.3e} " )
79107
80108 # Initialize transformation matrix as identity
81109 transform_P = tl .eye (matrix_dimension )
@@ -205,18 +233,14 @@ def joint_matrix_diagonalization(
205233 pvec * tl .sin (theta_k ) + transform_P [:, q ] * tl .cos (theta_k ),
206234 )
207235
208- # Error computation, check if loop needed...
209- old_error = error
210- error = (
211- tl .norm (matrices_tensor ) ** 2.0
212- - tl .norm (tl .diagonal (matrices_tensor , axis1 = 1 , axis2 = 2 )) ** 2.0
213- )
236+ # Update deviation from normality
237+ old_deviation = deviation
238+ deviation = deviation_from_normality (matrices_tensor )
214239
215240 if verbose :
216- print (f"Sweep # { k + 1 } : e = { error :.3e} " )
241+ print (f"Sweep # { k + 1 } : dev = { deviation :.3e} " )
217242
218- # TODO: Strangely the error increases on the first iteration
219- if old_error - error < threshold and k > 2 :
243+ if (old_deviation - deviation < threshold ) or (deviation < threshold ):
220244 break
221245
222246 return matrices_tensor , transform_P
0 commit comments