Skip to content

Commit 5a9de01

Browse files
authored
Merge pull request #34 from Leona-LYT/main
fix bug problems about coef_ and add more tests
2 parents 7e733db + 686a3cb commit 5a9de01

2 files changed

Lines changed: 322 additions & 8 deletions

File tree

rehline/_sklearn_mixin.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def _fit_multiclass(self, X_aug, y, sample_weight=None):
358358
class_pairs = []
359359
for cls_i, cls_j in combinations(self.classes_, 2):
360360
mask = np.isin(y, [cls_i, cls_j])
361-
y_pm = np.where(y[mask] == cls_i, 1, -1).astype(np.float64)
361+
y_pm = np.where(y[mask] == cls_j, 1, -1).astype(np.float64)
362362
sw_sub = sample_weight[mask] if sample_weight is not None else None
363363
tasks.append((X_aug[mask], y_pm, sw_sub))
364364
class_pairs.append((cls_i, cls_j))
@@ -455,12 +455,12 @@ def predict(self, X):
455455

456456
# discrete vote: score > 0 favors cls_i, score <= 0 favors cls_j
457457
pred = (scores[:, k] > 0).astype(int)
458-
votes[:, i] += pred
459-
votes[:, j] += 1 - pred
458+
votes[:, j] += pred
459+
votes[:, i] += 1 - pred
460460

461461
# continuous confidence: score > 0 means cls_i is more confident
462-
sum_of_confidences[:, i] += scores[:, k]
463-
sum_of_confidences[:, j] -= scores[:, k]
462+
sum_of_confidences[:, j] += scores[:, k]
463+
sum_of_confidences[:, i] -= scores[:, k]
464464

465465
# Monotonically transform to (-1/3, 1/3) to break ties without
466466
# overriding any decision made by a difference of >= 1 vote

tests/_test_multiclass.py

Lines changed: 317 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,309 @@ def test_decision_function_shapes():
378378
print("\n✓ All decision_function shape tests passed!")
379379

380380

