This repository was archived by the owner on Aug 22, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathUtils.cs
More file actions
211 lines (186 loc) · 9.54 KB
/
Utils.cs
File metadata and controls
211 lines (186 loc) · 9.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
using FluentAssertions;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using TorchSharp.Modules;
using TorchSharp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
public static class Utils
{
public static Tensor ApplyRotaryEmbeddings(Tensor input, Tensor freqsComplex)
{
// Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
// Two consecutive values will become a single complex number
// (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2)
var input_complex = input.to_type(ScalarType.Float32).reshape(input.shape[0], input.shape[1], input.shape[2], -1, 2).view_as_complex();
freqsComplex = freqsComplex.to(input.device);
// Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension
// (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2)
var freqs_complex_reshaped = freqsComplex.unsqueeze(0).unsqueeze(2);
// Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
// Which results in the rotation of the complex number as shown in the Figure 1 of the paper
// (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2)
var rotated_complex = input_complex * freqs_complex_reshaped;
// Console.WriteLine(rotated_complex.mean().ToSingle());
// Convert the complex number back to the real number
// (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2)
var rotated = rotated_complex.view_as_real();
// (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim)
var rotated_reshaped = rotated.reshape(rotated.shape[0], rotated.shape[1], rotated.shape[2], -1);
input.shape.Should().BeEquivalentTo(rotated_reshaped.shape);
return rotated_reshaped.type_as(input);
}
// def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
// # As written in the paragraph 3.2.2 of the paper
// # >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
// assert head_dim % 2 == 0, "Dimension must be divisible by 2"
// # Build the theta parameter
// # According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2]
// # Shape: (Head_Dim / 2)
// theta_numerator = torch.arange(0, head_dim, 2).float ()
//# Shape: (Head_Dim / 2)
// theta = 1.0 / (theta * *(theta_numerator / head_dim)).to(device) # (Dim / 2)
// # Construct the positions (the "m" parameter)
// # Shape: (Seq_Len)
// m = torch.arange(seq_len, device=device)
// # Multiply each theta by each position using the outer product.
// # Shape: (Seq_Len) outer_product* (Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
// freqs = torch.outer(m, theta).float ()
//# We can compute complex numbers in the polar form c = R * exp(m * theta), where R = 1 as follows:
//# (Seq_Len, Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
// freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
// return freqs_complex
public static Tensor PrecomputeThetaPosFrequencies(int headDim, int seqLen, string device, float theta = 10000.0f)
{
// As written in the paragraph 3.2.2 of the paper
// >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
if (headDim % 2 != 0)
{
throw new ArgumentException("Dimension must be divisible by 2", nameof(headDim));
}
// Build the theta parameter
// According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2]
// Shape: (Head_Dim / 2)
var thetaNumerator = torch.arange(0, headDim, 2).to(torch.float32).to(device);
// Shape: (Head_Dim / 2)
var thetaInput = torch.pow(theta, -1.0f * (thetaNumerator / headDim)).to(device); // (Dim / 2)
// Construct the positions (the "m" parameter)
// Shape: (Seq_Len)
var m = torch.arange(seqLen, device: device);
// Multiply each theta by each position using the outer product.
// Shape: (Seq_Len) outer_product* (Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
var freqs = torch.outer(m, thetaInput).to(torch.float32).to(device);
// We can compute complex numbers in the polar form c = R * exp(m * theta), where R = 1 as follows:
// (Seq_Len, Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
var freqsComplex = torch.polar(torch.ones_like(freqs), freqs);
return freqsComplex;
}
// python
// def rotate_half(x):
// """Rotates half the hidden dims of the input."""
// x1 = x[..., : x.shape[-1] // 2]
// x2 = x[..., x.shape[-1] // 2 :]
// return torch.cat((-x2, x1), dim=-1)
public static Tensor RotateHalf(Tensor x)
{
var x1 = x[.., .., .., ..(int)(x.shape[^1] / 2)];
var x2 = x[.., .., .., (int)(x.shape[^1] / 2)..];
// (x1 * x1 * x2).Peek("x1 * x1 * x2");
return torch.cat([-x2, x1], dim: -1);
}
// python
// # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
// def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
// """Applies Rotary Position Embedding to the query and key tensors.
// Args:
// q (`torch.Tensor`): The query tensor.
// k (`torch.Tensor`): The key tensor.
// cos (`torch.Tensor`): The cosine part of the rotary embedding.
// sin (`torch.Tensor`): The sine part of the rotary embedding.
// position_ids (`torch.Tensor`):
// The position indices of the tokens corresponding to the query and key tensors. For example, this can be
// used to pass offsetted position ids when working with a KV-cache.
// unsqueeze_dim (`int`, *optional*, defaults to 1):
// The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
// sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
// that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
// k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
// cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
// the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
// Returns:
// `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
// """
// cos = cos[position_ids].unsqueeze(unsqueeze_dim)
// sin = sin[position_ids].unsqueeze(unsqueeze_dim)
// q_embed = (q * cos) + (rotate_half(q) * sin)
// k_embed = (k * cos) + (rotate_half(k) * sin)
// return q_embed, k_embed
public static (Tensor, Tensor) ApplyRotaryPosEmb(Tensor q, Tensor k, Tensor cos, Tensor sin, Tensor? positionIds = null, int unsqueezeDim = 1)
{
// The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
// sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
// that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
// k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
// cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
// the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
if (positionIds is not null)
{
cos = cos[positionIds!].unsqueeze(unsqueezeDim);
sin = sin[positionIds!].unsqueeze(unsqueezeDim);
}
else
{
cos = cos.unsqueeze(unsqueezeDim);
sin = sin.unsqueeze(unsqueezeDim);
}
var qEmbed = q * cos;
qEmbed += RotateHalf(q) * sin;
var kEmbed = k * cos;
kEmbed += RotateHalf(k) * sin;
// var kEmbed = (k * cos) + (RotateHalf(k) * sin);
return (qEmbed, kEmbed);
}
public static Module<Tensor, Tensor> GetActivation(string act_fn)
{
return act_fn switch
{
"silu" => nn.SiLU(),
"relu" => nn.ReLU(),
"gelu" => nn.GELU(),
"tanh" => nn.Tanh(),
"swish" => nn.SiLU(),
_ => throw new ArgumentException("Invalid activation function", act_fn),
};
}
public static Tensor Phi2RepeatKV(Tensor x, int nRep)
{
var batchSize = x.shape[0];
var seqLen = x.shape[1];
var nKVHeads = x.shape[2];
var headDim = x.shape[3];
if (nRep == 1)
{
return x;
}
return x.unsqueeze(3)
.expand(batchSize, seqLen, nKVHeads, nRep, headDim)
.view(batchSize, seqLen, nKVHeads * nRep, headDim);
}
public static Tensor Phi3RepeatKV(Tensor x, int nRep)
{
var batchSize = x.shape[0];
var nKVHeads = x.shape[1];
var seqLen = x.shape[2];
var headDim = x.shape[3];
if (nRep == 1)
{
return x;
}
return x.unsqueeze(3)
.expand(batchSize, nKVHeads, nRep, seqLen, headDim)
.view(batchSize, nKVHeads * nRep, seqLen, headDim);
}
}