ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [코드리뷰] Transformer의 positional_embedding
    Data miner/Information Retrieval 2022. 11. 2. 21:59
    728x90

    논문을 구현한 코드에서 단어 토큰들의 위치 정보를 임베딩 하기 위한 여러 방식들이 있는데, "absolute", "relative_key", "relative_key_query"가 그에 해당한다. 

     

     - 'absolute'의 경우

        - position_embeddings이란 이름으로 토큰 시퀀스의 위치 인덱스를 나타내는 정수형 타입의 텐서값을 받아 임베딩한다. (max embedding 크기 X hidden size)

        - word_embeddings, token_type_embeddings, position_embeddings의 값은 추후에 합산되어 학습되므로 이들의 임베딩 hidden size는 동일하다. Embedding X by Y 에서 Y값이 동일하다는 의미. 

        - 한편 토큰 시퀀스의 위치 인덱스의 경우, 모델의 파라미터가 아니므로 self.register_buffer() 를 통해 매개변수가 아닌 상태로 사용한다. 

        - BertEmbeddings 안에서 사용됨

     

      
      class BertEmbeddings(nn.Module):
      		...
            ...
            
       	self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
            self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
            if version.parse(torch.__version__) > version.parse("1.6.0"):
                self.register_buffer(
                    "token_type_ids",
                    torch.zeros(self.position_ids.size(), dtype=torch.long),
                    persistent=False,
                )
                
       def forward:
       	...
        ...
            embeddings = inputs_embeds + token_type_embeddings
            if self.position_embedding_type == "absolute":
                position_embeddings = self.position_embeddings(position_ids)
                embeddings += position_embeddings
            embeddings = self.LayerNorm(embeddings)
            embeddings = self.dropout(embeddings)
            return embeddings

     

    - "relative_key" 혹은 "relative_key_query"의 경우

       - (2*max embedding 크기-1) by (self.attention_head_size) 의 distance embedding 사용

     - 토큰들 사이의 거리를 측정하며, 각 토큰 간의 거리를 측정하기 위해서 시퀀스 내의 최대 토큰 수(max_position_embeddings)를 사용함

       - 만약 seq_length = 10인 경우라면, 원래 아래와 같이 0~9라는 위치를 가지고 있는 토큰 정보들이 (하단 그림 참조)

         전체 문장 내에서 어떤 상대적 위치를 가지고 있는지 나타내기 위해서 아래와 같이 표현됨 (하단 그림 참조)

       상대적인 위치이기 때문에, i번째 행에서 0은 토큰 자기 자신의 위치를 나타냄. 각각의 상대적 위치 값 중에서 마이너스 값들은 시퀀스 내의 최대 토큰 수(self.max_position_embeddings - 1)의 값이 합산되어 보정되어서 표현됨. 그래야 self.distance_embedding 에서 lookup해서 쓸 수 있기 때문.  이 임베딩을 positional_embedding이라고 함

      - 이후, torch.einsum() 함수로 query_layer와 positional embedding을 연산해서 relative_position_scores을 구함

      - relative_position_scores (혹은 relative_position_scores_query + relative_position_scores_key)는 attention_scores과 더해짐

       - BertSelfAttention 안에서 사용됨

    class BertSelfAttention(nn.Module):
    
    	...
        ...
        
        def forward(
        	...
            ...
        
            if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
                    seq_length = hidden_states.size()[1]
                    position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
                    position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
                    distance = position_ids_l - position_ids_r
                    positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
                    positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility
    
                    if self.position_embedding_type == "relative_key":
                        relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                        attention_scores = attention_scores + relative_position_scores
                    elif self.position_embedding_type == "relative_key_query":
                        relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                        relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                        attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

     

     

     

     

Designed by Tistory.