- 一种快速构建序列数据集的方式:
def next_batch(self):
B, T = self.B, self.T
buf = self.tokens[self.current_pos : self.current_pos + B * T + 1]
x = buf[:-1].view(B, T)
y = buf[1:].view(B, T)
- 固定初始化参数的方差(不是根据隐藏层大小而变化的方差):
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
- 权重共享方案:词嵌入权重和最后线性映射的权重是一样的:
# share weights
self.transformer.wte.weight = self.lm_head.weight
- 残差缩放,更稳定:
class MLP:
# ...
# used for residual scaling
self.c_proj.NANOGPT_SCALE_N = 1
#...
class CausalSelfAttention
# ...
# used for residual scaling
self.c_proj.NANOGPT_SCALE_N = 1
#...
def _init_weights(self, module):
if isinstance(module, nn.Linear):
std = 0.02
# Here to scale down
# Reason to * 2: In each block(layer), there are 2 residuals
if hasattr(module, 'NANOGPT_SCLAE_N'):
std *= (2 * self.config.n_layers) ** -0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
- 使用更低的精度训练:
# lower presicion for training
torch.set_float32_matmul_precision('high')
#...
# Using bfloat16
with torch.autocast(device_type=device, dtype=torch.bfloat16):
logits, loss = model(x, y)
loss.backward()
- torch.compile:
model = torch.compile(model)
- flash attention:
# flash_attention
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
- 利用缓存,将矩阵大小设置为2的次方(或尽可能的多含有2的次方):
例如,更改词汇表的小大:
# 50257 in previous
model = GPT(GPTConfig(vocab_size=50304))
GPT2训练细节
- 优化器与梯度裁剪:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)
# ...
loss.backward()
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# ...
- scheduler(余弦与热身):
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
max_steps = 50
def get_lr(it):
if it < warmup_steps:
return max_lr * (it + 1) / warmup_steps
if it > max_steps:
return min_lr
decay_co = (it - warmup_steps) / (max_steps - warmup_steps)
assert 0 <= decay_co <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_co))
return min_lr + coeff * (max_lr - min_lr)
# in training
loss.backward()
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# set lr for each param
lr = get_lr(i)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
optimizer.step()
- weight_decay:
def configure_optimizers(self, weight_decay, learning_rate, device):
param_dict = {pn : p for pn, p in self.named_parameters()}
param_dict = {pn : p for pn, p in param_dict.items() if p.requires_grad}
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
# key insight
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_group = [
{'params' : decay_params, 'weight_decay' : weight_decay},
{'params' : nodecay_params, 'weight_decay' : 0.0},
]
num_deacy_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_deacy_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(decay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW and use fused version when it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and 'cuda' in device
print(f"using fused AdamW: {use_fused}")
# if fused, updating parameters will be fused into a single kernel without iterating all the parameters
optimizer = torch.optim.AdamW(optim_group, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
return optimizer
# ...
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)
optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device=device)
# ...
- 梯度累积(为了解决显存不够不能放下512个batch) :
total_batch_size = 524288 # 2 ** 19 ~0.5M
B = 4
T = 256
assert total_batch_size % (B * T) == 0, "make sure total_batch_size is divisible by B * T"
grad_accum_steps = total_batch_size // (B * T)
print(f"total desired batch size: {total_batch_size}")
print(f"=> calculated radient accumulation steps: {grad_accum_steps}")
train_loader = DataLoaderLite(B=B, T=T)
# ...
for micro_step in range(grad_accum_steps):
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)
# with torch.autocast(device_type=device, dtype=torch.bfloat16):
logits, loss = model(x, y)
# Notice this carefully!
loss = loss / grad_accum_steps
loss.backward()
# ...
- DDP:
from torch.distributed import init_process_group
ddp = int(os.environ.get('RANK', -1)) != -1 # check if a ddp run
if ddp:
assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
init_process_group(backend="nccl")
ddp_rank = int(os.environ.get('RANK'))
ddp_local_rank = int(os.environ.get('LOCAL_RANK'))
ddp_world_size = int(os.environ.get('WORLD_SIZE'))
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
master_process = ddp_local_rank == 0
else:
ddp_rank = 0
ddp_local_rank = 1
ddp_world_size = 1
master_process = True
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backend, "mps") and torch.backend.mps.is_available():
device = "mps"
print(f"using device: {device}")
# ...
if ddp:
model = DDP(model, device_ids=[ddp_local_rank])
# ...
for micro_step in range(grad_accum_steps):
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
logits, loss = model(x, y)
loss = loss / grad_accum_steps
loss_accum += loss.detach()
# cancel sync
if ddp:
model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
loss.backward()