Saving the model’s learned parameters = save state_dict with the torch.save()
Save/Load state_dict (Recommended)
Save in pickle:
torch.save(model.state_dict(), PATH)
Load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
model.eval()
call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference.
Drawback of save/load entire model
## Just an example, DON'T USE IT.
## save
# torch.save(model, PATH)
## load
# model = torch.load(PATH)
# model.eval()
Pickle does not save the model class itself.
The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved.