-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_utils.py
More file actions
108 lines (82 loc) · 3.19 KB
/
model_utils.py
File metadata and controls
108 lines (82 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
activation_functions = {
"relu": nn.ReLU,
"sigmoid": nn.Sigmoid,
"tanh": nn.Tanh,
"leaky_relu": nn.LeakyReLU,
"gelu": nn.GELU,
}
def get_2d_sincos_pos_embed(embed_dim, grid_size):
grid_h = np.arange(grid_size[0], dtype=np.float32)
grid_w = np.arange(grid_size[1], dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
emb = np.concatenate([emb_h, emb_w], axis=1)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega
pos = pos.reshape(-1)
out = np.einsum("m,d->md", pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
class MultiHeadSelfAttention(nn.Module):
def __init__(self, dim, num_heads):
super(MultiHeadSelfAttention, self).__init__()
self.num_heads = num_heads
self.dim = dim
self.head_dim = dim // num_heads
assert self.head_dim * num_heads == dim
self.qkv = nn.Linear(dim, dim * 3)
self.fc_out = nn.Linear(dim, dim)
def forward(self, x):
batch_size, seq_length, dim = x.size()
qkv = self.qkv(x)
qkv = qkv.view(batch_size, seq_length, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn_weights = F.softmax(attn_scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
attn_output = attn_output.view(batch_size, seq_length, dim)
out = self.fc_out(attn_output)
return out
class FeedForward(nn.Module):
def __init__(self, dim, mlp_dim, act="relu"):
super(FeedForward, self).__init__()
self.fc1 = nn.Linear(dim, mlp_dim)
self.fc2 = nn.Linear(mlp_dim, dim)
self.act = activation_functions.get(act.lower(), nn.ReLU)()
def forward(self, x):
x = self.act(self.fc1(x))
x = self.fc2(x)
return x
class TransformerLayer(nn.Module):
def __init__(self, dim, num_heads, mlp_dim, act="relu"):
super(TransformerLayer, self).__init__()
self.self_attn = MultiHeadSelfAttention(dim, num_heads)
self.feed_forward = FeedForward(dim, mlp_dim, act)
self.ln1 = nn.LayerNorm(dim)
self.ln2 = nn.LayerNorm(dim)
def forward(self, x):
attn_output = self.self_attn(self.ln1(x))
x = x + attn_output
ff_output = self.feed_forward(self.ln2(x))
x = x + ff_output
return x