flops연산을 위해 구현된 git들을 보면 대부분이 이를 통해서 구현되어 있습니다. register_buffer()메소드는 무엇을 위해 쓰는 것인지 살펴보겠습니다.
buffer
computer science에서 buffer는 어떠한 임시 저장 공간을 의미합니다. A와 B가 서로 입출력을 수행할 때 속도 차이를 극복하기 위해 잠시 사용하는 임시공간 정도로 생각하면 되겠습니다.
동영상을 볼 때 동영상의 한 프레임프레임을 받는 속도와 우리가 영상을 보는 속도는 차이가 있기에, 버퍼라는 임시 공간 안에 동영상 프레임 데이터들을 최대한 빨리 받아 저장해두는 예시를 들 수 있겠네요.
pytorch의 nn.Module의 메소드인 register_buffer()
이 메소드를 통해서 nn.Module 인스턴스에 buffer를 추가할 수 있습니다. 파라미터가 아니기에 모델의 학습(최적화)에는 사용되지 않습니다. 그러나 state_dict에는 저장됩니다.
예를 들면
conv = nn.Conv2d(3, 128, 3)
conv.register_buffer('total_ops', torch.zeros(1))
Python
복사
다음과 같은 코드에서 conv라는 모듈에 ‘total_ops’라는 이름의 버퍼를 추가해주고, 1크기의 0값을 갖는 텐서를 그에 해당하는 value로 넣어줍니다.
conv.total_ops = torch.Tensor(1) 과 같이 변경해주는 것도 가능합니다.
이렇게 추가된 buffer는 gpu에서 동작 가능하기 때문에 이런 방식 외에도 네트워크에 업데이트 하지 않는 레이어를 추가해주고 싶을 때 사용도 가능합니다.
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.register_buffer('my_buffer', torch.randn(1))
self.my_param = nn.Parameter(torch.randn(1))
def forward(self, x):
return x
Python
복사