ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [torch] torch.einsum 함수 이해하기
    Data miner/Development log 2022. 11. 15. 21:39
    728x90

    torch.einsum함수는 Transformer 논문에서 query의 값과 query에 position embedding값을 연계시키는 부분에서 사용되었다. 이에 대해서 한 번 정리하고 가는 것이 좋을 것 같아 포스팅을 하게 되었다. 

    relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)

     

    공식문서에 따르면, einsum() 함수는 input operands의 곱을 합하는 방식이며, Einstein 합 방식에 기반하여 output matrix 차원을 만드는 방식이다. 아래의 그림과 같이 operands에 속하는 텐서들의 차원이 순서대로 첫번째 요소에 대응하며, '->' 이후는 연산 이후의 output차원을 나타낸다. 여러 텐서에서 동일한 차원에서 수행되는 경우, 동일한 기호를 사용해야 한다. 

     

    이러한 표기는 배치 행렬 곱셈을 적용 후 축을 바꾸는 수고스러움(permute 함수 사용)을 덜게 한다.

     

    공식 문서 및 다른 블로그를 서치해보아도, einsum 작동원리의 응용된 연산에 대한 설명을 찾아볼 수 없어, 직접 간단한 예시를 통해 연산하는 방식으로 작동 원리를 이해하고자 하였다. 

     

    relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 부분에서, "bhld,lrd->bhlr" 실제 연산 과정은 여러번의 시행착오를 통해 다음과 같다는 것을 확인할 수 있었다. 

     

    먼저, positional_embedding의 lrd 차원은 permute/transpose를 통해 lrd -> ldr 로 바꾼 후 bhld, ldr 차원이 동일한 부분에 해당하는 ld부분 원소끼리 일반적인 matmul 방식으로 곱해지지 않고 앞 query_layer에 ld해당하는 부분에 transpose된 positional_embedding이 각각 대응되어 곱하는 방식으로 연산이 되었다.

     

    또한, 이는 상대적인 위치를 query layer에 더하기 위해서, 각 상대적인 위치 임베딩 벡터값을 곱하여 query layer에 반영하고자 하기 위함이다. 

     

    하나의 예시가 아니라, 전반적인 작동원리에 대한 글을 찾거나 알게 되면 추후 수정하여 포스팅할 예정이다.

     

     

    단순한 예시에 대한 설명은 다음의 링크를 따라가면 되겠다.

    https://baekyeongmin.github.io/dev/einsum/

     

Designed by Tistory.