Skip to content

Commit 7c17a8e

Browse files
authored
Merge pull request #1357 from gzrp/dev-postgresql
Add the implementations for the PoswiseFeedForwardNet
2 parents b65e244 + 7909461 commit 7c17a8e

1 file changed

Lines changed: 26 additions & 0 deletions

File tree

  • examples/singa_peft/examples/model

examples/singa_peft/examples/model/trans.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,3 +481,29 @@ def forward(self, query, key, value, attn_mask):
481481
# attn: [batch_size, n_heads, len_q, len_k] value: [batch_size, n_heads, len_v(=len_k), d_v]
482482
context = matmul4d(attn, value)
483483
return context, attn
484+
485+
486+
class PoswiseFeedForwardNet(layer.Layer):
487+
def __init__(self, d_model=512, dim_feedforward=2048, bias=False):
488+
super(PoswiseFeedForwardNet, self).__init__()
489+
490+
self.d_model = d_model
491+
self.dim_feedforward = dim_feedforward
492+
self.bias = bias
493+
494+
self.linear1 = Linear3D(d_model, dim_feedforward, bias=bias)
495+
self.relu = layer.ReLU()
496+
self.linear2 = Linear3D(dim_feedforward, d_model, bias=bias)
497+
self.add = layer.Add()
498+
self.norm = LayerNorm(d_model)
499+
500+
def forward(self, inputs):
501+
# inputs: [batch_size, seq_len, d_model]
502+
residual = inputs
503+
output = self.linear1(inputs)
504+
output = self.relu(output)
505+
output = self.linear2(output)
506+
# [batch_size, seq_len, d_model]
507+
output = self.add(output, residual)
508+
output = self.norm(output)
509+
return output

0 commit comments

Comments
 (0)