[Pytorch] 언제 CNN의 필터가 초기화 되나?
참고한 스택오버플로우 링크.
CNN에서는 커널을 통해 만들어진 피처맵들이 모두 학습가능한 weight들이다. 그러나, 코드상에서 이것을 딱히 초기화 해주는 코드는 없다.
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # in_channel=3, out_channel=64
해당 코드는 CNN의 첫 conv layer인데, 입력으로 RGB input channel에 3이 들어가고, output channel에는 64가 매개변수로 들어감으로 64개의 필터(피처맵)가 만들어진다. 근데, 64개의 필터가 역전파를 통해 손실 값을 최소화하는 방향으로 가중치를 조정하는 것은 알겠는데 64개의 필터가 각각 어떤 값으로 처음에 초기화되는지에 대한 의문이 있었다.
이는 공식문서의 코드를 보면 확인할 수 있는데, Conv2d 클래스는 _ConvNd 클래스를 상속받으며 ConvNd의 __init__
함수의 마지막에서 reset_paramenter()
를 호출한다.
reset_parameters()
의 코드는 다음과 같다
def reset_parameters(self) -> None:
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size)
# For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
파라미터를 초기화 할 때 카이밍 초기화를 사용하는 것을 확인할 수 있다.
kaiming 초기화에서 sqrt(5)로 초기화 하는 이유는 위 주석에 달려있는 github의 이슈에 들어가서 확인해보면 알 수 있다. 초기화 코드의 리팩토링 때문이었다고 하는데, kaiming 초기화 방법을 일반적인 활성화 함수에 적용할 수 있도록 만든 파라미터이다.
출처
https://stackoverflow.com/questions/53990652/what-is-the-first-initialized-weight-in-pytorch-convolutional-layer
https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#_ConvNd