Torch - nn - Save/Load & STATE_DICT
Learnable parameters (i.e. W & B) of an torch.nn.Module
model = model.parameters()
state_dict
Python dictionary object
maps each layer to its parameter tensor.
only layers with learnable parameters
also in Optimizer objects
torch.optim
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
output
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [
{
'lr': 0.001,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]
}
]
Saving the model’s learned parameters = save state_dict
with the torch.save()
Save/Load state_dict (Recommended)
Save in pickle
:
torch.save(model.state_dict(), PATH)
Load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
model.eval()
model.eval()
call model.eval()
to set dropout and batch normalization layers to evaluation mode before running inference.
Drawback of save/load entire model
## Just an example, DON'T USE IT.
## save
# torch.save(model, PATH)
## load
# model = torch.load(PATH)
# model.eval()
Pickle does not save the model class itself.
The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved.
Last updated
Was this helpful?