hook이라는 것은 보통 프로그래밍에서 패키지화 된 코드 내에 customized code를 삽입할 수 있게해주는 인터페이스/공간 이라고 하는데, 간단히 말해 프로그램의 실행 로직을 분석하고 싶거나 추가적인 기능을 제공하고 싶을 때 사용할 수 있다.
내가 본 코드들(pytorch)에서는 보통 추가적인 기능 제공에 사용하는 것을 보았다.
hook
hook의 간단한 예시를 들어보자면
def add(a, b):
return a+b
class Package():
def __init__(self):
self.name = 'package'
self.hooks = []
def __call__(self, a, b):
x = add(a, b)
print(x)
for hook in self.hooks:
output = hook(x)
if output:
x = output
return x
Python
복사
이렇게 해두면
def square(x):
return x * x
package = Package()
package.hooks.append(square)
output = package(1, 2)
Python
복사
이렇게 square라는 hook 함수를 추가해주면! output은 (1+2)^2 = 9가 됩니다.
이제 pytorch에서 nn.Modules에 hook을 추가해주는 것을 살펴볼까요
register_foward_hook register_forward_pre_hook register_backward_hook register_full_backward_hook register_load_state_dict_post_hook ….
꽤 많은 hook을 추가하는 메서드들이 존재합니다.
이 중 forward에 관련된 두가지만 살펴보겠다. 이 두가지를 살펴보면 다른 hook들은 어떻게 수행되는지 감을 잡을 수 있을 것이다.
register_forward_hook() , register_forward_pre_hook
class Model(nn.Module):
def __init__(self):
def forward(self, a, b):
output = a + b
return output
# register_forward_hook에 등록되는 함수는 module과 입력값, 출력값을 인자로 받는다
def forward_hook(module, input, output):
return input+output
a, b = torch.Tensor(1), torch.Tensor(2)
model = Model()
model.register_forward_hook(forward_hook)
model(a, b) # 6
# register_forward_pre_hook에 등록되는 함수는 moduler과 입력값만을 인자로 받는다.
# 함수 이름처럼 forward를 수행하기 전에 수행되기 때문이다.
def pre_forward_hook(module, input):
return input+3
model2 = Model()
model2.register_forward_pre_hook(pre_forward_hook)
model2(a, b) # 6
Python
복사
register_hook 은 Tensor에 적용되며, backward hook만 가능하다. Tensor의 gradient가 계산될 때마다 hook이 호출된다. gradient를 바꿀 수는 없지만 새로운 gradient 생성이 가능하며 기존 grad를 대체할 수 있다.
register_full_backward_hook 은 module에 적용되며 input에 대한 gradient가 계산될 때마다 hook이 호출된다. hook(module, grad_input, grad_output)과 같은 형태로 grad_input과 grad_output은 각각 input과 output의 gradient를 포함한 튜플이다. 새로운 gradient를 return해 기존 gradient를 대체할 수 있다. (직접적인 수정은 불가능하다)