LLaMA

LLaMA

花两天时间写了遍LLaMA,发现里面涉及的知识点很多,记录一下

模型架构

notion image
顺着图一点点写

1. RMS Norm

RMS Norm 是一个新的Normalization的方式,与 Layer NormBatch Norm 类似。公式如下:
notion image
这样做避免了计算了所有特征平均值(据论文中说,只有方差才是影响Normalization的主要因素)。
代码实现也很简单:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        # The gamma parameters
        # (dim)
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor):
        # (b, s, d) -> (b, s, d)
        # rsqrt: 1 / sqrt(x)
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
    
    def forward(self, x: torch.Tensor):
        # (d) * (b, s, d) -> (b, s, d)
        return self.weight * self._norm(x)

2. RoPE

Rotary Position Encodings (旋转位置编码)是LLaMA模型中很重要的一个部分,但是对我来说,实现它的方式更巧妙。公式如下:
notion image
对于一个有 d 维特征的向量,我们需要构造以下的矩阵来完成编码:
notion image
其中 等于:
notion image
但是这个矩阵过于稀疏,计算效率太低。在经过数学推导后可以得到等价计算式:
notion image
于是真正要实现的是上述的公式。
详细推导过程与更多细节详见原始论文:RoFormer: Enhanced Transformer with Rotary Position Embedding

实现的方式非常🐮:

  1. 先计算所有的
theta_numerator = torch.arange(0, head_dim, 2).float()
theta = 1 / (theta ** (theta_numerator / head_dim)).to(device)
  1. 再构造出所有的
m = torch.arange(0, max_seq_len, device=device)
  1. 计算出所有的
freqs = torch.outer(m, theta).float()
  1. 转成复数形式:(最重要的一步):
freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
这样就把每一个 转成了
这里用到了欧拉公式
转成复数的具体原因(用一个例子说明):
假设我们有输入特征:
我们将其视作一个(2,2)的矩阵:
通过torch.view_as_complex 函数,我们可以把这个矩阵看成两个复数,其中:
第一个复数是:
第二个复数是:
回顾一下,对于一个特征为4的向量,我们得到的freqs 长这样:
二者对应(elem wise) 点积后得到:
通过torch.view_as_real 函数,得到:
再把形状变回去,得到:
会发现结果和公式一模一样:
notion image
完整的代码实现:
def precompute_theta_pos_frequencies(head_dim: int, max_seq_len: int, device: str, theta: float = 10000.0):
    # according to the paper, head_dim should be even
    assert head_dim % 2 == 0, "Head dim should be even"
    # Shape: (head_dim / 2), according to the paper, represents theta [0, 2, ..., d - 2]
    theta_numerator = torch.arange(0, head_dim, 2).float()
    # Shape: (head_dim / 2), formula: theta_i = 10000 ^ (-2(i - 1) / dim), for i = 1, 2, ..., d / 2
    theta = 1 / (theta ** (theta_numerator / head_dim)).to(device)
    # Shape: (max_seq_len), according to the paper, represents the position [1, 2, ..., n]
    m = torch.arange(0, max_seq_len, device=device)
    # Shape: (max_seq_len, head_dim / 2)
    freqs = torch.outer(m, theta).float()
    # Shape: (max_seq_len, head_dim / 2)
    # the first numberrepresents the cos, the second numer represents the sin
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    # (b, s, h, d) -> (b, s, h, d / 2, 2) -> (b, s, h, d / 2)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    # (b, s, h, d / 2) * (1, s, 1, d / 2) -> (b, s, h, d / 2)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    x_complex = x_complex * freqs_complex
    # (b, s, h, d / 2) -> (b, s, h, d / 2, 2) -> (b, s, h, d)
    x_out = torch.view_as_real(x_complex).reshape(*x.shape)
    return x_out.type_as(x).to(device)

3. KV cache

原理不再赘述,大概就是:
推理的时候前面的KV不需要重新计算,可以缓存下来直接用
听起来很复杂,但是在代码里面就几行:
# declare
self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_heads_kv, self.head_dim), device=args.device)
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_heads_kv, self.head_dim), device=args.device)

# ...
# in forward

# NOTE!!: do this after applying Linear proj and RoPE
# add KV into cache
self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv

# retrieve KV from cache
keys = self.cache_k[:batch_size, :start_pos + seq_len]
values = self.cache_v[:batch_size, :start_pos + seq_len]

# Attention as the same
# ...
同样的,关于更多的KV cache的原理和细节看论文:Efficiently Scaling Transformer Inference

4. SwiGLU

这是两个东西的结合版

1. GLU

GLU(Gated Linear Units,门控线性单元) 引入了两个不同的线性层,其中一个首先经过sigmoid函数,其结果将和另一个线性层的输出进行逐元素相乘作为最终的输出:
这里W , V 以及b , c 分别是这两个线性层的参数;σ ( x W + b ) 作为门控,控制x V + c的输出。

2. Swish

Swish激活函数的形式为:
其中σ ( x ) 是Sigmoid函数;β 是一个可学习的参数。

3. SwiGLU

如前文所述,将GLU的激活函数改为Swish即变成了所谓的SwiGLU激活函数:
代码也比较简单:
class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        hidden_dim = 4 * args.input_dim
        hidden_dim = int(2 * hidden_dim / 3)
        if args.ffn_dim_multiplier is not None:
            hidden_dim = int(hidden_dim * args.ffn_dim_multiplier)
        hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)

        self.w1 = nn.Linear(args.input_dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, args.input_dim, bias=False)
        self.w3 = nn.Linear(args.input_dim, hidden_dim, bias=False)

    def forward(self, x: torch.Tensor):
		    # SwiGLU begin
        swish = nn.SiLU(self.w1(x))
        x_V = self.w3(x)
        x = swish * x_V
        # SwiGLU end
        x = self.w2(x)
        return x
那一大堆hidden_dim的计算不知道在干嘛,但是官方仓库里的代码有这些