GPT2

GPT2

  1. 一种快速构建序列数据集的方式:
def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_pos : self.current_pos + B * T + 1]
        x = buf[:-1].view(B, T)
        y = buf[1:].view(B, T)
  1. 固定初始化参数的方差(不是根据隐藏层大小而变化的方差):
def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
  1. 权重共享方案:词嵌入权重和最后线性映射的权重是一样的:
# share weights
self.transformer.wte.weight = self.lm_head.weight
  1. 残差缩放,更稳定:
class MLP:
# ...

# used for residual scaling
self.c_proj.NANOGPT_SCALE_N = 1

#...

class CausalSelfAttention
# ...

# used for residual scaling
self.c_proj.NANOGPT_SCALE_N = 1

#...

def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
						# Here to scale down
						# Reason to * 2: In each block(layer), there are 2 residuals
            if hasattr(module, 'NANOGPT_SCLAE_N'):
                std *= (2 * self.config.n_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

  1. 使用更低的精度训练:
# lower presicion for training
torch.set_float32_matmul_precision('high')

#...

# Using bfloat16
with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits, loss = model(x, y)
loss.backward()
  1. torch.compile:
model = torch.compile(model)
  1. flash attention:
# flash_attention
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
  1. 利用缓存,将矩阵大小设置为2的次方(或尽可能的多含有2的次方):
例如,更改词汇表的小大:
# 50257 in previous
model = GPT(GPTConfig(vocab_size=50304))
 

GPT2训练细节

  1. 优化器与梯度裁剪:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)

# ...

loss.backward()
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

# ...
  1. scheduler(余弦与热身):
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
max_steps = 50
def get_lr(it):
    if it < warmup_steps:
        return max_lr * (it + 1) / warmup_steps
    
    if it > max_steps:
        return min_lr
    
    decay_co = (it - warmup_steps) / (max_steps - warmup_steps)
    assert 0 <= decay_co <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_co))
    return min_lr + coeff * (max_lr - min_lr)

# in training
loss.backward()
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

# set lr for each param
lr = get_lr(i)
for param_group in optimizer.param_groups:
    param_group['lr'] = lr

optimizer.step()
  1. weight_decay:
def configure_optimizers(self, weight_decay, learning_rate, device):
        param_dict = {pn : p for pn, p in self.named_parameters()}
        param_dict = {pn : p for pn, p in param_dict.items() if p.requires_grad}

        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        
        # key insight
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        
        optim_group = [
            {'params' : decay_params, 'weight_decay' : weight_decay},
            {'params' : nodecay_params, 'weight_decay' : 0.0},
        ]
        num_deacy_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_deacy_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(decay_params)}, with {num_nodecay_params:,} parameters")
        
        # Create AdamW and use fused version when it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and 'cuda' in device
        print(f"using fused AdamW: {use_fused}")
        
        # if fused, updating parameters will be fused into a single kernel without iterating all the parameters
        optimizer = torch.optim.AdamW(optim_group, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
        return optimizer
        
# ...

# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)
optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device=device)

# ...
  1. 梯度累积(为了解决显存不够不能放下512个batch) :
total_batch_size = 524288 # 2 ** 19 ~0.5M
B = 4
T = 256
assert total_batch_size % (B * T) == 0, "make sure total_batch_size is divisible by B * T"
grad_accum_steps = total_batch_size // (B * T)
print(f"total desired batch size: {total_batch_size}")
print(f"=> calculated radient accumulation steps: {grad_accum_steps}")

train_loader = DataLoaderLite(B=B, T=T)

# ...

for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)
        # with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits, loss = model(x, y)
        # Notice this carefully!
        loss = loss / grad_accum_steps
        loss.backward()
        
 # ...
 
  1. DDP:
from torch.distributed import init_process_group

ddp = int(os.environ.get('RANK', -1)) != -1 # check if a ddp run
if ddp:
    assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
    init_process_group(backend="nccl")
    ddp_rank = int(os.environ.get('RANK'))
    ddp_local_rank = int(os.environ.get('LOCAL_RANK'))
    ddp_world_size = int(os.environ.get('WORLD_SIZE'))
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_local_rank == 0
else:
    ddp_rank = 0
    ddp_local_rank = 1
    ddp_world_size = 1
    master_process = True
    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
    elif hasattr(torch.backend, "mps") and torch.backend.mps.is_available():
        device = "mps"
    print(f"using device: {device}")
    
# ... 

if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])
    
# ...

for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits, loss = model(x, y)
        loss = loss / grad_accum_steps
        loss_accum += loss.detach()
				# cancel sync
        if ddp:
            model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
        loss.backward()