AI/NLP
[멀티노드 분산학습] FSDP + Accelerate로 Multi Node Training하기
땽뚕
2025. 5. 24. 11:21
728x90
[멀티노드 분산학습] FSDP + Accelerate로 Multi Node Training하기
H100 멀티노드 학습해야 되는데, 방법을 정리한다
FSDP + Accelerate로 Multi Node Training
Takeaway Message.
- FSDP+Accelerate 기반으로 분산학습 수행할 땐 SLURM (Simple Linux Utility for Resource Management)라는 것을 사용한다고 함
- SLURM이란 ?
- cluster server 상에서 작업을 관리하기 위한 프로그램 -> 여러 대의 서버에서 학습을 자동으로 분산 실행하고 자원을 할당해주는 스케줄러 시스템이라고 생각하면 된다
- 멀티 노드 학습을 하려면 GPU, CPU, RAM, 네트워크 자원 등을 여러 노드에서 동시에 예약하고 동기화 실행해야 하는데 이것을 자동화해줌
- 즉 각 노드에 ssh 접속 따로 할 필요 없이 SLURM만 실행하면 됨
- slurm 명령어 : sbatch
- cluster server 상에서 작업을 관리하기 위한 프로그램 -> 여러 대의 서버에서 학습을 자동으로 분산 실행하고 자원을 할당해주는 스케줄러 시스템이라고 생각하면 된다
- SLURM이란 ?
예시 환경
- Fine-tuning Llama 2 70B using PyTorch FSDP
- 노드 수: 2개
- 노드당 GPU 수: 8개 (총 16 GPU)
- GPU 종류: A100 80GB
- 노드당 CPU: 96코어 / RAM 1TB
- 노드 간 연결
- 내부 연결: NVLink
Slurm 스크립트 작성 예시
- 설명
- #sbatch 뒤에 옵션을 달면, slurm 명령어가 실행
- 주요 옵션
- --nodes : 노드 몇 개 사용할지
- --nodelist : 사용할 노드 지정
- --gres : gpu 개수 설정
- --nodes : 노드 몇 개 사용할지
#!/bin/bash
#SBATCH --job-name=ift_llama
#SBATCH --nodes=8
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=96
#SBATCH --mem-per-cpu=11G # Important to enable "mix" use of GPUs across cluster users
#SBATCH --partition=XXXXX
#SBATCH --gres=gpu:8 # Adjust number of GPUs here
#SBATCH --output=shared_storage/sourab/temp/logs/%x-%j.out
#SBATCH --err=shared_storage/sourab/temp/logs/%x-%j.err
set -x -e
# CHANGE HERE THE CONDA EVN AND ANY STARTUP SCRIPTS
source ~/sourab/.bashrc
source shared_storage/sourab/miniconda3/etc/profile.d/conda.sh
conda activate hf
cd shared_storage/sourab/DHS-LLM-Workshop/code_assistant/training
git pull
# have the below in case of debugging nccl issues such as nccl timeout.
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL
# export TORCH_DISTRIBUTED_DEBUG=INFO
# hide duplicated errors using this hack - will be properly fixed in pt-1.12
# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json
# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1
# AWS specific
export NCCL_PROTO=simple
export RDMAV_FORK_SAFE=1
export FI_EFA_FORK_SAFE=1
export FI_EFA_USE_DEVICE_RDMA=1
export FI_PROVIDER=efa
export FI_LOG_LEVEL=1
export NCCL_IB_DISABLE=1
export NCCL_SOCKET_IFNAME=ens
echo "START TIME: $(date)"
# CHANGE TO CUMMULATIVELY LOG OUTPUTS
LOG_PATH="main_log.txt"
GPUS_PER_NODE=8
NNODES=$SLURM_NNODES
NUM_PROCESSES=$(expr $NNODES \* $GPUS_PER_NODE)
# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000
# OTHER LAUNCHERS CAN BE USED HERE
export LAUNCHER="accelerate launch \
--config_file configs/fsdp_config.yaml \
--main_process_ip $MASTER_ADDR \
--main_process_port $MASTER_PORT \
--machine_rank \$SLURM_PROCID \
--num_processes $NUM_PROCESSES \
--num_machines $NNODES \
"
# Note: it is important to escape `$SLURM_PROCID` since we want the srun on each node to evaluate this variable
export PROGRAM="\
train.py \
--model_name "meta-llama/Llama-2-70b-chat-hf" \
--dataset_name "smangrul/code-chat-assistant-v1" \
--max_seq_len 2048 \
--max_steps 500 \
--logging_steps 25 \
--eval_steps 100 \
--save_steps 250 \
--bf16 True \
--packing True \
--output_dir "/shared_storage/sourab/experiments/full-finetune-llama-chat-asst" \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--dataset_text_field "content" \
--use_gradient_checkpointing True \
--learning_rate 5e-5 \
--lr_scheduler_type "cosine" \
--weight_decay 0.01 \
--warmup_ratio 0.03 \
--use_flash_attn True
"
export CMD="$LAUNCHER $PROGRAM"
srun --jobid $SLURM_JOBID bash -c "$CMD" 2>&1 | tee -a $LOG_PATH
echo "END TIME: $(date)"
accelerate 학습 명령어 예시
accelerate launch \
--config_file configs/fsdp_config.yaml \
--main_process_ip
$MASTER_ADDR
\
--main_process_port
$MASTER_PORT
\
--machine_rank
$MACHINE_RANK
\
--num_processes 16 \
--num_machines 2 \
train.py \
--model_name
"meta-llama/Llama-2-70b-chat-hf"
\
--dataset_name
"smangrul/code-chat-assistant-v1"
\
--max_seq_len 2048 \
--max_steps 500 \
--save_steps 250 \
--eval_steps 100 \
--bf16 True \
--gradient_checkpointing True \
--use_flash_attn True \
--output_dir
"/shared_storage/.../full-finetune-llama-chat-asst"
728x90