einsum
-
[torch] torch.einsum 함수 이해하기Data miner/Development log 2022. 11. 15. 21:39
torch.einsum함수는 Transformer 논문에서 query의 값과 query에 position embedding값을 연계시키는 부분에서 사용되었다. 이에 대해서 한 번 정리하고 가는 것이 좋을 것 같아 포스팅을 하게 되었다. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 공식문서에 따르면, einsum() 함수는 input operands의 곱을 합하는 방식이며, Einstein 합 방식에 기반하여 output matrix 차원을 만드는 방식이다. 아래의 그림과 같이 operands에 속하는 텐서들의 차원이 순서대로 첫번째 요소에 대응하며, '->' 이후는 연산 이후의 out..