Machine Learning

[Pytorch] model.train(), model.eval() 의미

https://stackoverflow.com/questions/60018578/what-does-model-eval-do-in-pytorch/60018731#60018731

 

What does model.eval() do in pytorch?

I am using this code, and saw model.eval() in some cases. I understand it is supposed to allow me to "evaluate my model", but I don't understand when I should and shouldn't use it, or how...

stackoverflow.com

nn.Module에는 train time과 evaluate time에 수행하는 다른 작업을 switching해줄 수 있도록하는 함수를 제공한다.

train time과 evaluate time에 서로 다르게 동작해야 하는 것들에는 대표적으로 아래와 같은 것들이 있다.

  • Dropout layer
  • BatchNorm layer

model.eval()을 수행하면 evaluation과정에서 사용하지 않을 layer들의 전원을 끈다.

# eval mode

model.eval()

with torch.no_grad():
	...
    out = model(val_data)
    ...

evaluation이 끝나면 다시 train mode로 변경을 해줘야한다.

# train mode
model.train()