AI/NLP

[멀티노드 분산학습] FSDP + Accelerate로 Multi Node Training하기

땽뚕 2025. 5. 24. 11:21
728x90

 

 

[멀티노드 분산학습] FSDP + Accelerate로 Multi Node Training하기 

 

 

 

 

H100 멀티노드 학습해야 되는데, 방법을 정리한다 

 

 

 


FSDP + Accelerate로 Multi Node Training

Takeaway Message.

  • FSDP+Accelerate 기반으로 분산학습 수행할 땐 SLURM (Simple Linux Utility for Resource Management)라는 것을 사용한다고 함
    • SLURM이란 ?
      • cluster server 상에서 작업을 관리하기 위한 프로그램 -> 여러 대의 서버에서 학습을 자동으로 분산 실행하고 자원을 할당해주는 스케줄러 시스템이라고 생각하면 된다
        • 멀티 노드 학습을 하려면 GPU, CPU, RAM, 네트워크 자원 등을 여러 노드에서 동시에 예약하고 동기화 실행해야 하는데 이것을 자동화해줌
        • 즉 각 노드에 ssh 접속 따로 할 필요 없이 SLURM만 실행하면 됨
      • slurm 명령어 : sbatch

예시 환경

Slurm 스크립트 작성 예시

  • 설명
    • #sbatch 뒤에 옵션을 달면, slurm 명령어가 실행
    • 주요 옵션
      • --nodes : 노드 몇 개 사용할지
        • --nodelist : 사용할 노드 지정
      • --gres : gpu 개수 설정
#!/bin/bash
#SBATCH --job-name=ift_llama
#SBATCH --nodes=8
#SBATCH --ntasks-per-node=1          # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=96
#SBATCH --mem-per-cpu=11G # Important to enable "mix" use of GPUs across cluster users
#SBATCH --partition=XXXXX
#SBATCH --gres=gpu:8 # Adjust number of GPUs here
#SBATCH --output=shared_storage/sourab/temp/logs/%x-%j.out
#SBATCH --err=shared_storage/sourab/temp/logs/%x-%j.err

set -x -e

# CHANGE HERE THE CONDA EVN AND ANY STARTUP SCRIPTS
source ~/sourab/.bashrc
source shared_storage/sourab/miniconda3/etc/profile.d/conda.sh
conda activate hf
cd shared_storage/sourab/DHS-LLM-Workshop/code_assistant/training
git pull

# have the below in case of debugging nccl issues such as nccl timeout.
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL
# export TORCH_DISTRIBUTED_DEBUG=INFO
# hide duplicated errors using this hack - will be properly fixed in pt-1.12
# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json

# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1

# AWS specific
export NCCL_PROTO=simple
export RDMAV_FORK_SAFE=1
export FI_EFA_FORK_SAFE=1
export FI_EFA_USE_DEVICE_RDMA=1
export FI_PROVIDER=efa
export FI_LOG_LEVEL=1
export NCCL_IB_DISABLE=1
export NCCL_SOCKET_IFNAME=ens

echo "START TIME: $(date)"

# CHANGE TO CUMMULATIVELY LOG OUTPUTS
LOG_PATH="main_log.txt"

GPUS_PER_NODE=8
NNODES=$SLURM_NNODES
NUM_PROCESSES=$(expr $NNODES \* $GPUS_PER_NODE)

# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000

# OTHER LAUNCHERS CAN BE USED HERE
export LAUNCHER="accelerate launch \
    --config_file configs/fsdp_config.yaml \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --machine_rank \$SLURM_PROCID \
    --num_processes $NUM_PROCESSES \
    --num_machines $NNODES \
    "
# Note: it is important to escape `$SLURM_PROCID` since we want the srun on each node to evaluate this variable

export PROGRAM="\
train.py \
    --model_name "meta-llama/Llama-2-70b-chat-hf" \
    --dataset_name "smangrul/code-chat-assistant-v1" \
    --max_seq_len 2048 \
    --max_steps 500 \
    --logging_steps 25 \
    --eval_steps 100 \
    --save_steps 250 \
    --bf16 True \
    --packing True \
    --output_dir "/shared_storage/sourab/experiments/full-finetune-llama-chat-asst" \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --dataset_text_field "content" \
    --use_gradient_checkpointing True \
    --learning_rate 5e-5  \
    --lr_scheduler_type "cosine" \
    --weight_decay 0.01 \
    --warmup_ratio 0.03 \
    --use_flash_attn True
"


export CMD="$LAUNCHER $PROGRAM"

srun --jobid $SLURM_JOBID bash -c "$CMD" 2>&1 | tee -a $LOG_PATH

echo "END TIME: $(date)"
 

accelerate 학습 명령어 예시

accelerate launch \
--config_file configs/fsdp_config.yaml \
--main_process_ip
$MASTER_ADDR
\
--main_process_port
$MASTER_PORT
\
--machine_rank
$MACHINE_RANK
\
--num_processes 16 \
--num_machines 2 \
train.py \
--model_name
"meta-llama/Llama-2-70b-chat-hf"
\
--dataset_name
"smangrul/code-chat-assistant-v1"
\
--max_seq_len 2048 \
--max_steps 500 \
--save_steps 250 \
--eval_steps 100 \
--bf16 True \
--gradient_checkpointing True \
--use_flash_attn True \
--output_dir
"/shared_storage/.../full-finetune-llama-chat-asst"
 









728x90