-
Seq2Seq & Attention & Transformer 정리AI/NLP 2021. 12. 31. 14:37728x90
[논문 목록]
1. Seq2Seq : Sequence to Sequence Learning with Neural Networks (https://arxiv.org/abs/1409.3215)
2. Attention : Neural Machine Translation by Jointly Learning to Align and Translate (https://arxiv.org/abs/1409.0473)
3. Transformer : Attention is All You Need (https://arxiv.org/abs/1706.03762)
Seq2Seq
RNN 계열의 모델로, Encoder과 Decoder로 구성되어 있음
1) Encoder: 입력된 단어의 연속열(= Source sentence)을 고정된 크기의 벡터 하나로 압축하여, 입력된 문장에 대한 문맥 정보를 담음 → 하나의 Context Vector 생성!
2) Decoder: Encoder의 출력인 Context Vector를 참고하여, 출력하고자 하는 Target Sentence를 만들어냄
Seq2Seq을 훈련하는 과정에서 Teacher Forcing 이라는 개념을 이해해야 합니다. Teacher Forcing이란 seq2seq 모델을 훈련하는 과정에선 Decoder에 input으로 들어가는 값들을 Ground truth값으로 고정하는 것을 의미합니다. 이렇게 하는 이유는 만약 Decoder의 초반 부분에서 학습을 잘못하여 잘못된 값을 예측하게 된다면 그 다음으로 계속 영향을 미치는 것을 방지하기 위해서입니다.
Testing할 때는 Decoder에 input으로 들어가는 값을 이전 Step의 output을 입력합니다.
Attention
기존의 Seq2Seq, LSTM, GRU과 같은 RNN 계열에서의 고질적인 한계들이 존재!
1. 사용되던 고정된 크기의 Context Vector에는 한계가 있음 (FIxed-length Context Vector )
- 하나의 Context Vector가 Source Sentence의 모든 정보 가지고 있어야 하므로 성능 저하
- 입력된 Source Sentence의 길이가 가변적인데, Context Vector는 고정된 크기를 가지는 것이 전체적인 성능에 Bottleneck Problem 유발
2. Long Term Dependency Problem (Vanishing Gradient)
- 층이 깊어질수록 앞쪽의 정보가 제대로 전달되지 못하는 long term dependency 문제
3. Parallelization Problem
- 입력이 차례대로 들어가 계산을 병렬적으로 처리하지 못한다
=> 해결책 : "Seq2Seq 모델에 Attention weight 반영하자! "
"Attention Mechanism" : 매번 Encoder의 source sentence으로부터의 출력을 전부 Decoder에서 입력으로 받자!
- Decoder에서 매번 Encoder의 모든 output을 참고하여 그 중 어떤 정보가 중요한지 계산(Energy)하여 확률값(Weight)으로 반영!
이제부터 Attention Weight를 어떻게 얻는지 말씀드리겠습니다.
첫번쨰로 context vector와 encoder의 각 input vector들을 임의의 함수 f를 통해 중요도 a를 뽑아냅니다. 이때 이 함수 f는 작은 neural network를 사용합니다. 두번째로 이 중요도의 합들이 1이 되도록 softmax를 거쳐 weight를 만듭니다. 세번째로 이 weight와 input vector들을 곱한 후 모두 합하여 중요도가 반영된 vector 하나를 생성해냅니다. 이때 이 벡터가 attention score가 됩니다.
다음으로, 앞에서 얻은 Attention Weight을 Decoder에서 어떻게 반영하는지 말씀드리도록 하겠습니다. Decoder의 이전 step의 hidden state를 context vector로 하여 Encoder의 모든 input vector들과의 계산을 통해 Attention Weight을 뽑아낸 후, 그 attention score 값을 Decoder의 현재 step의 hidden state값과 concatenate하여 예측 값을 도출합니다.
Transformer
RNN, CNN 대신 Multi Head Attention + Positional Encoding 사용!
* Positional Encoding이란?
- RNN 계열 아키텍쳐를 사용한다면, RNN을 사용한다는 것만으로도 각각의 단어가 순서대로 들어가기 때문에, 자동으로 각각의 h_s 값은 순서 정보도 가지게 된다
- Transformer에서는 순서 정보를 반영하기 위해 Input Embedding Matrix에 Positional Encoding을 더해준다
- 이때 Positional Encoding은 순서 정보를 주기 함수(sin과 cos)를 통해 표현한다
* Multi Head Attention 이란?
: Scaled Dot-Product Attention( = Self Attention) 을 여러 개로 나누어(=head를 여러 개를 두어), 병렬 처리하여서 더 빠르게 계산하기 위해서!
* Self Attention (=Scaled Dot-Product Attention)
: Query, Key, Value 값을 통해 입력 문장의 각각의 단어가 입력 문장의 다른 단어들과 얼마나 밀접한 단어가 있는지에 대한 정보 추출
→ ex) 입력 문장 I am a teacher가 있을 때, I에 대하여 I am a teacher의 각각의 단어들이 어느 정도의 관련성이 있나?
: Query: 무언가 물어보는 주체 [I]
: Key: 물어보는 대상 [I/am/a/teacher]
: Value : 실제 value 값 [어느 정도의 관련성]
* Transformer 구조
1. Encoder
( Multi-Head Attention(Encoder Self Attention) + Add & Norm + Feed Forward Layer + Add & Norm ) X N개 층
2. Decoder
→ Encoder의 마지막 레이어의 출력( key vector, value vector )을 모든 디코더 레이어의 Encoder Decoder Attention에 입력
( Multi-Head Attention (Masked Decoder Attention) + Add & Norm + Multi-Head Attention (Encoder Decoder Attention) + Add & Norm + Feed Forward Layer + Add & Norm ) X N개 층
* Transformer에서의 Attention Module 3가지
1. Encoder Self Attention (Encoder에서 사용) : 각각의 출력 단어가 모든 출력 단어 전부를 참고
2. Masked Decoder Attention (Decoder에서 사용) : 각각의 출력 단어가 앞쪽에 등장한 출력 단어 참고
3. Encoder Decoder Attention (Decoder에서 사용) : Query가 Decoder에 있고 각각의 Key와 Value는 Encoder에 있는 상황
Reference
https://yngie-c.github.io/nlp/2020/06/30/nlp_seq2seq/
https://glee1228.tistory.com/3
https://yngie-c.github.io/nlp/2020/07/01/nlp_transformer/
728x90'AI > NLP' 카테고리의 다른 글
NLP Benchmark Datasets 정리 (GLUE / SQuAD/RACE) (0) 2022.02.04 RoBERTa: A Robustly Optimized BERT Pretraining Approach 정리 (0) 2022.02.04 Word Embedding 3 : Deep Contextualized Word Representations (ELMo) 정리 (0) 2022.01.17 Word Embedding 01 (One-hot Encoding / Word2Vec ) 정리 (0) 2021.12.31 Word Embedding 02 ( Glove / FastText ) 정리 (0) 2021.12.31