본문 바로가기
ML 관련/자연어 처리 관련

[논문 리뷰] Attention is all you need

by 탶선 2019. 12. 5.
반응형

 

 

논문 제목 : Attention is all you need

 

구글 브레인, 구글 리서치에서 쓴 논문으로 이 블로그의 첫 글을 쓰고자 한다.

 

논문 제목에 나오는대로 attention - 특정 정보(단어)에 좀 더 주의를 기울이면 된다는 논문이다.

 

예를들어

model이 수행해야 하는 task가 번역일 경우 source는 영어이고 target은 한국어일 때

“Hi, my name is tapseon.” 문장과 대응되는 “안녕, 내 이름은 탶선이야.”라는 문장이 있다.

model이 '이름은' 이라는 token을 decode할 때, source에서 가장 중요한 것은 name이다

 

즉 source의 모든 token이 비슷한 중요도를 갖기 보다는 name이 더 큰 중요도를 갖게 만드는 방법이 attention이다.

 

1. 배경

기존의 language model들은 series data가 순차적으로 신경망의 입력으로 들어가는 형식인 RNN을 기반으로 하여 문장이 길어질 경우 거리가 먼 단어들(예 : 1,2,3,4,5,6에서 1과 6을 뜻함)의 경우 서로에 대한 상관관계가 약하게 처리되며 RNN은 순차적으로 진행이 되는 특성상 병렬 처리 불가능하다는 단점을 안고 있다.

 

Attention 기법은 문장 전체를 한번에 입력하여 거리가 먼 단어라 할지라도 서로에 대한 상관관계가 약하지 않으며 RNN으로는 불가능한 병렬 처리가 가능하게 함으로써 약 30배가량의 학습 속도를 개선했다

 

 

2. Introduction

RNN기반의 인코더와 디코더를 그림으로 나타낸다

RNN기반의 인코더와 디코더

 

 

 

 

위의 그림들 처럼 RNN기반으로 기계번역시 가중치를 두는 

 

transformer model architecture

attention의 모델 구조는 위 그림과 같다

attention모델은 크게 왼쪽의 encoder와 오른쪽의 decoder로 나뉜다

 

encoder

Encoder

encoder는 N개의 identical layer로 구성되어 있으며 input AxA 가 첫 번째 layer에 들어가게 되고, layer(x)가 다시 layer에 들어가는 식이다.

그리고 각각의 layer는 2개의 sub layer를 동반하며 multi-head self-attention mechanism position-wise fully connected feed-forward network로 구성되어 있다.

이때 두 개의 sub-layer에 residual connection(input을 output으로 그대로 전달하는 것)을 이용한다.

 

이때 x+Sublayer(x)를 하기 위해 sub-layer의 output dimension을 embedding dimension과 맞춘다.

 

즉 residual connection을 하기 위해서는 두 값의 차원을 맞춰줄 필요가 있다.  그 후 layer normalization을 적용한다

 

 

Decoder 

decoder 또한 N개의 동일한 layer로 이루어져 있으며 encoder와 달리 encoder의 경과에 multi-head attention 을 수행할 sub-layer를 추가 , sub-layer에 residual connection을 사용 후 layer normalization 수행한다.

 

decoder는 encoder와 다르게 순차적으로 결과를 만들어야 하기 떄문에 self-attention을 변형한다 -> masking 사용

masking을 통해서 position 보다 이후에 있는 position에 attention을 주지 못하도록 한다.

 

position 에 대한 예측은 이미 알고 있는 output들에 의존한다.

 

masking에 대한 예

(예시)

1행의 a 예측시

  a 뒤에 있는 b,c는 masking되어 attention되지 않는다.

2행의 b 예측시

  b 전에 있는 a를 attention가능, c는 masking되어 attention 불가능하다.

 

 

 

Attention

Scaled Dot-Product Attention

본 논문에서는 attention을 Scaled Dot-Product Attention이라 부른다

input은 $d_{k}$ dimension의 query와 key들, $d_{v}$  dimension의 value들로 이루어지며

이 때 모든 query와 key에 대한 dot-product를 계산하고 각각을 $\sqrt{d_{k}}$ 로 나누어준다.

 

dot-product를 하고 $\sqrt{d_{k}}$ 로 scaling을 해주기 때문에 Scaled Dot-Product Attention 이라 부른다.

그리고 여기에 softmax를 적용해 value들에 대한 weights를 얻어낼 수 있다.

 

-query와 key, value에 대한 설명-

 

query가 어떤 단어와 관련되어 있는지 찾기 위해서 모든 key들과 연산하며 실제 연산을 보면 query와 key를

dot-product하여 softmax를 취한다.

dot-product의 값이 커질수록 softmax함수에서 기울기의 변화가 거의 없는 부분으로 가기 때문에 , $\sqrt{d_{k}}$로 scaling을 해준다.

여기서 의미하는 것은 하나의 query가 모든 key들과 연관성을 계산, 그 값들을 확률 값으로 만들어 주는 것이다.

softmax를 거친 값을 value에 곱해주어 querty와 유사한 value(중요한 value) 일수록 더 높은 값을 갖게 된다.

따라서 query가 어떤 key와 높은 확률로 연관성을 가지는지 알게 되는 것이다.

이제 구한 확률값을 value에 곱해서 value에 대해 scaling한다고 생각하면된다.

 

Q = Query : t-1 시점의 decoder의 hidden states # Query (Q): 영향을 받는 단어 A를 나타내는 변수

K = Keys : 모든 시점의 encoder의 hidden states # Key (K) : 영향을 주는 단어 B를 나타내는 변수

V = Values : 모든 시점의 encoder의 hidden states # Value (V) : 그 영향에 대한 가중치




 

 

Multi-Head Attention

기존의 attention - 전체 dimension에 대해 하나의 attention을 적용

Multi-head attention - 전체 dimension을 h로 나눠 attention을 h번 적용

 

각 query,key,vector를 linearly하게 h로 project 한 후 각각을 attention 한다.

이 후 생성된 h개의 vector를 concat하여 생성된 vector의 dimension을 $d_{model}$로 다시 맞춰주도록 matrix를 곱하는 과정

$$MultiHead(Q,K,V) = Concat(head_{1},...,head_{h})W^O$$

$$where head_{i} = Attention(QW_{i}^Q,KW_{i}^K,VW_{i}^V)$$

 

각 파라미터의 shape

$$W_{i}^Q,W_{i}^K \in \mathbb{R}^{d_{model}\times d_{k}} ,W_{i}^V\in\mathbb{R}^{d_{model}\times d_{k}},W^O \in\mathbb{R}^{hd_{v}\times d_{model}} $$

 

Regularization

3가지 방법

1. Residual dropout

2. Attention dropout

3. label smoothing

 

Conclusion


Transformer 
인코더 디코더 구조에서 가장 일반적으로 사용되는 반복 레이어를 multihead self attention 으로 대체하여 기반한 모델 제시
• WMT 2014 영어 독일어 및 WMT 2014 영어 프랑스어 번역 작업
모두에서 SOTA 달성
• 이전의 모든 앙상블 모델 보다 우수한 성능
• 텍스트 이외의 입력 및 출력 방식과 관련된 문제로 확장 계획

 

 

 

 

 

본 게시물은 포자랩스의 게시물을 참고하여 작성했습니다.

https://pozalabs.github.io/transformer/

반응형

댓글