0. 前言

Node: 一个节点, 可以理解为一台电脑.

Device: 工作设备, 可以简单理解为一张卡, 即一个GPU.

Process: 一个进程, 可以简单理解为一个Python程序.

Threading: 一个线程, 一个进程可以有多个线程, 它们共享资源.


  1. 建议: 使用torchrun, 不要使用multiprocessing, torchrun会在程序中断后帮你杀死线程, 使用multiprocessing很容易造成僵尸线程, 资源无法释放.

1. 什么是数据并行化

随着模型参数和数据量越来越大, 分布式训练成为了深度学习模型训练中越来越重要的一环. 分布式训练包括两类: 模型并行化数据并行化. 在模型并行化中, 一个Device负责处理模型的一个切片 (例如模型的一层); 而在数据并行化中, 一个Device负责处理数据的一个切片 (即Batch的一部分). 我们今天讨论的torch.nn.parallel.DistributedDataParallel就是由pytorch提供的一种数据并行化方式.

2. 为什么要使用torch.nn.parallel.DistributedDataParallel

相较于torch.nn.parallel.DistributedDataParallel, 一个更易于使用也更被人熟知的接口是torch.nn.DataParallel. 该接口只需要一行修改即可实现"数据并行化" (具体参考知乎):

device_ids = [0, 1]
model = torch.nn.DataParallel(model, device_ids=device_ids)

此方法虽然简单, 但是存在若干问题, 例如设备间负载不均; 效率不高等. 现在官方推荐的方法为torch.nn.parallel.DistributedDataParallel ( Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel ).

为何如此呢? 简单来说, torch.nn.DataParallel 实现的是单机-多线程 (Single-Node Multi-threading), 而torch.nn.parallel.DistributedDataParallel 实现的是 单机/多机-多进程 (Single/Multi-Node Multi-process). 即torch.nn.parallel.DistributedDataParallel中的每一个模型是由一个独立的Process来控制的.1

此外, 官方文档中指出torch.nn.parallel.DistributedDataParallel是和模型并行化兼容的, 而torch.nn.DataParallel则不可以.

3. torch.nn.parallel.DistributedDataParallel 究竟在做什么?

Fig 1. torch.nn.parallel.DistributedDataParallel (https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html)

简单来说, 如果我们有$N$个Device (即$N$张卡), 每次的batch有$N\times M$个数据, 那我们如果能将模型分别复制到这$N$张卡上, 每张卡负责计算$M$个数据的$loss$的平均梯度, 然后将这些$N$张卡上梯度平均起来2, 同时将梯度更新到所有模型上3. 那我们相当于花了计算$M$个数据的时间, 完成了$N\times M$条数据的计算.

而多机希望实现的是: 简单来说, 就是你只需要知道我有多少个Device, 而不需要管这些Device分布在多少个Node上.

而为了保证模型间参数的同步, 多设备间需要通讯, 这是通过后端 (Backend 来完成的, 见torch.distributed), 简单来说: 如果想使用GPU训练用nccl, 如果使用CPU用gloo4.

  • Each process maintains its own optimizer and performs a complete optimization step with each iteration. While this may appear redundant, since the gradients have already been gathered together and averaged across processes and are thus the same for every process, this means that no parameter broadcast step is needed, reducing time spent transferring tensors between nodes.
  • Each process contains an independent Python interpreter, eliminating the extra interpreter overhead and “GIL-thrashing” that comes from driving several execution threads, model replicas, or GPUs from a single Python process. This is especially important for models that make heavy use of the Python runtime, including models with recurrent layers or many small components.

4. torch.nn.parallel.DistributedDataParallel 使用范例

导入需要的库

import os
import torch
import argparse
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.nn as nn
import torch.distributed as dist
from datetime import timedelta

定义一个简单的模型

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

Fig 1. 中的多个Process形成了一个Process group, 在使用torch.nn.parallel.DistributedDataParallel之间我们需要初始化它, 初始化需要两个参数global_rankworld_size:

  • 其中world_size是指你一共有多少Process, 即world_size = 节点数量 * 每个节点上有多少Process = nnode * nproc_per_node.

  • 而对于每一个Process, 它都有一个local_rankglobal_rank, local_rank对应的就是该Process在自己的Node上的编号, 而global_rank就是全局的编号.

    • 比如你有$2$个Node, 每个Node上各有$2$个Proess (Process0, Process1, Process2, Process3). 那么对于Process2来说, 它的local_rank就是$0$ (即它在Node1上是第$0$个Process), global_rank 就是$2$.
    • 不难发现, local_rank对应的就是该Process需要使用的Device(GPU)编号 (并不一定, 但这是一种方便的方法).
def setup(global_rank, world_size):
    # 配置Master Node的信息
    # os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_ADDR'] = 'XXX.XXX.XXX.XXX'
    # os.environ['MASTER_PORT'] = '23555'
    os.environ['MASTER_PORT'] = 'XXXX'

    # 初始化Process Group
    # 关于init_method, 参数详见https://pytorch.org/docs/stable/distributed.html#initialization
    dist.init_process_group("nccl", init_method='env://', rank=global_rank, world_size=world_size, timeout=timedelta(seconds=5))

