ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [멀티노드 분산학습] FSDP + Accelerate로 Multi Node Training하기
    AI/NLP 2025. 5. 24. 11:21
    728x90

     

     

    [멀티노드 분산학습] FSDP + Accelerate로 Multi Node Training하기 

     

     

     

     

    H100 멀티노드 학습해야 되는데, 방법을 정리한다 

     

     

     


    FSDP + Accelerate로 Multi Node Training

    Takeaway Message.

    • FSDP+Accelerate 기반으로 분산학습 수행할 땐 SLURM (Simple Linux Utility for Resource Management)라는 것을 사용한다고 함
      • SLURM이란 ?
        • cluster server 상에서 작업을 관리하기 위한 프로그램 -> 여러 대의 서버에서 학습을 자동으로 분산 실행하고 자원을 할당해주는 스케줄러 시스템이라고 생각하면 된다
          • 멀티 노드 학습을 하려면 GPU, CPU, RAM, 네트워크 자원 등을 여러 노드에서 동시에 예약하고 동기화 실행해야 하는데 이것을 자동화해줌
          • 즉 각 노드에 ssh 접속 따로 할 필요 없이 SLURM만 실행하면 됨
        • slurm 명령어 : sbatch

    예시 환경

    Slurm 스크립트 작성 예시

    • 설명
      • #sbatch 뒤에 옵션을 달면, slurm 명령어가 실행
      • 주요 옵션
        • --nodes : 노드 몇 개 사용할지
          • --nodelist : 사용할 노드 지정
        • --gres : gpu 개수 설정
    #!/bin/bash
    #SBATCH --job-name=ift_llama
    #SBATCH --nodes=8
    #SBATCH --ntasks-per-node=1          # crucial - only 1 task per dist per node!
    #SBATCH --cpus-per-task=96
    #SBATCH --mem-per-cpu=11G # Important to enable "mix" use of GPUs across cluster users
    #SBATCH --partition=XXXXX
    #SBATCH --gres=gpu:8 # Adjust number of GPUs here
    #SBATCH --output=shared_storage/sourab/temp/logs/%x-%j.out
    #SBATCH --err=shared_storage/sourab/temp/logs/%x-%j.err
    
    set -x -e
    
    # CHANGE HERE THE CONDA EVN AND ANY STARTUP SCRIPTS
    source ~/sourab/.bashrc
    source shared_storage/sourab/miniconda3/etc/profile.d/conda.sh
    conda activate hf
    cd shared_storage/sourab/DHS-LLM-Workshop/code_assistant/training
    git pull
    
    # have the below in case of debugging nccl issues such as nccl timeout.
    # export NCCL_DEBUG=INFO
    # export NCCL_DEBUG_SUBSYS=ALL
    # export TORCH_DISTRIBUTED_DEBUG=INFO
    # hide duplicated errors using this hack - will be properly fixed in pt-1.12
    # export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json
    
    # force crashing on nccl issues like hanging broadcast
    export NCCL_ASYNC_ERROR_HANDLING=1
    # export NCCL_DEBUG=INFO
    # export NCCL_DEBUG_SUBSYS=COLL
    # export NCCL_SOCKET_NTHREADS=1
    # export NCCL_NSOCKS_PERTHREAD=1
    # export CUDA_LAUNCH_BLOCKING=1
    
    # AWS specific
    export NCCL_PROTO=simple
    export RDMAV_FORK_SAFE=1
    export FI_EFA_FORK_SAFE=1
    export FI_EFA_USE_DEVICE_RDMA=1
    export FI_PROVIDER=efa
    export FI_LOG_LEVEL=1
    export NCCL_IB_DISABLE=1
    export NCCL_SOCKET_IFNAME=ens
    
    echo "START TIME: $(date)"
    
    # CHANGE TO CUMMULATIVELY LOG OUTPUTS
    LOG_PATH="main_log.txt"
    
    GPUS_PER_NODE=8
    NNODES=$SLURM_NNODES
    NUM_PROCESSES=$(expr $NNODES \* $GPUS_PER_NODE)
    
    # so processes know who to talk to
    MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
    MASTER_PORT=6000
    
    # OTHER LAUNCHERS CAN BE USED HERE
    export LAUNCHER="accelerate launch \
        --config_file configs/fsdp_config.yaml \
        --main_process_ip $MASTER_ADDR \
        --main_process_port $MASTER_PORT \
        --machine_rank \$SLURM_PROCID \
        --num_processes $NUM_PROCESSES \
        --num_machines $NNODES \
        "
    # Note: it is important to escape `$SLURM_PROCID` since we want the srun on each node to evaluate this variable
    
    export PROGRAM="\
    train.py \
        --model_name "meta-llama/Llama-2-70b-chat-hf" \
        --dataset_name "smangrul/code-chat-assistant-v1" \
        --max_seq_len 2048 \
        --max_steps 500 \
        --logging_steps 25 \
        --eval_steps 100 \
        --save_steps 250 \
        --bf16 True \
        --packing True \
        --output_dir "/shared_storage/sourab/experiments/full-finetune-llama-chat-asst" \
        --per_device_train_batch_size 1 \
        --gradient_accumulation_steps 1 \
        --dataset_text_field "content" \
        --use_gradient_checkpointing True \
        --learning_rate 5e-5  \
        --lr_scheduler_type "cosine" \
        --weight_decay 0.01 \
        --warmup_ratio 0.03 \
        --use_flash_attn True
    "
    
    
    export CMD="$LAUNCHER $PROGRAM"
    
    srun --jobid $SLURM_JOBID bash -c "$CMD" 2>&1 | tee -a $LOG_PATH
    
    echo "END TIME: $(date)"
     

    accelerate 학습 명령어 예시

    accelerate launch \
    --config_file configs/fsdp_config.yaml \
    --main_process_ip
    $MASTER_ADDR
    \
    --main_process_port
    $MASTER_PORT
    \
    --machine_rank
    $MACHINE_RANK
    \
    --num_processes 16 \
    --num_machines 2 \
    train.py \
    --model_name
    "meta-llama/Llama-2-70b-chat-hf"
    \
    --dataset_name
    "smangrul/code-chat-assistant-v1"
    \
    --max_seq_len 2048 \
    --max_steps 500 \
    --save_steps 250 \
    --eval_steps 100 \
    --bf16 True \
    --gradient_checkpointing True \
    --use_flash_attn True \
    --output_dir
    "/shared_storage/.../full-finetune-llama-chat-asst"
     









    728x90
Designed by Tistory.