一句话总结:分布式训练将大模型训练任务分配到多个计算节点上,解决单机无法处理的模型规模和训练效率问题。
| 策略 | 显存节省 | 通信开销 | 计算效率 | 实现难度 |
|---|
| DP | N 倍 | 中(梯度同步) | 高(~100%) | 低 |
| PP | 1/k 倍 | 中(激活传递) | 中(气泡) | 中 |
| TP | 1/k 倍 | 高(all-gather) | 高(~95%) | 中 |
| DP+PP | N/k 倍 | 高 | 中 | 高 |
| DP+TP | N 倍 | 中 | 高 | 中 |
| DP+PP+TP | N/k² 倍 | 很高 | 中 | 很高 |
import torch.distributed as dist
class DataParallelTrainer:
def __init__(self, model, devices):
self.model = model
self.devices = devices
dist.init_process_group(backend='nccl')
self.model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[devices[0]],
output_device=devices[0]
)
def train_step(self, batch):
local_batch = batch.partition(self.global_rank)
loss = self.model(local_batch)
loss.backward()
return loss.item()
gradient_accumulation_steps = 4
for i, batch in enumerate(dataloader):
output = model(batch)
loss = criterion(output, labels) / gradient_accumulation_steps
loss.backward()
if (i + 1) % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
class TensorModelParallel:
"""张量并行:将大矩阵切分到多个 GPU"""
def __init__(self, world_size):
self.world_size = world_size
self.rank = dist.get_rank()
def split_column(self, weight, bias=None):
"""列切分:将输出通道切分"""
shard_size = weight.shape[0] // self.world_size
start = self.rank * shard_size
end = start + shard_size
return weight[start:end, :], bias[start:end] if bias else None
def split_row(self, weight, bias=None):
"""行切分:将输入通道切分"""
shard_size = weight.shape[1] // self.world_size
start = self.rank * shard_size
end = start + shard_size
return weight[:, start:end], bias
class PipelineParallelModel:
def __init__(self, model, n_stages):
self.stages = self.split_model(model, n_stages)
self.n_stages = n_stages
self.rank = dist.get_rank()
def forward(self, inputs):
"""流水线前向传播"""
x = inputs.to(self.rank)
for stage in self.stages:
if stage.rank == self.rank:
x = stage(x)
if dist.get_rank() == 0 and stage.rank == self.rank:
dist.send(x, stage.rank + 1)
else:
x = dist.recv(x, stage.rank - 1)
return x
def backward(self, loss):
"""流水线反向传播"""
loss.backward()
class ColumnParallelLinear(nn.Module):
"""列并行线性层"""
def __init__(self, in_features, out_features, tp_size, rank):
super().__init__()
self.tp_size = tp_size
self.rank = rank
self.out_features_per_gpu = out_features // tp_size
self.weight = nn.Parameter(
torch.empty(self.out_features_per_gpu, in_features)
)
self.bias = nn.Parameter(
torch.empty(self.out_features_per_gpu)
)
def forward(self, x):
output = F.linear(x, self.weight) + self.bias
if self.tp_size > 1:
output = self._all_gather(output)
return output
class RowParallelLinear(nn.Module):
"""行并行线性层"""
def __init__(self, in_features, out_features, tp_size, rank):
super().__init__()
self.tp_size = tp_size
self.rank = rank
self.in_features_per_gpu = in_features // tp_size
self.weight = nn.Parameter(
torch.empty(out_features, self.in_features_per_gpu)
)
self.bias = nn.Parameter(torch.empty(out_features))
def forward(self, x):
if self.tp_size > 1:
x = self._all_gather(x)
output = F.linear(x, self.weight) + self.bias
return output
config = {
"data_parallel_size": 4,
"pipeline_parallel_size": 2,
"tensor_parallel_size": 4,
}
ds_config = {
"train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps": 8,
"steps_per_print": 10,
"pipeline": {
"activation_checkpointing": True,
"partition_activations": True,
"contiguous_memory_optimization": True
},
"zero_optimization": {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto"
}
}
| 策略 | 描述 | 效果 |
|---|
| NCCL 后端 | 使用 NVIDIA Collective Communications | 快 2-3× |
| 梯度压缩 | 量化梯度(8bit/4bit) | 通信量减少 2-4× |
| 梯度累积 | 减少同步频率 | 通信量减少 k 倍 |
| 重叠通信 | 计算与通信重叠 | 有效利用率提升 20-30% |
| All-Gather 优化 | 预分配 buffer | 延迟降低 |
class CommComputeOverlap:
"""通过异步通信重叠计算"""
def forward_backward(self, batch):
output = self.forward_with_comm(batch)
loss = self.criterion(output, labels)
grads = self.backward_with_comm(loss)
dist.barrier()
self.optimizer.step(grads)