def cleanup():
    dist.destroy_process_group()
  • 对于Process group来说需要有一个Master node, 可以理解为是这个Process group的根节点, 我们一般配置为为Node0, 在后续在各个节点上启动代码时, 最好也是先从Node0开始启动. 需要确保Master node的IP地址可访问, 且端口没有被占用.
  • 对于响应超时可以自行设置超时时间timeout, 官方建议设置一个较大的时间, 以防可能出现的网络延迟.

定义训练过程

接下来我们定义在每一个Process中我们希望执行的代码.

def run_demo(local_rank, args):
    # 计算global_rank和world_size
    global_rank = local_rank + args.node_rank * args.nproc_per_node
    world_size = args.nnode * args.nproc_per_node
    setup(global_rank=global_rank, world_size=world_size)
    # 设置seed
    torch.manual_seed(args.seed)

    # 创建模型, 并将其移动到local_rank对应的GPU上
    model = ToyModel().to(local_rank)
    ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(local_rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    print([data for data in model.parameters()])

    cleanup()
  • 需要注意, 我们每一个Process的应该运行在它自己对应的GPU上, 所以我们在代码中需要加上以下三种方式的一种:

    • # 方法1
      torch.cuda.set_device(local_rank)
      # 方法2
      with torch.cuda.device(local_rank)
      # 方法3
      model = ToyModel().to(local_rank)
      
  • DDP中, 所有模型都是以相同的参数被初始化, 同时训练过程中的梯度会在backward pass中被同步, 这就保证了在optimizer的优化过程中所有模型的参数保持一致.

多线程执行

最后, 我们需要在每一个Node上启动nproc_per_nodeProcess, 这一步可以使用torch.distributed.launch/torchrun/multiprocessing来实现:

  • 虽然为了代码清晰这里我们用了multiprocessing来实现, 但是在代码意外退出的时候它容易出现僵尸进程等Bug (Github), 在工程代码中不建议使用.

  • torchrun是为了替代torch.distributed.launch的新型启动方式, 可以支持ELASTIC LAUNCH, 即动态控制启动的节点数量, 但是由于是新功能, 只有最新的torch 1.10, 处于兼容性考虑还是建议使用torch.distributed.launch.

    • Single-node multi-worker
      
      >>> torchrun
          --standalone
          --nnodes=1
          --nproc_per_node=$NUM_TRAINERS
          YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
      Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures):
      
      >>> torchrun
          --nnodes=$NUM_NODES
          --nproc_per_node=$NUM_TRAINERS
          --max_restarts=3
          --rdzv_id=$JOB_ID
          --rdzv_backend=c10d
          --rdzv_endpoint=$HOST_NODE_ADDR
          YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
      
  • torch.distributed.launch的使用也很简单:

    • python -m torch.distributed.launch --nnodes=NNODE --node_rank=NODE_RANK --nproc_per_node=NPROC_PER_NODE \
      YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other arguments of your training script)
      
    • 然后记得你的代码中一定需要设一个--local_rank的参数, torch.distributed.launch会传给你对应的local_rank.

      • >>> import argparse
        >>> parser = argparse.ArgumentParser()
        >>> parser.add_argument("--local_rank", type=int)
        >>> args = parser.parse_args()
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int)
    parser.add_argument('--nproc_per_node', type=int)
    parser.add_argument('--nnode', type=int)
    parser.add_argument('--node_rank', type=int)
    args = parser.parse_args()

    mp.spawn(run_demo, args=(args,), nprocs=args.nproc_per_node)
  • 执行时

    • Node0: python DDP_test.py --seed 1 --nproc_per_node 1 --nnode 2 --node_rank 0
    • Node1: python DDP_test.py --seed 1 --nproc_per_node 1 --nnode 2 --node_rank 1

5. 其他注意事项

并行化中的数据集

上文中我们也提到, 我们希望将$N\times M$个数据不重合地拆分成$N$份, 每份$M$条数据, 为此我们需要使用到DistributedSampler:

# 载入数据集
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
# 配置sampler, 相当于我们希望将原本数据集划分成`world_size`份
# torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)
train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset,
    num_replicas=world_size,
    rank=global_rank,
    seed=seed,
    shuffle=True,
    pin_memory=True
)

image.png

  • num_workers设置了加载数据集的使用的线程数, 0表示只有主进程读取数据集. 关于num_workers的配置可以见知乎:DataLoader的num_workers设置|加速.
  • pin_memory 锁页内存, 可以加速数据读取. (也可能会导致Bug)

并行化中Save和Load模型 (Pytorch)

当使用DDP涉及到保存和读取模型的时候, 我们自然希望的是: 只需要有一个Process保存模型, 同时能够将模型读取到所有Process.

