Machine Learning
[Pytorch] model.train(), model.eval() 의미
MLra
2021. 10. 27. 17:35
https://stackoverflow.com/questions/60018578/what-does-model-eval-do-in-pytorch/60018731#60018731
What does model.eval() do in pytorch?
I am using this code, and saw model.eval() in some cases. I understand it is supposed to allow me to "evaluate my model", but I don't understand when I should and shouldn't use it, or how...
stackoverflow.com
nn.Module에는 train time과 evaluate time에 수행하는 다른 작업을 switching해줄 수 있도록하는 함수를 제공한다.
train time과 evaluate time에 서로 다르게 동작해야 하는 것들에는 대표적으로 아래와 같은 것들이 있다.
- Dropout layer
- BatchNorm layer
model.eval()을 수행하면 evaluation과정에서 사용하지 않을 layer들의 전원을 끈다.
# eval mode
model.eval()
with torch.no_grad():
...
out = model(val_data)
...
evaluation이 끝나면 다시 train mode로 변경을 해줘야한다.
# train mode
model.train()