花两天时间写了遍LLaMA,发现里面涉及的知识点很多,记录一下
模型架构
顺着图一点点写
1. RMS Norm
RMS Norm
是一个新的Normalization的方式,与 Layer Norm
和 Batch Norm
类似。公式如下:这样做避免了计算了所有特征平均值(据论文中说,只有方差才是影响Normalization的主要因素)。
代码实现也很简单:
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
# The gamma parameters
# (dim)
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor):
# (b, s, d) -> (b, s, d)
# rsqrt: 1 / sqrt(x)
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor):
# (d) * (b, s, d) -> (b, s, d)
return self.weight * self._norm(x)
2. RoPE
Rotary Position Encodings
(旋转位置编码)是LLaMA模型中很重要的一个部分,但是对我来说,实现它的方式更巧妙。公式如下:对于一个有 d 维特征的向量,我们需要构造以下的矩阵来完成编码:
其中 等于:
但是这个矩阵过于稀疏,计算效率太低。在经过数学推导后可以得到等价计算式:
于是真正要实现的是上述的公式。
详细推导过程与更多细节详见原始论文:RoFormer: Enhanced Transformer with Rotary Position Embedding
实现的方式非常🐮:
- 先计算所有的 :
theta_numerator = torch.arange(0, head_dim, 2).float()
theta = 1 / (theta ** (theta_numerator / head_dim)).to(device)
- 再构造出所有的 :
m = torch.arange(0, max_seq_len, device=device)
- 计算出所有的 :
freqs = torch.outer(m, theta).float()
- 转成复数形式:(最重要的一步):
freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
这样就把每一个 转成了
这里用到了欧拉公式
转成复数的具体原因(用一个例子说明):
假设我们有输入特征:
我们将其视作一个(2,2)的矩阵:
通过
torch.view_as_complex
函数,我们可以把这个矩阵看成两个复数,其中:第一个复数是:
第二个复数是:
回顾一下,对于一个特征为4的向量,我们得到的
freqs
长这样:二者对应(elem wise) 点积后得到:
通过
torch.view_as_real
函数,得到:再把形状变回去,得到:
会发现结果和公式一模一样:
完整的代码实现:
def precompute_theta_pos_frequencies(head_dim: int, max_seq_len: int, device: str, theta: float = 10000.0):
# according to the paper, head_dim should be even
assert head_dim % 2 == 0, "Head dim should be even"
# Shape: (head_dim / 2), according to the paper, represents theta [0, 2, ..., d - 2]
theta_numerator = torch.arange(0, head_dim, 2).float()
# Shape: (head_dim / 2), formula: theta_i = 10000 ^ (-2(i - 1) / dim), for i = 1, 2, ..., d / 2
theta = 1 / (theta ** (theta_numerator / head_dim)).to(device)
# Shape: (max_seq_len), according to the paper, represents the position [1, 2, ..., n]
m = torch.arange(0, max_seq_len, device=device)
# Shape: (max_seq_len, head_dim / 2)
freqs = torch.outer(m, theta).float()
# Shape: (max_seq_len, head_dim / 2)
# the first numberrepresents the cos, the second numer represents the sin
freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
return freqs_complex
def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
# (b, s, h, d) -> (b, s, h, d / 2, 2) -> (b, s, h, d / 2)
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
# (b, s, h, d / 2) * (1, s, 1, d / 2) -> (b, s, h, d / 2)
freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
x_complex = x_complex * freqs_complex
# (b, s, h, d / 2) -> (b, s, h, d / 2, 2) -> (b, s, h, d)
x_out = torch.view_as_real(x_complex).reshape(*x.shape)
return x_out.type_as(x).to(device)
3. KV cache
原理不再赘述,大概就是:
推理的时候前面的KV不需要重新计算,可以缓存下来直接用
听起来很复杂,但是在代码里面就几行:
# declare
self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_heads_kv, self.head_dim), device=args.device)
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_heads_kv, self.head_dim), device=args.device)
# ...
# in forward
# NOTE!!: do this after applying Linear proj and RoPE
# add KV into cache
self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv
# retrieve KV from cache
keys = self.cache_k[:batch_size, :start_pos + seq_len]
values = self.cache_v[:batch_size, :start_pos + seq_len]
# Attention as the same
# ...
同样的,关于更多的KV cache的原理和细节看论文:Efficiently Scaling Transformer Inference
4. SwiGLU
这是两个东西的结合版
1. GLU
GLU(Gated Linear Units,门控线性单元) 引入了两个不同的线性层,其中一个首先经过sigmoid函数,其结果将和另一个线性层的输出进行逐元素相乘作为最终的输出:
这里W , V 以及b , c 分别是这两个线性层的参数;σ ( x W + b ) 作为门控,控制x V + c的输出。
2. Swish
Swish激活函数的形式为:
其中σ ( x ) 是Sigmoid函数;β 是一个可学习的参数。
3. SwiGLU
如前文所述,将GLU的激活函数改为Swish即变成了所谓的SwiGLU激活函数:
代码也比较简单:
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
hidden_dim = 4 * args.input_dim
hidden_dim = int(2 * hidden_dim / 3)
if args.ffn_dim_multiplier is not None:
hidden_dim = int(hidden_dim * args.ffn_dim_multiplier)
hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
self.w1 = nn.Linear(args.input_dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, args.input_dim, bias=False)
self.w3 = nn.Linear(args.input_dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor):
# SwiGLU begin
swish = nn.SiLU(self.w1(x))
x_V = self.w3(x)
x = swish * x_V
# SwiGLU end
x = self.w2(x)
return x
那一大堆hidden_dim的计算不知道在干嘛,但是官方仓库里的代码有这些