ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [2023 NLP GLUE Seminar - Week02 Pruning] From Dense to Sparse: Contrastive Pruning for Better Pre-trained Language Model Compression 리뷰 (AAAI, 2022)
    AI/NLP 2023. 1. 9. 15:28
    728x90
    [2023 NLP GLUE Seminar - Week02 Pruning] 
    From Dense to Sparse: Contrastive Pruning for Better Pre-trained Language Model Compression 리뷰 (AAAI, 2022)

     

     

     

     

    선수 지식 : pruning 기법, contrastive learning 

     

     

     

     

    이 글은 AI 분야에서의 pruning을 다루는 글입니다

     

     

    목차 

     

    • 들어가기 전에 ... : Pruning이란 ? 
    • Introduction
    • Methodology
    • Experiments
    • Discussions
    • Conclusion

     

     

     

    들어가기 전에 .... : Pruning이란 ? 

     

     

     

    • Pruning의 정의
      • 가지치기로 해석할 수 있는 Pruning은 Model의 weight들 중 중요도가 낮은 weight의 연결을 제거하여 모델의 파라미터를 줄이는 방법
      • Pruning은 네트워크의 성능이 크게 저하되지 않는 선에서 weight들을 최대한 sparse(희소하게 : 대부분의 값이 0이도록) 하게 만드는 방법 (bias는 상대적으로 파라미터에서 차지하는 비율이 작기 때문에 pruning하지 않음)

     

    • Pruning의 방법
      1. 일반적인 네트워크 학습 진행
      2. weight값이 작은 연결을 제거함 
      3. 남아 있는 연결을 유지한 상태로 가중치를 재학습(retraining) / Fine tuning 함 
        • pruning한 네트워크를 재 학습 하는 경우 성능이 더 상승한다고 알려져 있음

    Pruning은 한번에 수행되지 않고, 가지치기 이후 retraining(fine tuning이나 from scratch)과정을 몇번씩 거친다. 한번에 수행되면 weight들이 다 잘려나가 성능이 급격히 떨어진다. 그보다는, 여러번 반복하여 성능을 복원했다가 야금야금 pruning하는 방법이 주로 사용된다.

     

     

     

    논문 요약) Training과 Pruning의 시빌 워

     

     

     

    • Pruning의 결과 
      • 모델의 weight값들은 대다수가 0인 Sparse matrix구조가 됨

     

     

    • Pruning의 이점 
      • 모델의 정확도가 손상되지 않는 범위에서 메모리, 배터리, 하드웨어 소비량을 줄이고
      • Regularization이 일반화 성능을 높인다.
      • 추론 속도가 빨라진다
      • 기기에 경량화된 모델을 배치해서 개인이 이용하고 있는 기기에서 프라이버시가 보장될 수 있음 (ref. pytorch document)
    • Pruning의 단점
      • 정보의 손실 생기고 
      • 입자도(granularity, 세밀함)가 하드웨어 가속 디자인의 효율성에 영향을 미치기에
        • 또, 너무 sparse하게 만들어버리면 하드웨어 가속 효율이 떨어짐 

     

     

     👾🧐 그럼 dropout과는 뭐가 다른거지 ? 🧐👾 


    Pruning한번 잘라낸 뉴런을 보관하지 않는다.

    그러나 Dropout은 regularization이 목적이므로 학습 시에 뉴런들을 랜덤으로 껐다가 (보관해두고) 다시 켜는 과정을 반복한다. 추론 시에는 모든 뉴런을 켜고 수행한다.

     

     

     

     

    • Pruning할 때 설정해야 할 요소 세 가지 
      1. Pruning granularity(입자도, 세밀함)
        • Pruning할 요소를 정하는것이 필요하다.
        • Pruning할 때 각 요소를 pruning하는 것을 element-wise pruning 또는 fine-grained pruning이라고 함
        • 어떠한 요소들을 그룹으로 묶고 그 그룹 전체에 대해 pruning 하는것은 Coarse-grained pruning, structed pruning, group pruning 또는 block pruning 이라 함
      2. Pruning criteria
        • weight들 중 어떤 요소를 어떻게 prune할지 정하기 위한 기준
      3. Pruning schedule
        • 이 때, one-shot pruning을 할지 iterative pruning을 할지, iterative pruning이라면 얼마나 많이 반복할지, 매 iteration 마다 Prunning criteria을 어떻게 설정할지, 어떤 weight를 prune할지, 언제 멈출지와 같은 정보를 총칭하여 Pruning schedule이라고 함. pruning을 멈추는 타이밍 또한 schedule로 표현할 수 있습니다.
    • Pruning 기법
      • 무엇을 잘라낼 것인가?(What to prune?)
        • unstructured : 무작위로 각각의 weight들을 잘라내기
        • structured : 단위를 잡아서 한번에 잘라내기
      • 어떻게 잘라낼 것인가?(How to prune?)
      • 언제 잘라낼 것인가?(When to prune?)
      • 얼마나 자주 잘라낼 것인가?(How often?)
        • One shot pruning  :학습 완료 후 한번 pruning하는 것
        • iterative pruning : 학습과 pruning과정을 거친 후 sparse한 네트워크를 다시 학습하는 것 
          • pruning한 네트워크를 재 학습 하는 경우 성능이 더 상승한다고 알려져 있음

     

     

     


     

    Introduction 

     

     

    등장 배경 & 이전 연구들의 한계 

     

    • Pretraining Language Model(이하 PLM)은 NLP의 여러 분야에서 성공적인 성능 향상을 가져왔음
    • 그러나, Parameter들이 엄청 많은, 무지막지하게 큰 초거대 모델들이고, 불필요한 중복된 가중치들이 많아, 이를 압축하기 위한 시도들이 있었다 
      • 이를 해결하기 위해 Pruning 등장 ! 

     

     

    • 하지만 대부분의 방법들은 다음과 같은 한계점이 있다
      • downstream tasks에 대한 "task-specific knowledge"만을 고려하고, 필요한 "task-agnostic knowledge"들을 무시하고 Pruning한다  → 이로 인해 generalization 능력 저하 

     

     





    ❓❓ 위의 말이 무슨 뜻인지 잘 모르겠다면 ❓❓



    먼저 Task specific, Task agnostic이라는 용어부터 짚고 가자 








    NLP모델의 큰 흐름을 정리해보면 Task specific model이 제안되다가 Transformer를 기점으로 Task Agnostic(Task에 상관없는) model이 쏟아지고있다. 

    즉, 대용량 데이터로 학습한 pretrained model이 대부분의 downstream task에서 좋은 성능을 내기 시작했다는 것을 알 수 있다.


    예를 들어 GPT-1의 경우, 핵심은 대용량의 데이터를 학습하여 ‘언어 자체’를 잘 이해할 수 있는 representation을 학습하는 것이다. (== Task Agnostic / Pretraining 단계에서 하는 거) 


    이후 특정 task(= downstream task) 에 맞도록 Fine tuning (== Task Specific) !








    ❓❓ 그럼 " downstream tasks에 대한 "task-specific knowledge"만을 고려하고, 필요한 "task-agnostic knowledge"들을 무시하고 Pruning한다" 는 말의 뜻은 뭐냐고 물으신다면 ❓❓





    Pruning을 할 때, pretrain시에 학습된 weight들 위주로 솎아내는 방식들만 제시되었다는 말인갑다 ! 





     

     

     

    Contributions

     

    • 그래서 task-specific + task-agnostic 모두를 다 잘 유지할 수 있는 방식인 ContrAstive Pruning (CAP) 제안 !
      • 위에서 언급한 unstructured & structured pruning 방식 둘다 지원한다 ! 
      • contrastive learning 방식을 사용하여 효과적으로 representation을 학습함 
      • 또한, 프루닝된 모델의 성능을 더 잘 유지하기 위해 스냅샷(즉, 각 프루닝 반복에서 중간 모델)은 프루닝에 대한 효과적인 감독 역할도 한다

     

     

     




    ❓❓ Snapshot이 뭐지 ❓❓



    Snapshot





    위에서도 언급했듯이 Pruning은 한번에 수행되지 않고, 가지치기 이후 retraining(fine tuning이나 from scratch)과정을 몇번씩 거친다.  여러번 반복하여 성능을 복원했다가 야금야금 pruning하는 방법이 주로 사용된다.


    이 과정을 여러 번 반복하는 과정에서의 중간 정도의 모델들을 Snapshot이라고 부른다 


    ex) pruning 40번 반복하기로 했으면 그 중간, 즉 40번째 이전까지의 모델들 

     

     






    ❓❓ Contrastive Learning이 뭐지 ❓❓ 






    우리가 판단하고자 하는 Query Image인 A,
    그리고 그 이미지 A와 유사한 positive image 

    유사하지 않은 negative image 
    거리를 가지도록 하는 모델! 







    Loss는 아래와 같다 











     

     

     

     

    본 논문에서 쓰는 Contrastive learning의 Loss는 위와 같다 

     

     

     

     

     

    Methodology

     

     

     

     

    모델은 크게 세 가지 부분으로 나눠진다 ! 

    1. PrC: Contrastive Learning with Pre-trained Model
    2. SnC: Contrastive Learning with Snapshots
    3. FiC: Contrastive Learning with Fine-tuned Model

     

     

     

    간단하게 요약하자면, 

     

    • Pre-trained Model에서의 output들과, Snapshots에서의 output들,  Fine-tuned Model에서의 output들 사이에서,
    • 의미적으로 비슷한 애들은 가깝게, 먼 애들은 멀게 위치하도록 임베딩 공간을 조정하는 Contrastive Learning 통해,
    • Pruning을 해도 task-specific Knowledge(fine tuned model)와 task-agnostic Knowledge(pretrained model) 모두를 다 잘 유지할 수 있는 방식인 것이다 

     

     

    (안 간단한 요약)

     

     

    설명을 위해 ppt로 열심히 만든 figure

     

     

    아래의 표는 지도 방식일 때와, 비지도 방식일 때의 postive example들 

     

     

     

    1. PrC: Contrastive Learning with Pre-trained Model

     

     

     

    pruning한 모델의 output와 pretrained 모델의 output 간의 contrastive learning을 통해,

    앞서서 말했던, 기존 논문들의 한계점인 "task agnostic knowledge가 손실된다"는 점을 해결할 수 있음 

     

     

    즉 초거대 LM의 범용 지식을 잘 유지할 수 있다 ~ 

     

     

    최종 loss는 supervised 방식과 unsupervised 방식의 loss가 합쳐진 것 

     

     

     

    2. SnC: Contrastive Learning with Snapshots

     

     

     

     

    마찬가지로 pruning한 모델들의 output 간에 contrastive learning를 적용 ! 

     

     

     

    위에서도 언급했듯이 pruning은 보통 여러 번 수행하는데, 기존의 연구들에서는 그 중간 모델들(a.k.a. snap shot)을 아예 사용하지 않았다

     

     

    하지만 본 논문에서는 이 중간 단계의 모델들을 활용한다 ! 

    아래와 같이 sparsity를 20, 40, 60%로 조정한 모델들 각각으로 contrastive learning 적용 

     

     

     

     

     

    마찬가지로, 최종 loss는 supervised 방식과 unsupervised 방식의 loss가 합쳐진 것 

     

     

     

    3. FiC: Contrastive Learning with Fine-tuned Model

     

     

    pruning한 모델의 output와 fine-tuned 모델의 output 간의 contrastive learning을 통해,

    task specific한 knowledge도 학습 가능 

     

     

     

    여기도 최종 loss는 supervised 방식과 unsupervised 방식의 loss가 합쳐진 것 

     

     

     

     

    위 세 가지 loss는 cross-entropy loss와 함께,

    요렇게 합쳐진답니다 ~

     

     

     

    메모리 오버헤드 문제 없나  ? 

    GPU에 이 모든 모델을 로드할 필요가 없어서 괜찮다

    추가적인 GPU memory 오버헤드는 4096 × 768 = 3.15M 정도라서,  BERTbase의 3.15M / 110M = 2.86% 임
    이정도는 받아줄만 하지 않녜 

     

     

    Experiments

     

     

    아래의 벤치마크 데이터셋들 활용 

     

    • MNLI 
    • QQP 
    • SST-2 
    • SQuAD

     

     

    Main Results

    세 가지의 model compression 방식을 비교합니다 ~

    • Knowledge Distilation
    • Structure Pruning 
    • Unstructure Pruning 

     

     

     

    어떤 pruning 기법을 썼냐에 따라서 세 가지 버전 

    • CAP-f
      • structed pruning / the most standard Firstorder pruning (Molchanov et al. 2017) 기반
    • CAP-m
      • unstructed pruning / the state-of-the-art Movement pruning인 Soft-movement pruning (Sanh, Wolf, and Rush 2020) 기반 : Top-K selection strategy
    • CAP-soft
      • unstructed pruning / Soft-movement pruning (Sanh, Wolf, and Rush 2020, 위와 동일 논문) 기반 : pre-defined threshold 

     

    결과 요약

    1. CAP는 상당한 성능을 유지하면서도 상당 부분의 BERT 파라미터(무려 97%나 삭제)들을 제거한다 

    2. CAP는 Sparsity가 커질 수록 다른 Pruning 방법의 성능을 능가한다. 

    3. CAP는 knowledge distillation 방법의 성능도 능가한다 

     

     

     

    Generalization Ability

    다른 Pruning 기법에 비해 이만큼 성능 올랐다 ~ 

    즉, 우리 방식은 Task agnostic한 knowledge를 잘 보존해서 성능 오른것임 ~

     

     

     

    Discussions

     

     

    Understanding Different Contrastive Modules

     

    모델을 구성하는 세 개의 파츠들 각각 하나씩 빼봤을 때 아래와 같이 성능이 감소했다 

    -> 즉 모두 필요한 구성품들이다 ~

     

     

     

    Understanding Supervised and Unsupervised Contrastive Objectives

     

    the same example encoded by different models are considered as positive examples for unsupervised contrastive learning (unsup).

    If the sentence level label annotations are available, we can also conduct supervised contrastive learning (sup) by considering examples with the same labels as positive examples.

     

     

     

    CAP에는 supervised objectives  과 unsupervised objectives이 모두 필수적

     

     

     

    Performance under Various Sparsity Ratios

     

    sparisity 비율에 따른 성능 변화 

     

    Exploration of Pooling Methods and Temperatures

     

    pooling 방식에 따른 변화 

     

     

    Exploration of Learning From Fine-tuned Model

     

    이 논문에서는 task-specific knowledge를 얻어오기 위해서 fine-tuned model (FiC) 부분을 사용했는데,

    task-specific knowledge를 얻어오기 위한 방법으로는 knowledge distillation (KD)이 있단다

     

    따라서 이 실험은 knowledge distillation (KD)을 했을 때와 하지 않았을 때 성능을 비교한다 

     

     

    Conclusion

     

    "가지쳤더니 더 좋아졌다"는 비유가 여러모로 찰떡이다

     

     

     

    • 논문의 장점
      • 사실 처음 보는 분야라, 설명이 잘 안 되어 있을까봐 논문 읽기 전에 따로 공부했는데,
        그럴 필요 없이 이 논문만 읽어도 pruning이나 contrastive learning에 대해서 대충 이해할 수 있을 정도로 써놓은 듯 🙌
      • 분석... 많아서 좋은데 .... 읽는 건 힘들었음 ㅋㅋ쿠ㅜ
    • 논문의 단점

     

     

     

     

    Code

     

    https://github.com/RunxinXu/ContrastivePruning

     

    GitHub - RunxinXu/ContrastivePruning: Source code for our AAAI'22 paper 《From Dense to Sparse: Contrastive Pruning for Better

    Source code for our AAAI'22 paper 《From Dense to Sparse: Contrastive Pruning for Better Pre-trained Language Model Compression》 - GitHub - RunxinXu/ContrastivePruning: Source code for our AAAI&...

    github.com

     

     

     

     

    서두에서 말했듯이 Pruning 기법은 크게 아래와 같이 나눠지고, 

    • 무엇을 잘라낼 것인가?(What to prune?)
      • unstructured : 무작위로 각각의 weight들을 잘라내기
      • structured : 단위를 잡아서 한번에 잘라내기

     

     

    본 논문에서는 structured pruning과 unstructured pruning 방식을 모두 제공 중이며,

    어떤 pruning 기법을 썼냐에 따라서 세 가지 버전으로,,, 

    • CAP-f
      • structured pruning / the most standard Firstorder pruning (Molchanov et al. 2017) 기반
    • CAP-m
      • unstructured pruning / the state-of-the-art Movement pruning인 Soft-movement pruning (Sanh, Wolf, and Rush 2020) 기반 : Top-K selection strategy
    • CAP-soft
      • unstructured pruning / Soft-movement pruning (Sanh, Wolf, and Rush 2020, 위와 동일 논문) 기반 : pre-defined threshold 

     

     

     

    우선 structured pruning 코드를 통해 전체적인 이해를 해보도록 한다 

    (다음에는  structured pruning 과  unstructured pruning 이 코드상에서 어떻게 다른지 보자 -> 생각보다 코드가 엄청 많아서 다음에 ..)

     

     

    설명을 위해 ppt로 열심히 만든 figure

     

    먼저, 우리에게는 pretrained model (task agnostic)과, 

    fine-tuned model (task specific)이 있다는 전제 하에 진행된다 (여기서는 huggingface에서 제공되는 bert를 사용했음) 

     

    - custom model을 돌릴려면 pretrained + fine-tuned model까지 필수 지참 ~

     

     

     

     

    아래는 main 코드 중에서도 training 하는 부분이다

     

     

     

    크게 두 가지 부분으로 구성되는데,

    첫번째는 finetuned model에서 임베딩 값들을 뽑아오는 부분이다 

     

     

     

    두번째는 본격적으로 매 step마다 1) pruning + 2) retrain을 반복한다 

    구체적인 순서는 다음과 같다

     

    1. pruning할 단위인 sequence를 정하고,

    2. 중요도를 어떻게 계산하고,

    3. 어떤 임계치에서 자를 것인지에 대해서 결정한 뒤 

    4. 모델의 head mask와 intermediate mask를 업데이트 시켜준 후 

    5. retrain한다  ( trainer.train() )

     

     

     

    step 0은 원래의 pretrained model에서 시작하여, step 1~N까지는 pruning 중인 snap shot 모델들이다.

     

     

     

     

    if training_args.do_prune:
        model = trainer._wrap_model(trainer.model, training=False)
        model = model.module if hasattr(model, 'module') else model
    
        # Determine the number of heads to prune
        prune_percent = training_args.prune_percent
        prune_percent = None if prune_percent == '' else [float(x) for x in prune_percent.split(',')]
    
        prune_sequence_head, prune_sequence_intermediate = determine_pruning_sequence(
            prune_percent,
            config.num_hidden_layers,
            config.num_attention_heads,
            config.intermediate_size,
            training_args.at_least_x_heads_per_layer,
        )
        prune_sequence = zip(prune_sequence_head, prune_sequence_intermediate)
    
        for step, (n_to_prune_head, n_to_prune_intermediate) in enumerate(prune_sequence):
            logger.info("We are going to prune {} heads and {} intermediate !!!".format(n_to_prune_head, n_to_prune_intermediate))
            head_importance, intermediate_importance = calculate_head_and_intermediate_importance(
                model,
                train_dataset,
                old_head_mask=model.head_mask,
                old_intermediate_mask=model.intermediate_mask,
                trainer=trainer,
                normalize_scores_by_layer=training_args.normalize_pruning_by_layer,
                subset_size=training_args.subset_ratio
            ) 
            for layer in range(len(head_importance)):
                layer_scores = head_importance[layer].cpu().data
                logger.info("head importance score")
                logger.info("\t".join(f"{x:.5f}" for x in layer_scores))
            # Determine which heads to prune
            new_head_mask = what_to_prune_head(
                head_importance,
                n_to_prune=n_to_prune_head,
                old_head_mask=model.head_mask,
                at_least_x_heads_per_layer=training_args.at_least_x_heads_per_layer,
            )
            new_intermediate_mask = what_to_prune_mlp(
                intermediate_importance,
                n_to_prune=n_to_prune_intermediate,
                old_intermediate_mask=model.intermediate_mask
            )
            for layer in range(len(new_head_mask)):
                y = new_head_mask[layer].cpu().data
                logger.info("head mask")
                logger.info("\t".join("{}".format(int(x)) for x in y))
            logger.info("intermediate mask")
            for layer in range(len(new_intermediate_mask)):
                y = new_intermediate_mask[layer]
                logger.info("Layer {} has {} intermediate active.".format(layer, torch.sum(y)))
    
            # calculate and store example representations and labels (for verification)
            if training_args.use_contrastive_loss:
                representations_bank = None
                labels_bank = None
                dataloader = DataLoader(
                    train_dataset,
                    batch_size=trainer.args.train_batch_size,
                    shuffle=False,
                    collate_fn=trainer.data_collator,
                    drop_last=trainer.args.dataloader_drop_last,
                    num_workers=trainer.args.dataloader_num_workers,
                    pin_memory=trainer.args.dataloader_pin_memory,
                )
                with torch.no_grad():
                    for inputs in tqdm(dataloader):
                        inputs = trainer._prepare_inputs(inputs)
                        labels = inputs['labels'].cpu()
                        representations = model(encode_example=True, **inputs).cpu()
                        if representations_bank is None:
                            representations_bank = representations
                            labels_bank = labels
                        else:
                            representations_bank = torch.cat((representations_bank, representations), dim=0)
                            labels_bank = torch.cat((labels_bank, labels), dim=0)
    
                if step == 0:
                    # add to global representations bank for pretrained
                    global_representations_bank_pretrained = representations_bank
                    global_representations_bank_pretrained = global_representations_bank_pretrained.unsqueeze(1)
                else:
                    # add to global representations bank for snaps
                    if global_representations_bank_snaps is None:
                        global_representations_bank_snaps = representations_bank.unsqueeze(1)
                    else:
                        global_representations_bank_snaps = torch.cat((global_representations_bank_snaps, representations_bank.unsqueeze(1)), dim=1)
    
                # update bank
                model.global_representations_bank_finetuned = global_representations_bank_finetuned
                model.global_representations_bank_pretrained = global_representations_bank_pretrained
                model.global_representations_bank_snaps = global_representations_bank_snaps
                model.global_labels_bank = labels_bank
    
            # apply structured pruing
            model.head_mask[:] = new_head_mask.clone()
            model.intermediate_mask[:] = new_intermediate_mask.clone()
    
            # re-train
            trainer.optimizer = trainer.lr_scheduler = None
            trainer.args.num_train_epochs = training_args.retrain_num_train_epochs
            trainer.train()
    
            # re-eval
            tasks = [data_args.task_name]
            eval_datasets = [eval_dataset]
            if data_args.task_name == "mnli":
                tasks.append("mnli-mm")
                eval_datasets.append(datasets["validation_mismatched"])
    
            for eval_d, task in zip(eval_datasets, tasks):
                metrics = trainer.evaluate(eval_dataset=eval_d)
                max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_d)
                metrics["eval_samples"] = min(max_val_samples, len(eval_d))
                new_metrics = {}
                for k in metrics:
                    new_metrics["{}_{}_{}".format(task, k, step+1)] = metrics[k]
                metrics = new_metrics
                trainer.log_metrics("{}_eval_{}".format(task, step+1), metrics)
                trainer.save_metrics("{}_eval_{}".format(task, step+1), metrics)

     

     

     

    위 코드에서

    1. pruning할 단위인 sequence를 정하고,

    2. 중요도를 어떻게 계산하고,

    3. 어떤 임계치에서 자를 것인지에 대한 함수들은 아래와 같다 

     

     

    def determine_pruning_sequence(
        prune_percents,
        n_heads,
        n_layers,
        n_intermediate,
        at_least_x_heads_per_layer=1,
    ):
        '''
        Same ratio for attention heads and MLPs
        '''
    
        # Compute the number of heads to prune on percentage if needed
        all_n_to_prune = []
        for prune_percent in prune_percents:
            total_heads = n_heads * n_layers
            n_to_prune = int(total_heads * prune_percent / 100)
            # Make sure we keep at least one head per layer
            if at_least_x_heads_per_layer > 0:
                if n_to_prune > total_heads - at_least_x_heads_per_layer * n_layers:
                    assert False
            all_n_to_prune.append(n_to_prune)
    
        # We'll incrementally prune layers and evaluate
        all_n_to_prune = sorted(all_n_to_prune)
        n_to_prune_sequence_head = all_n_to_prune[:]
        for idx in range(1, len(all_n_to_prune)):
            n_to_prune_sequence_head[idx] = all_n_to_prune[idx] - all_n_to_prune[idx-1]
        # Verify that the total number of heads pruned stayed the same
        assert all_n_to_prune[-1] == sum(n_to_prune_sequence_head)
    
        # MLP
        all_n_to_prune = []
        for prune_percent in prune_percents:
            total_intermediate = n_layers * n_intermediate
            n_to_prune = int(total_intermediate * prune_percent / 100)
            all_n_to_prune.append(n_to_prune)
        n_to_prune_sequence_intermediate  = [0 for _ in range(len(all_n_to_prune))]
        n_to_prune_sequence_intermediate[0] = all_n_to_prune[0]
        for idx in range(1, len(all_n_to_prune)):
            n_to_prune_sequence_intermediate[idx] = all_n_to_prune[idx] - all_n_to_prune[idx-1]
        assert len(n_to_prune_sequence_head) == len(n_to_prune_sequence_intermediate)
        return n_to_prune_sequence_head, n_to_prune_sequence_intermediate

     

    pruning percent에 따라 몇 개의 sequence를 날릴지 결정 

    (위에서 언급했던 것처럼, 20% -> 40% -> 60% ...)

     

    def calculate_head_and_intermediate_importance(
        model, 
        dataset,
        old_head_mask,
        old_intermediate_mask,
        trainer,
        normalize_scores_by_layer=True,
        disable_progress_bar=False,
        subset_size=1.0,
    
    ):
        training_flag = model.training
        model = model.module if hasattr(model, 'module') else model
        model.eval() 
    
        n_layers, n_heads, n_intermediate = model.config.num_hidden_layers, model.config.num_attention_heads, model.config.intermediate_size
        head_importance = torch.zeros(n_layers, n_heads).to(old_head_mask)
        head_mask = torch.ones(n_layers, n_heads).to(old_head_mask)[:] = old_head_mask.clone()
        head_mask.requires_grad_(requires_grad=True)
        intermediate_importance = torch.zeros(n_layers, n_intermediate).to(old_intermediate_mask)
        intermediate_mask = torch.ones(n_layers, n_intermediate).to(old_intermediate_mask)[:] = old_intermediate_mask.clone()
        intermediate_mask.requires_grad_(requires_grad=True)
    
        batch_size = trainer.args.train_batch_size
        if subset_size <= 1:
            subset_size *= len(dataset)
        n_prune_steps = int(np.ceil(int(subset_size) / batch_size))
    
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=trainer.data_collator,
            drop_last=trainer.args.dataloader_drop_last,
            num_workers=trainer.args.dataloader_num_workers,
            pin_memory=trainer.args.dataloader_pin_memory,
        )
        dataloader = islice(dataloader, n_prune_steps)
        prune_iterator = tqdm(
            dataloader,
            desc="Iteration",
            disable=disable_progress_bar,
            total=n_prune_steps
        )
    
        for inputs in prune_iterator:
            # key: add head mask and intermediate mask, so we can get the gradients from them
            inputs['head_mask'] = head_mask 
            inputs['intermediate_mask'] = intermediate_mask
            inputs = trainer._prepare_inputs(inputs)
            loss = trainer.compute_loss(model, inputs)
            loss.backward()
            head_importance += head_mask.grad.abs().detach()
            intermediate_importance += intermediate_mask.grad.abs().detach()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0)
            torch.nn.utils.clip_grad_norm_(head_mask, 0)
            torch.nn.utils.clip_grad_norm_(intermediate_mask, 0)
        
        if normalize_scores_by_layer:
            exponent = 2
            norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1/exponent)
            head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
            norm_by_layer = torch.pow(torch.pow(intermediate_importance, exponent).sum(-1), 1/exponent)
            intermediate_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
        
        if training_flag:
            model.train()
    
        return head_importance, intermediate_importance

     

    # key: add head mask and intermediate mask, so we can get the gradients from them

    head mask와 intermediate mask를 추가해서 

     

     

    def what_to_prune_head(
        head_importance,
        n_to_prune,
        old_head_mask,
        at_least_x_heads_per_layer=1,
    ):
        head_importance = head_importance.clone()
        n_layers, n_heads = head_importance.size()
    
        already_prune = {}
        for layer in range(old_head_mask.size(0)):
            for head in range(old_head_mask.size(1)):
                if old_head_mask[layer][head].item() == 0:
                    if layer not in already_prune:
                        already_prune[layer] = []
                    already_prune[layer].append(head)
    
        # Sort heads by score
        heads_and_score = [
            ((layer, head), head_importance[layer][head].item())
            for layer in range(n_layers)
            for head in range(n_heads)
        ]
        heads_and_score = sorted(heads_and_score, key=lambda x: x[1])
        sorted_heads = [head_and_score[0]
                        for head_and_score in heads_and_score]
        # Ensure we don't delete all heads in a layer
        if at_least_x_heads_per_layer:
            # Remove the top scoring head in each layer
            to_protect = {l: 0 for l in range(n_layers)}
            filtered_sorted_heads = []
            for layer, head in reversed(sorted_heads):
                if layer in to_protect:
                    if to_protect[layer] < at_least_x_heads_per_layer:
                        to_protect[layer] += 1
                        continue
                    else:
                        to_protect.pop(layer)
                filtered_sorted_heads.insert(0, (layer, head))
            sorted_heads = filtered_sorted_heads
        # layer/heads that were already pruned
        # Prune the lowest scoring heads
        sorted_heads = [
            (layer, head)
            for (layer, head) in sorted_heads
            if layer not in already_prune or head not in already_prune[layer]
        ]
    
        old_head_mask = old_head_mask.clone()
        new_head_mask = old_head_mask.clone()
        # Update heads to prune
        for layer, head in sorted_heads[:n_to_prune]:
            new_head_mask[layer][head] = 0
        return new_head_mask
        
     def what_to_prune_mlp(
        intermediate_importance,
        n_to_prune,
        old_intermediate_mask
    ):
        intermediate_importance = intermediate_importance.clone()
        n_layers, n_intermediate = intermediate_importance.size()
    
        already_prune = defaultdict(list)
        for layer in range(n_layers):
            for intermediate_idx in range(n_intermediate):
                if old_intermediate_mask[layer][intermediate_idx].item() == 0:
                    already_prune[layer].append(intermediate_idx)
    
        score = [
            ((layer, intermediate_idx), intermediate_importance[layer][intermediate_idx].item()) 
            for layer in range(n_layers) for intermediate_idx in range(n_intermediate)
        ]
        score.sort(key=lambda x:x[-1])
        filter_score = [
            ((layer, intermediate_idx), score)
            for ((layer, intermediate_idx), score) in score
            if layer not in already_prune or intermediate_idx not in already_prune[layer]
        ]
    
        old_intermediate_mask = old_intermediate_mask.clone()
        new_intermediate_mask = old_intermediate_mask.clone()
        for (layer, intermediate_idx), _ in filter_score[:n_to_prune]:
            new_intermediate_mask[layer][intermediate_idx] = 0
        return new_intermediate_mask

     

     

    4. 모델의 head mask와 intermediate mask를 업데이트 시켜준 후

    5. retrain한다  ( trainer.train() )

     

     

    contrastive loss에 대한 계산은 PruneBert.py에서 확인할 수 있었다 

     

    근데 너무 길어서 loss 계산하는 것만 보자 

     

     

        def forward(
            self,
            idx=None,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            intermediate_mask=None,
            inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            encode_example=False,
        ):
    
    	##  생략 ## 
        
            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                intermediate_mask=intermediate_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
    
            all_output, pooled_output = outputs[0], outputs[1]
    
            # just for encode example
            if encode_example:
                if self.alignrep == 'mean-pooling':
                    result = torch.sum(all_output * (attention_mask.unsqueeze(-1) == 1), dim=1) / torch.sum(attention_mask, dim=1, keepdim=True)
                else:
                    result = pooled_output
                result = torch.nn.functional.normalize(result, p=2, dim=-1)
                return result # bsz * hidden_state
    
            pooled_output = self.dropout(pooled_output)
            logits = self.classifier(pooled_output)
    
            loss = None
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
    
            # cross entropy loss
            loss *= self.ce_loss_weight
    
    		## 생략 ##
    
            loss += self.calculate_contrastive_loss(idx, self.global_representations_bank_pretrained, self.global_labels_bank, all_output, pooled_output, attention_mask, labels)
            loss += self.calculate_contrastive_loss(idx, self.global_representations_bank_finetuned, self.global_labels_bank, all_output, pooled_output, attention_mask, labels)
            loss += self.calculate_contrastive_loss(idx, self.global_representations_bank_snaps, self.global_labels_bank, all_output, pooled_output, attention_mask, labels)
    
            if not return_dict:
                output = (logits,) + outputs[2:]
                return ((loss,) + output) if loss is not None else output
    
            return SequenceClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )

     

     

    self.calculate_contrastive_loss 함수는 그럼? 

     

    def calculate_contrastive_loss(
            self, 
            idx, 
            global_representations_bank, 
            global_labels_bank,
            all_output, 
            pooled_output, 
            attention_mask,
            labels
        ):
            loss = 0
            if idx is not None and global_representations_bank is not None:
                # representations: bsz * rep_num * hidden_state
                # labels: bsz
                representations = torch.index_select(global_representations_bank, dim=0, index=idx.cpu()).to(pooled_output)
                bsz, rep_num, hid_size = representations.size()
    
                if self.alignrep == 'mean-pooling':
                    pooled_output = torch.sum(all_output * (attention_mask.unsqueeze(-1) == 1), dim=1) / torch.sum(attention_mask, dim=1, keepdim=True)
                else:
                    pooled_output = pooled_output
                pooled_output = torch.nn.functional.normalize(pooled_output, p=2, dim=-1) # bsz * hidden_state
                
                # also add current data
                representations = torch.cat((pooled_output.unsqueeze(1).detach(), representations), dim=1) # bsz * (rep_num+1) * hidden_state
                representations = representations.reshape(bsz*(rep_num+1), hid_size)
    
                # sample more examples
                extra = self.extra_examples // global_representations_bank.size(1) 
                extra_idx = torch.LongTensor(random.sample(range(global_representations_bank.size(0)), k=extra))
                extra_labels = torch.index_select(global_labels_bank, dim=0, index=extra_idx).to(pooled_output) # extra
                extra_representations = torch.index_select(global_representations_bank, dim=0, index=extra_idx).view(-1, hid_size).to(pooled_output) # (extra * rep_num) * hidden_state
                extra_idx = extra_idx.to(pooled_output)
    
                representations = torch.cat((representations, extra_representations), dim=0) 
                contrastive_score = torch.mm(pooled_output, representations.t()) # bsz * (bsz*(rep_num+1)+(extra * rep_num))
    
                # exclude choosing myself
                # contrastive_mask: choosing myself -> 1,e.g., contrastive_mask[0,0] = contrastive_mask[1, rep_num] = contrastive_mask[2, 2*rep_num] = 1
                contrastive_mask = torch.unbind(torch.eye(bsz).to(contrastive_score), dim=1)
                contrastive_mask = [torch.cat((m.unsqueeze(1), torch.zeros(bsz, rep_num).to(contrastive_score)), dim=1) for m in contrastive_mask]
                contrastive_mask = torch.cat(contrastive_mask, dim=1) # bsz * (bsz*(rep_num+1))
                contrastive_mask = torch.cat((contrastive_mask, torch.zeros(bsz, extra * rep_num).to(contrastive_mask)), dim=1) # bsz * (bsz*(rep_num+1)+(extra * rep_num))
                contrastive_score /= self.contrastive_temperature
                contrastive_score = contrastive_score.masked_fill(contrastive_mask==1, -1e6)
                contrastive_score = torch.nn.functional.log_softmax(contrastive_score, dim=-1)
    
                # calculate unsupervised_mask, only maintain the positive positives (belonging to the same instance) log_softmax
                all_idx = torch.cat((idx.unsqueeze(1).repeat(1, rep_num+1).view(-1), extra_idx.unsqueeze(1).repeat(1, rep_num).view(-1)), dim=0) # (bsz*(rep_num+1)+(extra * rep_num))
                all_idx = all_idx.unsqueeze(0).expand(bsz, -1) # bsz * (bsz*(rep_num+1)+(extra * rep_num))
                positive_mask = (idx.unsqueeze(1) == all_idx) # bsz * (bsz*(rep_num+1)+(extra * rep_num))
                mask_contrastive_score = contrastive_score.masked_fill( (positive_mask==0) | (contrastive_mask==1), 0)
                positive_num = torch.sum(positive_mask, dim=1, keepdim=True) - 1
                loss += - self.cl_unsupervised_loss_weight * torch.sum(mask_contrastive_score / positive_num) / torch.sum(mask_contrastive_score!=0)
                
                # supervised_mask
                all_labels = torch.cat((labels.unsqueeze(1).repeat(1, rep_num+1).view(-1), extra_labels.unsqueeze(1).repeat(1, rep_num).view(-1)), dim=0) # (bsz*(rep_num+1)+(extra * rep_num))
                all_labels = all_labels.unsqueeze(0).expand(bsz, -1) # bsz * (bsz*(rep_num+1)+(extra * rep_num))
                positive_mask = (labels.unsqueeze(1) == all_labels) # bsz * (bsz*(rep_num+1)+(extra * rep_num))
                mask_contrastive_score = contrastive_score.masked_fill( (positive_mask==0) | (contrastive_mask==1), 0)
                positive_num = torch.sum(positive_mask, dim=1, keepdim=True) - 1
                loss += - self.cl_supervised_loss_weight * torch.sum(mask_contrastive_score / positive_num) / torch.sum(mask_contrastive_score!=0)
            
            return loss

     

     

     

    • 실제 사용할 수 있는 모듈

     

    1. pytorch의 prune 모듈  

     

    import torch
    from torch import nn
    import torch.nn.utils.prune as prune
    import torch.nn.functional as F
    
    
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    class LeNet(nn.Module):
        def __init__(self):
            super(LeNet, self).__init__()
            # 1개 채널 수의 이미지를 입력값으로 이용하여 6개 채널 수의 출력값을 계산하는 방식
            # Convolution 연산을 진행하는 커널(필터)의 크기는 3x3 을 이용
            self.conv1 = nn.Conv2d(1, 6, 3)
            self.conv2 = nn.Conv2d(6, 16, 3)
            self.fc1 = nn.Linear(16 * 5 * 5, 120)  # Convolution 연산 결과 5x5 크기의 16 채널 수의 이미지
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
    
        def forward(self, x):
            x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
            x = F.max_pool2d(F.relu(self.conv2(x)), 2)
            x = x.view(-1, int(x.nelement() / x.shape[0]))
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    model = LeNet().to(device=device)
    
    
    module = model.conv1
    print(list(module.named_parameters()))
    
    prune.random_unstructured(module, name="weight", amount=0.3)

     

    • 모듈 내 같은 파라미터값에 대해 가지치기 기법이 여러번 적용될 수 있으며, 다양한 가지치기 기법의 조합이 적용된 것과 동일하게 적용될 수 있다. 새로운 마스크와 이전의 마스크의 결합은 PruningContainer 의 compute_mask 메소드를 통해 처리할 수 있다.

     

    2. Neural Network Distiller

    •  network compression을 위한 open-source, 실무에 들어가면 쓰게 되지 않을까 ? 

     

    https://intellabs.github.io/distiller/pruning.html#han-et-al-2015

     

    Pruning - Neural Network Distiller

    A common methodology for inducing sparsity in weights and activations is called pruning. Pruning is the application of a binary criteria to decide which weights to prune: weights which match the pruning criteria are assigned a value of zero. Pruned element

    intellabs.github.io

     

     

     

     


     

    Reference

     

    # about pruning

     

    https://arxiv.org/abs/1510.00149

     

    Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffman Coding

    Neural networks are both computationally intensive and memory intensive, making them difficult to deploy on embedded systems with limited hardware resources. To address this limitation, we introduce "deep compression", a three stage pipeline: pruning, trai

    arxiv.org

     

    https://velog.io/@woojinn8/LightWeight-Deep-Learning-3.-Deep-Compression-%EB%A6%AC%EB%B7%B0

     

    [CNN Networks] 9. Deep Compression 리뷰

    Pruning과 Quantization을 활용해 모델 압축을 하는 Deep compression에 대해 정리한 내용입니다.

    velog.io

     

    https://wandb.ai/katia/GLUE-aws-sweep/reports/Pruning-BERT-on-a-GLUE-task--Vmlldzo4NTYxNjI

     

    Pruning BERT on a GLUE task

    A quick presentation on fine-tuning and pruning BERT model from Hugging-Face on the CoLA task. Made by KaKTaK using Weights & Biases

    wandb.ai

     

    https://blogik.netlify.app/BoostCamp/U_stage/45_pruning/

     

    모델 경량화 1 - Pruning(가지치기)

    가지치기 by 홍원의 마스터님, BoostCamp AI Tech 8주차

    blogik.netlify.app

     

    https://tutorials.pytorch.kr/intermediate/pruning_tutorial.html

     

    가지치기 기법(Pruning) 튜토리얼

    저자: Michela Paganini 번역: 안상준 최첨단 딥러닝 모델들은 굉장히 많은 수의 파라미터값들로 구성되기 때문에, 쉽게 배포되기 어렵습니다. 이와 반대로, 생물학적 신경망들은 효율적으로 희소하

    tutorials.pytorch.kr

    https://arxiv.org/pdf/2003.03033.pdf

     

     

     

     

    # 본 논문 

    https://arxiv.org/pdf/2112.07198.pdf

     

     

    # Contrastive Learning 

    https://velog.io/@sjinu/Similarity-Learning-Contrastive-Learning

     

    Similarity Learning & Contrastive Learning

    Similarity Leanring & Contrastive Learning(1)

    velog.io

     

    # 기타

     

    https://yukyunglee.github.io/Transformer-to-T5-1/

     

    Transformer to T5 - 1

    학부 졸업 프로젝트와 논문이 Object detection과 Linear Proramming을 주제로 했기 때문에, 연구실 지원당시에도 NLP를 공부할것이라 미처 생각하지 못했다. 그래서 입학이 확정난 이후부터 스터디를 하

    yukyunglee.github.io

    http://dsba.korea.ac.kr/seminar/?mod=document&uid=1789 

     

     

    728x90
Designed by Tistory.