File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -28,20 +28,13 @@ def __init__(
2828 self .use_bias = use_bias
2929 self .dropout = dropout
3030
31- linear = nn .Linear (in_dim , out_dim , bias = use_bias )
32- weight = linear .weight
33- bias = linear .bias if self .use_bias else None
34- self .register_parameter ("weight" , nn .Parameter (weight ))
35- self .register_parameter (
36- "bias" , nn .Parameter (bias ) if bias is not None else None
37- )
38-
31+ self .linear = nn .Linear (in_dim , out_dim , bias = use_bias )
3932 self .dropout = nn .Dropout (p = dropout ) if dropout > 0.0 else nn .Identity ()
4033 self .lora_a = nn .Linear (in_features = in_dim , out_features = rank , bias = False )
4134 self .lora_b = nn .Linear (in_features = rank , out_features = out_dim , bias = False )
4235
4336 def forward (self , x : torch .Tensor ) -> torch .Tensor :
44- out = torch . nn . functional . linear (x , self . weight , self . bias )
37+ out = self . linear (x )
4538 lora_out = self .lora_a (self .dropout (x ))
4639 lora_out = (self .alpha / self .rank ) * self .lora_b (lora_out )
4740
You can’t perform that action at this time.
0 commit comments