在实现中我们需要注意的是:

  • 确保Saving的过程结束了才会有Process执行Loading: 借助barrier()实现.
  • 当读取模型的时候, 需要设置正确的map_localtion, 以防止将模型加载到其他Process所处的Device(GPU)上. 因为当没有指定map_location时, Pytorch是默认先把模型参数加载到CPU中, 然后把每个参数复制到它被保存时所在的Device上.
def run_demo_checkpoint(local_rank, args):
    # 计算global_rank和world_size
    global_rank = local_rank + args.node_rank * args.nproc_per_node
    world_size = args.nnode * args.nproc_per_node
    setup(global_rank=global_rank, world_size=world_size)
    print(f"Running DDP checkpoint example on rank {global_rank}.")

    # 设置seed
    torch.manual_seed(args.seed)

    model = ToyModel().to(local_rank)
    ddp_model = DDP(model, device_ids=[local_rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"

    if global_rank == 0:
        # 只在Process0中保存模型
        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

    # barrier(): 可以理解为只有当所有Process都到达了这一步才能继续往下执行
    # 以此保证其他Process只有在Process0完成保存后才可能读取模型
    dist.barrier()

    # 配置`map_location`.
    map_location = torch.device(f'cuda:{local_rank}')

    ddp_model.load_state_dict(
        torch.load(CHECKPOINT_PATH, map_location=map_location))

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(local_rank)
    loss_fn = nn.MSELoss()
    loss_fn(outputs, labels).backward()
    optimizer.step()
    print(outputs)

    # Not necessary to use a dist.barrier() to guard the file deletion below
    # as the AllReduce ops in the backward pass of DDP already served as
    # a synchronization.

    if global_rank == 0:
        os.remove(CHECKPOINT_PATH)

    cleanup()

Join: 处理uneven data

我们知道DDP中,Pytorch会在每次backward pass的时候做一次synchronization, 以保证梯度的同步, 但是这就存在一个问题, 如果不同的Process所分配到的数据长度不一样怎么办. 例如Process1中如果有$5$个batch, Process2中只有$6$个batch, 那么Process2在处理最后一个batch的时候就会无限挂起等待其他Process, DDP中提供了Join接口来解决这一问题. (Github Issue)

import torch
import torch.distributed as dist
import os
import torch.multiprocessing as mp
import torch.nn as nn

def worker(rank):
    dist.init_process_group("nccl", rank=rank, world_size=2)
    torch.cuda.set_device(rank)
    model = nn.Linear(1, 1, bias=False).to(rank)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[rank], output_device=rank
    )
    # Rank1 比 rank0 会多一个Batch.
    inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
    with model.join():
        for _ in range(5):
            for inp in inputs:
                loss = model(inp).sum()
                loss.backward()
    # 如果没有join() API, 下面的同步语句就会无限挂起等待Rank1的
    # allreduce完成.
    torch.cuda.synchronize(device=rank)
  • Join的大致逻辑是, 当某一个Process提前用完了它的数据, 它就会进入joining模式, 从而"欺骗"和其他进程之间的allreduce: 比如DDP中用尽数据的Process在梯度allreduce的时候会提供一个全为$0$的梯度用于同步.

  • DDPjoin接口具体为: join(divide_by_initial_world_size=True, enable=True, throw_on_early_termination=False)

    • 其中divide_by_initial_world_size指的是在average梯度的时候, 除以的是world_size(初始化时候的进程数), 还是现在剩有数据的进程数(即non-joining的进程数). 官方建议是如果不同进程间输入的差异是微小的, 则设置为True; 如果差异非常巨大则设置为False.

    • enable指的是是否能够检测uneven input.

    • throw_on_early_termination指是否在有进程耗尽数据时抛出异常. 如果设置为True, 则会在第一个进程耗尽数据后抛出异常. 注意如果设置了throw_on_early_termination, divide_by_initial_world_size会被忽视.

      • Note: If the model or training loop this context manager is wrapped around has additional distributed collective operations, such as SyncBatchNorm in the model’s forward pass, then the flag throw_on_early_termination must be enabled. This is because this context manager is not aware of non-DDP collective communication. This flag will cause all ranks to throw when any one rank exhausts inputs, allowing these errors to be caught and recovered from across all ranks.

Unused Parameters

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter dete
ction by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by                                                                                                   
making sure all `forward` function outputs participate in calculating loss.                                                                                                                                                  

If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure
of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).

常用函数

6. 参考资料


  1. 线程和进程的区别: 建议百度. ↩︎

  2. 当不同卡上处理的数据量($M$)不同时, 不能直接算平均$loss$: Github Issue↩︎

  3. 由于在backward-pass计算梯度时, 该层的梯度不依赖于前面的层, 所以torch.nn.parallel.DistributedDataParallel中各卡上模型参数的同步是跟随着梯度backward-pass同时完成的. ↩︎

  4. 一个注释, 我在使用nccl的时有时候会出现Socket Timeout的错误, 目前还是一个Open Issue, 有建议说如果不行可以改使用gloo. (Github Issue). ↩︎