Torch - nn - Save/Load & STATE_DICT

Learnable parameters (i.e. W & B) of an torch.nn.Module model = model.parameters()

state_dict

  1. Python dictionary object

  2. maps each layer to its parameter tensor.

  3. only layers with learnable parameters

  4. also in Optimizer objects torch.optim

# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

output

Saving the model’s learned parameters = save state_dict with the torch.save()

Save in pickle:

Load:

model.eval()

call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference.

Drawback of save/load entire model

  • Pickle does not save the model class itself.

  • The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved.

Last updated