less than 1 minute read

torch.gather() 설명

torch.gather(input, dim, index, *, sparse_grad=False, out=None)  Tensor

특정 인덱스를 뽑고자 할 때 사용되며, dim으로 지정된 축을 따라 값을 수집한다. 그리고 index의 차원의 수와 matrix의 차원의 수를 맞춰줘야 한다.

  • dim 부분을 제외하고 나머지 차원은 동일해야 한다.
  • 인덱스 텐서와 출력 텐서의 크기는 동일하다.


예제1

import torch

origin = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]])
result = torch.gather(origin, 1, index)
print("origin: \n", origin)
print("index: \n", index)
print("result: \n", result)
output:
origin: 
 tensor([[1, 2],
        [3, 4]])
index: 
 tensor([[0, 0],
        [1, 0]])
result: 
 tensor([[1, 1],
        [4, 3]])
  • dim의 매개변수에 1을 전달했으므로 열을 따라 값을 추출하게 되므로 행이 나오게 된다. (행에서 지정된 열의 값을 추출하게 된다.)
  • 첫 번째 행에서 인덱스 ‘[0, 0]’을 사용하므로, 원본 텐서의 첫 번째 행에서 0번째 위치의 값을 두 번 호출한다. 첫 번째 행의 0번째 위치의 값은 ‘1’이므로 결과는 [1, 1]이 된다.
  • 두 번째 행에서 인덱스 [1, 0]을 호출하면, 두 번째 행에서 1번째 위치의 값 4, 0번째 위치의 값 3을 가리킨다. 즉, [4, 3]이 된다.


예제2

https://data-newbie.tistory.com/709 블로그 내용 참고하여 내용 추가하기


출처

https://data-newbie.tistory.com/709
https://pytorch.org/docs/stable/generated/torch.gather.html

Categories:

Updated: