Skip to content

Commit d23c0d2

Browse files
committed
Update lora def
1 parent d0820e1 commit d23c0d2

1 file changed

Lines changed: 2 additions & 9 deletions

File tree

examples/models/llama/lora.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,13 @@ def __init__(
2828
self.use_bias = use_bias
2929
self.dropout = dropout
3030

31-
linear = nn.Linear(in_dim, out_dim, bias=use_bias)
32-
weight = linear.weight
33-
bias = linear.bias if self.use_bias else None
34-
self.register_parameter("weight", nn.Parameter(weight))
35-
self.register_parameter(
36-
"bias", nn.Parameter(bias) if bias is not None else None
37-
)
38-
31+
self.linear = nn.Linear(in_dim, out_dim, bias=use_bias)
3932
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
4033
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
4134
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
4235

4336
def forward(self, x: torch.Tensor) -> torch.Tensor:
44-
out = torch.nn.functional.linear(x, self.weight, self.bias)
37+
out = self.linear(x)
4538
lora_out = self.lora_a(self.dropout(x))
4639
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
4740

0 commit comments

Comments
 (0)