381+
def test_ovo_coef_sign_convention():
382+
"""
383+
Test 5: OvO coefficient sign convention (regression test for the sign bug).
384+
385+
The previous bug assigned cls_i -> +1 and cls_j -> -1 in each OvO subproblem,
386+
which is opposite to sklearn's LabelEncoder convention (cls_i -> -1, cls_j -> +1)
387+
since combinations() always yields sorted pairs (cls_i < cls_j).
388+
This caused every subproblem's coef_ to be fully negated (diff ≈ 2 * |β|).
389+
390+
This test directly checks the sign direction of each OvO subproblem's coef_
391+
via dot product, rather than relying solely on accuracy, so the bug cannot
392+
silently reappear.
393+
"""
394+
print("\n" + "="*60)
395+
print("Test 5: OvO Coefficient Sign Convention")
396+
print("="*60)
397+
398+
np.random.seed(0)
399+
n_samples = 2000
400+
n_features = 6
401+
n_classes = 3
402+
C = 1.0
403+
404+
X, y = make_classification(
405+
n_samples=n_samples,
406+
n_features=n_features,
407+
n_informative=4,
408+
n_redundant=1,
409+
n_classes=n_classes,
410+
class_sep=2.0,
411+
random_state=0
412+
)
413+
scaler = StandardScaler()
414+
X = scaler.fit_transform(X)
415+
416+
# sklearn OvO reference
417+
base_clf = LinearSVC(C=C, loss='hinge', fit_intercept=True,
418+
max_iter=1000000, tol=1e-5, random_state=0)
419+
clf_skl = OneVsOneClassifier(base_clf)
420+
clf_skl.fit(X, y)
421+
422+
# rehline OvO
423+
clf_reh = plq_Ridge_Classifier(
424+
loss={'name': 'svm'}, C=C, multi_class='ovo',
425+
max_iter=1000000, tol=1e-5, verbose=0
426+
)
427+
clf_reh.fit(X, y)
428+
429+
n_estimators = n_classes * (n_classes - 1) // 2
430+
print(f"\n{'Estimator':^12} {'dot(skl,reh)':^16} {'||skl||':^12} {'||reh||':^12} {'sign OK':^10}")
431+
print("-" * 65)
432+
433+
all_positive_dot = True
434+
for k, est in enumerate(clf_skl.estimators_):
435+
coef_skl = est.coef_.flatten()
436+
coef_reh = clf_reh.coef_[k]
437+
dot = np.dot(coef_skl, coef_reh)
438+
norm_skl = np.linalg.norm(coef_skl)
439+
norm_reh = np.linalg.norm(coef_reh)
440+
# If signs agree the dot product is positive; if reversed it is negative.
441+
sign_ok = dot > 0
442+
all_positive_dot = all_positive_dot and sign_ok
443+
print(f"{k:^12d} {dot:^16.4f} {norm_skl:^12.4f} {norm_reh:^12.4f} {'✓' if sign_ok else '❌':^10}")
444+
445+
assert all_positive_dot, \
446+
"OvO coef_ sign convention mismatch: at least one subproblem has reversed sign. " \
447+
"This is the sign-convention bug (cls_i/cls_j label encoding mismatch)."
448+
449+
print("\n✓ OvO sign convention test passed!")
450+
451+
452+
def test_ovo_predict_consistency():
453+
"""
454+
Test 6: OvO predict / decision_function consistency.
455+
456+
Verifies that predict() produces exactly the same result as manually
457+
reconstructing predictions from decision_function() using the voting logic,
458+
ensuring the sign convention in fit and predict are perfectly aligned.
459+
"""
460+
print("\n" + "="*60)
461+
print("Test 6: OvO predict / decision_function Consistency")
462+
print("="*60)
463+
464+
np.random.seed(7)
465+
n_samples = 1500
466+
n_features = 5
467+
n_classes = 4
468+
C = 1.0
469+
470+
X, y = make_classification(
471+
n_samples=n_samples,
472+
n_features=n_features,
473+
n_informative=4,
474+
n_redundant=0,
475+
n_classes=n_classes,
476+
class_sep=1.5,
477+
random_state=7
478+
)
479+
scaler = StandardScaler()
480+
X = scaler.fit_transform(X)
481+
482+
clf = plq_Ridge_Classifier(
483+
loss={'name': 'svm'}, C=C, multi_class='ovo',
484+
max_iter=1000000, tol=1e-5, verbose=0
485+
)
486+
clf.fit(X, y)
487+
488+
# Predictions from predict()
489+
y_pred = clf.predict(X)
490+
491+
# Manually reconstruct predictions from decision_function (mirrors predict internals)
492+
scores = clf.decision_function(X)
493+
n_cls = len(clf.classes_)
494+
votes = np.zeros((n_samples, n_cls))
495+
confidences = np.zeros((n_samples, n_cls))
496+
for k, (_, _, cls_i, cls_j) in enumerate(clf.estimators_):
497+
i = np.where(clf.classes_ == cls_i)[0][0]
498+
j = np.where(clf.classes_ == cls_j)[0][0]
499+
pred = (scores[:, k] > 0).astype(int)
500+
votes[:, j] += pred
501+
votes[:, i] += 1 - pred
502+
confidences[:, j] += scores[:, k]
503+
confidences[:, i] -= scores[:, k]
504+
transformed = confidences / (3 * (np.abs(confidences) + 1))
505+
y_manual = clf.classes_[np.argmax(votes + transformed, axis=1)]
506+
507+
n_disagree = np.sum(y_pred != y_manual)
508+
print(f"Disagreements between predict() and manual reconstruction: {n_disagree}")
509+
510+
assert n_disagree == 0, \
511+
f"predict() and decision_function() are inconsistent: {n_disagree} samples disagree. " \
512+
"This indicates a mismatch between the sign convention in fit and predict."
513+
514+
print("✓ OvO predict / decision_function consistency test passed!")
515+
516+
517+
def test_ovo_fit_intercept_false():
518+
"""
519+
Test 7: OvO with fit_intercept=False — correct coef_ shape and accuracy.
520+
521+
Ensures that disabling the intercept still produces the correct coef_ shape,
522+
sets intercept_ to all zeros, and matches sklearn's solution.
523+
"""
524+
print("\n" + "="*60)
525+
print("Test 7: OvO with fit_intercept=False")
526+
print("="*60)
527+
528+
np.random.seed(13)
529+
n_samples = 2000
530+
n_features = 6
531+
n_classes = 3
532+
C = 1.0
533+
534+
X, y = make_classification(
535+
n_samples=n_samples,
536+
n_features=n_features,
537+
n_informative=4,
538+
n_redundant=1,
539+
n_classes=n_classes,
540+
class_sep=2.0,
541+
random_state=13
542+
)
543+
scaler = StandardScaler()
544+
X = scaler.fit_transform(X)
545+
546+
# sklearn OvO, no intercept
547+
base_clf = LinearSVC(C=C, loss='hinge', fit_intercept=False,
548+
max_iter=1000000, tol=1e-5, random_state=13)
549+
clf_skl = OneVsOneClassifier(base_clf)
550+
clf_skl.fit(X, y)
551+
552+
# rehline OvO, no intercept
553+
clf_reh = plq_Ridge_Classifier(
554+
loss={'name': 'svm'}, C=C, multi_class='ovo',
555+
fit_intercept=False, max_iter=1000000, tol=1e-5, verbose=0
556+
)
557+
clf_reh.fit(X, y)
558+
559+
n_estimators = n_classes * (n_classes - 1) // 2
560+
561+
# Shape checks
562+
assert clf_reh.coef_.shape == (n_estimators, n_features), \
563+
f"Expected coef_ shape ({n_estimators}, {n_features}), got {clf_reh.coef_.shape}"
564+
assert np.all(clf_reh.intercept_ == 0.0), \
565+
"intercept_ should be all zeros when fit_intercept=False"
566+
567+
# Accuracy checks
568+
max_diff = 0
569+
for k, est in enumerate(clf_skl.estimators_):
570+
diff = np.max(np.abs(est.coef_.flatten() - clf_reh.coef_[k]))
571+
max_diff = max(max_diff, diff)
572+
print(f"Estimator {k}: max coef diff = {diff:.6e}")
573+
574+
print(f"Overall max coef diff: {max_diff:.6e}")
575+
assert max_diff <= 1e-3, \
576+
f"fit_intercept=False OvO coef_ diff {max_diff:.6e} > 1e-3"
577+
578+
print("✓ OvO fit_intercept=False test passed!")
579+
580+
581+
def test_multiclass_invalid_multi_class():
582+
"""
583+
Test 8: Invalid multi_class parameter should raise ValueError.
584+
585+
Ensures that passing an unrecognised multi_class value causes fit() to raise
586+
a clear ValueError rather than silently failing or producing wrong results.
587+
"""
588+
print("\n" + "="*60)
589+
print("Test 8: Invalid multi_class Parameter")
590+
print("="*60)
591+
592+
np.random.seed(42)
593+
X = np.random.randn(200, 4)
594+
y = np.random.randint(0, 3, 200)
595+
596+
clf = plq_Ridge_Classifier(
597+
loss={'name': 'svm'}, C=1.0, multi_class='invalid_option'
598+
)
599+
600+
raised = False
601+
try:
602+
clf.fit(X, y)
603+
except ValueError as e:
604+
raised = True
605+
print(f"ValueError raised as expected: {e}")
606+
607+
assert raised, "Expected ValueError for invalid multi_class parameter, but none was raised."
608+
print("✓ Invalid multi_class parameter test passed!")
609+
610+
611+
def test_ovo_more_classes():
612+
"""
613+
Test 9: OvO correctness with 5 classes (10 subproblems).
614+
615+
Verifies that the number of subproblems, coef_ shape, and coefficient
616+
accuracy are all correct when the number of classes grows, guarding against
617+
errors in the combinatorial subproblem construction logic.
618+
"""
619+
print("\n" + "="*60)
620+
print("Test 9: OvO with 5 Classes (10 subproblems)")
621+
print("="*60)
622+
623+
np.random.seed(99)
624+
n_samples = 3000
625+
n_features = 8
626+
n_classes = 5
627+
C = 1.0
628+
n_estimators = n_classes * (n_classes - 1) // 2 # 10
629+
630+
X, y = make_classification(
631+
n_samples=n_samples,
632+
n_features=n_features,
633+
n_informative=6,
634+
n_redundant=1,
635+
n_classes=n_classes,
636+
class_sep=1.5,
637+
random_state=99
638+
)
639+
X_train, X_test, y_train, y_test = train_test_split(
640+
X, y, test_size=0.3, random_state=99, stratify=y
641+
)
642+
scaler = StandardScaler()
643+
X_train = scaler.fit_transform(X_train)
644+
X_test = scaler.transform(X_test)
645+
646+
# sklearn
647+
base_clf = LinearSVC(C=C, loss='hinge', fit_intercept=True,
648+
max_iter=1000000, tol=1e-5, random_state=99)
649+
clf_skl = OneVsOneClassifier(base_clf)
650+
clf_skl.fit(X_train, y_train)
651+
acc_skl = accuracy_score(y_test, clf_skl.predict(X_test))
652+
653+
# rehline
654+
clf_reh = plq_Ridge_Classifier(
655+
loss={'name': 'svm'}, C=C, multi_class='ovo',
656+
max_iter=1000000, tol=1e-5, verbose=0
657+
)
658+
clf_reh.fit(X_train, y_train)
659+
acc_reh = accuracy_score(y_test, clf_reh.predict(X_test))
660+
661+
# 形状检查
662+
assert clf_reh.coef_.shape == (n_estimators, n_features), \
663+
f"Expected coef_ shape ({n_estimators}, {n_features}), got {clf_reh.coef_.shape}"
664+
assert clf_reh.intercept_.shape == (n_estimators,), \
665+
f"Expected intercept_ shape ({n_estimators},), got {clf_reh.intercept_.shape}"
666+
assert len(clf_reh.estimators_) == n_estimators, \
667+
f"Expected {n_estimators} estimators, got {len(clf_reh.estimators_)}"
668+
669+
# 精度检查
670+
max_diff = 0
671+
for k, est in enumerate(clf_skl.estimators_):
672+
diff = np.max(np.abs(est.coef_.flatten() - clf_reh.coef_[k]))
673+
max_diff = max(max_diff, diff)
674+
print(f"5-class OvO: {n_estimators} subproblems, max coef diff = {max_diff:.6e}")
675+
print(f"Accuracy: sklearn={acc_skl:.4f}, rehline={acc_reh:.4f}")
676+
677+
assert max_diff <= 1e-3, \
678+
f"5-class OvO coef_ diff {max_diff:.6e} > 1e-3"
679+
680+
print("✓ OvO 5-class test passed!")
681+
return acc_skl, acc_reh, max_diff
682+
683+
381684
if __name__ == "__main__":
382685
print("\n" + "="*70)
383686
print("MULTI-CLASS CLASSIFICATION TEST SUITE")
@@ -391,14 +694,25 @@ def test_decision_function_shapes():
391694
acc_skl_ovr, acc_reh_ovr, diff_ovr = test_multiclass_ovr_vs_sklearn()
392695
acc_skl_ovo, acc_reh_ovo, diff_ovo = test_multiclass_ovo_vs_sklearn()
393696
test_decision_function_shapes()
394-
697+
test_ovo_coef_sign_convention()
698+
test_ovo_predict_consistency()
699+
test_ovo_fit_intercept_false()
700+
test_multiclass_invalid_multi_class()
701+
acc_skl_ovo5, acc_reh_ovo5, diff_ovo5 = test_ovo_more_classes()
702+
395703
print("\n" + "="*70)
396704
print("TEST SUMMARY")
397705
print("="*70)
398706
print(f"{'Test':^30} {'sklearn acc':^12} {'rehline acc':^12} {'max coef diff':^15}")
399707
print("-" * 70)
400708
print(f"{'Binary Classification':^30} {acc_skl_bin:^12.4f} {acc_reh_bin:^12.4f} {diff_bin:^15.2e}")
401709
print(f"{'OvR Multi-class':^30} {acc_skl_ovr:^12.4f} {acc_reh_ovr:^12.4f} {diff_ovr:^15.2e}")
402-
print(f"{'OvO Multi-class':^30} {acc_skl_ovo:^12.4f} {acc_reh_ovo:^12.4f} {diff_ovo:^15.2e}")
710+
print(f"{'OvO Multi-class (3cls)':^30} {acc_skl_ovo:^12.4f} {acc_reh_ovo:^12.4f} {diff_ovo:^15.2e}")
711+
print(f"{'OvO Multi-class (5cls)':^30} {acc_skl_ovo5:^12.4f} {acc_reh_ovo5:^12.4f} {diff_ovo5:^15.2e}")
712+
print(f"{'Decision Func Shapes':^30} {'—':^12} {'—':^12} {'—':^15}")
713+
print(f"{'OvO Sign Convention':^30} {'—':^12} {'—':^12} {'—':^15}")
714+
print(f"{'OvO Predict Consistency':^30} {'—':^12} {'—':^12} {'—':^15}")
715+
print(f"{'OvO No Intercept':^30} {'—':^12} {'—':^12} {'—':^15}")
716+
print(f"{'Invalid multi_class':^30} {'—':^12} {'—':^12} {'—':^15}")
403717
print("="*70)
404-
print("\n✓ All tests passed successfully!")
718+
print("\n✓ All 9 tests passed successfully!")

0 commit comments

Comments
 (0)