less than 1 minute read

squeeze()

import torch

tensor = torch.tensor([[[5,3,0], [2,0,1]]])
print(tensor.shape)

print(tensor.squeeze().shape)
output:
torch.Size([1, 2, 3])
torch.Size([2, 3])
  • squeeze()함수는 axis가 1인 축을 제거한다.

unsqueeze()

주어진 매개변수 차원에 대해 1인 축을 추가한다.

# tensor dimension = 1 -> able to 0~1 (number of 2)
# tensor dimension = 3 -> abled to 0~3 (number of 4)

tensor2 = torch.tensor([5, 2, 3])
print(tensor2)
print(tensor2.shape)

print("")

tensor2 = tensor2.unsqueeze(1)
print(tensor2)
print(tensor2.shape)
output:
tensor([5, 2, 3])
torch.Size([3])

tensor([[5],
        [2],
        [3]])
torch.Size([3, 1])

Categories:

Updated: