@@ -32,7 +32,7 @@ def __init__(self, weights_0=None, method='Full', initializer='identity',
3232 implements a quasi Newton method
3333 """
3434 if method not in ['Full' , 'Diag' , 'FixDiag' ]:
35- raise (ValueError (' method {} not avaliable' . format ( method ) ))
35+ raise (ValueError (f" method { method } not avaliable" ))
3636
3737 self .weights_0 = weights_0
3838 self .method = method
@@ -68,7 +68,6 @@ def predict(self, S):
6868
6969 return np .asarray (self .predict_proba (S ))
7070
71-
7271 def fit (self , X , y , * args , ** kwargs ):
7372
7473 self .__setup ()
@@ -99,8 +98,8 @@ def fit(self, X, y, *args, **kwargs):
9998
10099 self .weights_0_ = self ._get_initial_weights (self .initializer )
101100
102- if (self .optimizer == 'newton'
103- or (self .optimizer == 'auto' and k <= 36 )):
101+ if (self .optimizer == 'newton' or
102+ (self .optimizer == 'auto' and k <= 36 )):
104103 weights = _newton_update (self .weights_0_ , X_ , XXT , target , k ,
105104 self .method , reg_lambda = self .reg_lambda ,
106105 reg_mu = self .reg_mu , ref_row = self .ref_row ,
@@ -141,6 +140,8 @@ def _get_initial_weights(self, ref_row, initializer='identity'):
141140
142141 k = len (self .classes )
143142
143+ weights_0 = self .weights_0_
144+
144145 if self .weights_0_ is None :
145146 if initializer == 'identity' :
146147 weights_0 = _get_identity_weights (k , ref_row , self .method )
@@ -185,33 +186,33 @@ def _objective(params, *args):
185186
186187
187188def _get_weights (params , k , ref_row , method ):
188- ''' Reshapes the given params (weights) into the full matrix including 0
189- '''
189+ '''Reshapes the given params (weights) into the full matrix including 0
190+ '''
190191
191- if method in ['Full' , None ]:
192- raw_weights = params .reshape (- 1 , k + 1 )
193- # weights = np.zeros([k, k+1])
194- # weights[:-1, :] = params.reshape(-1, k + 1)
192+ if method in ['Full' , None ]:
193+ raw_weights = params .reshape (- 1 , k + 1 )
194+ # weights = np.zeros([k, k+1])
195+ # weights[:-1, :] = params.reshape(-1, k + 1)
195196
196- elif method == 'Diag' :
197- raw_weights = np .hstack ([np .diag (params [:k ]),
198- params [k :].reshape (- 1 , 1 )])
199- # weights[:, :-1][np.diag_indices(k)] = params[:]
197+ elif method == 'Diag' :
198+ raw_weights = np .hstack ([np .diag (params [:k ]),
199+ params [k :].reshape (- 1 , 1 )])
200+ # weights[:, :-1][np.diag_indices(k)] = params[:]
200201
201- elif method == 'FixDiag' :
202- raw_weights = np .hstack ([np .eye (k ) * params [0 ], np .zeros ((k , 1 ))])
203- # weights[np.dgag_indices(k - 1)] = params[0]
204- # weights[np.diag_indices(k)] = params[0]
205- else :
206- raise (ValueError ("Unknown calibration method {}" . format ( method ) ))
202+ elif method == 'FixDiag' :
203+ raw_weights = np .hstack ([np .eye (k ) * params [0 ], np .zeros ((k , 1 ))])
204+ # weights[np.dgag_indices(k - 1)] = params[0]
205+ # weights[np.diag_indices(k)] = params[0]
206+ else :
207+ raise (ValueError (f "Unknown calibration method { method } " ))
207208
208- if ref_row :
209- weights = raw_weights - np .repeat (
210- raw_weights [- 1 , :].reshape (1 , - 1 ), k , axis = 0 )
211- else :
212- weights = raw_weights
209+ if ref_row :
210+ weights = raw_weights - np .repeat (
211+ raw_weights [- 1 , :].reshape (1 , - 1 ), k , axis = 0 )
212+ else :
213+ weights = raw_weights
213214
214- return weights
215+ return weights
215216
216217
217218def _get_identity_weights (n_classes , ref_row , method ):
0 commit comments