Skip to content

Commit ef6830d

Browse files
author
Aaron Meyer
committed
Fix convergence check
1 parent a63e469 commit ef6830d

1 file changed

Lines changed: 42 additions & 18 deletions

File tree

tensorly/utils/jointdiag.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,42 @@
88
# License: BSD 3 clause
99

1010

11+
def deviation_from_normality(matrices_tensor):
12+
"""
13+
Calculates the total deviation from normality for a set of matrices.
14+
15+
Metric: Sum of || A @ A.T - A.T @ A ||_F^2 for all matrices in the tensor.
16+
17+
Args:
18+
matrices_tensor (Tensor): Dimension (k, k, n)
19+
20+
Returns:
21+
float: The total deviation error.
22+
"""
23+
n_matrices = matrices_tensor.shape[2]
24+
total_deviation = 0.0
25+
26+
for i in range(n_matrices):
27+
A = matrices_tensor[:, :, i]
28+
# Calculate A Transpose
29+
A_t = tl.transpose(A)
30+
31+
# Calculate Commutator: (A * A^T) - (A^T * A)
32+
# Note: Depending on the backend, tl.dot might act differently on 2D matrices.
33+
# Using explicit matrix multiplication is safer if available,
34+
# but here is the standard dot approach for 2D slices:
35+
commutator = tl.dot(A, A_t) - tl.dot(A_t, A)
36+
37+
# Add the squared norm of the commutator
38+
total_deviation += tl.norm(commutator) ** 2
39+
40+
return total_deviation
41+
42+
1143
def joint_matrix_diagonalization(
1244
matrices_tensor,
1345
max_n_iter: int = 50,
14-
threshold: float = 1e-10,
46+
threshold: float = 1e-8,
1547
verbose: bool = False,
1648
):
1749
"""
@@ -46,7 +78,7 @@ def joint_matrix_diagonalization(
4678
Args:
4779
X (_type_): n matrices, organized in a single tensor of dimension (k, k, n).
4880
max_n_iter (int, optional): Maximum iteration number. Defaults to 50.
49-
threshold (float, optional): Threshold for decrease in error indicating convergence. Defaults to 1e-10.
81+
threshold (float, optional): Threshold for decrease in deviation indicating convergence. Defaults to 1e-8.
5082
verbose (bool, optional): Output progress information during diagonalization. Defaults to False.
5183
5284
Raises:
@@ -67,15 +99,11 @@ def joint_matrix_diagonalization(
6799
assert tl.ndim(matrices_tensor) == 3, "Input must be a 3D tensor"
68100
assert matrix_dimension == matrices_tensor.shape[1], "All matrices must be square."
69101

70-
# Initial error calculation
71-
# Transpose is because np.tril operates on the last two dimensions
72-
error = (
73-
tl.norm(matrices_tensor) ** 2.0
74-
- tl.norm(tl.diagonal(matrices_tensor, axis1=1, axis2=2)) ** 2.0
75-
)
102+
# Deviation from normality is strictly decreasing
103+
deviation = deviation_from_normality(matrices_tensor)
76104

77105
if verbose:
78-
print(f"Sweep # 0: e = {error:.3e}")
106+
print(f"Sweep # 0: dev = {deviation:.3e}")
79107

80108
# Initialize transformation matrix as identity
81109
transform_P = tl.eye(matrix_dimension)
@@ -205,18 +233,14 @@ def joint_matrix_diagonalization(
205233
pvec * tl.sin(theta_k) + transform_P[:, q] * tl.cos(theta_k),
206234
)
207235

208-
# Error computation, check if loop needed...
209-
old_error = error
210-
error = (
211-
tl.norm(matrices_tensor) ** 2.0
212-
- tl.norm(tl.diagonal(matrices_tensor, axis1=1, axis2=2)) ** 2.0
213-
)
236+
# Update deviation from normality
237+
old_deviation = deviation
238+
deviation = deviation_from_normality(matrices_tensor)
214239

215240
if verbose:
216-
print(f"Sweep # {k + 1}: e = {error:.3e}")
241+
print(f"Sweep # {k + 1}: dev = {deviation:.3e}")
217242

218-
# TODO: Strangely the error increases on the first iteration
219-
if old_error - error < threshold and k > 2:
243+
if (old_deviation - deviation < threshold) or (deviation < threshold):
220244
break
221245

222246
return matrices_tensor, transform_P

0 commit comments

Comments
 (0)