Torch - nn - Save/Load & STATE_DICT

Learnable parameters (i.e. W & B) of an torch.nn.Module model = model.parameters()

state_dict

  1. Python dictionary object

  2. maps each layer to its parameter tensor.

  3. only layers with learnable parameters

  4. also in Optimizer objects torch.optim

# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

output

Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])

Optimizer's state_dict:
state    {}
param_groups     [
    {
        'lr': 0.001, 
        'momentum': 0.9, 
        'dampening': 0, 
        'weight_decay': 0, 
        'nesterov': False, 
        'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]
    }
]

Saving the model’s learned parameters = save state_dict with the torch.save()

Save in pickle:

torch.save(model.state_dict(), PATH)

Load:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

model.eval()

model.eval()

call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference.

Drawback of save/load entire model

## Just an example, DON'T USE IT.

## save
# torch.save(model, PATH)

## load
# model = torch.load(PATH)
# model.eval()
  • Pickle does not save the model class itself.

  • The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved.

Last updated

Was this helpful?