FSDP(Fully Sharded Data Parallel)是一种数据并行技术,最早在2021年由FairScale-FSDP提出,并在PyTorch 1.11版本中被集成。FSDP可以看作是微软DeepSpeed框架中提出的ZERO算法的ZERO-3级别的实现。以下是FSDP的一些关键特点:

  1. 模型参数分片:FSDP通过对模型参数、梯度和优化器状态进行分片处理,使得每个GPU只存储部分参数信息,从而减少单个GPU的内存占用。

  2. 减少内存占用:FSDP使得训练更大的模型成为可能,因为它允许更大的模型或批量大小适配GPU容量,减少内存需求。

  3. 通信开销:虽然FSDP带来的问题是增加了process/worker节点之间的通信开销,但可以通过在PyTorch内部进行优化来降低通信代价,例如通过重叠通信和计算来减少网络开销。

  4. 与DDP的区别:传统的数据并行(DDP)在每个GPU上保存整个模型的参数、梯度和优化器状态,而FSDP则将这些信息分片,每个GPU只保存部分信息。

  5. 实现方式:FSDP的实现涉及到将DDP中的all-reduce操作拆解为reduce-scatter和all-gather操作,以实现参数的分片处理。

  6. 性能:实验表明,FSDP能够实现与分布式数据并行相当的性能,并为更大的模型提供近线性可扩展性的支持。

  7. 易用性:FSDP提供了易用的API,可以非常方便地解决大模型分布式训练的难题。

  8. 与DeepSpeed ZeRO-DP的比较:FSDP受启发于DeepSpeed ZeRO-DP,并进行了进一步的延申和拓展。FSDP包括了NO_SGARD(等效于DDP)、SHARD_GRAD_OP(对标ZeRO2)和FULL_SHARD(对标ZeRO3)等分片策略。

FSDP是一种高效的大规模模型训练解决方案,它通过分片模型参数和优化器状态,提高了内存利用效率和扩展能力,使得在资源有限的情况下训练超大模型成为可能。