Torch - nn - Save/Load & STATE_DICT

Learnable parameters (i.e. W & B) of an torch.nn.Module model = model.parameters()

state_dict

  1. Python dictionary object

  2. maps each layer to its parameter tensor.

  3. only layers with learnable parameters

  4. also in Optimizer objects torch.optim

# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

output

Saving the model’s learned parameters = save state_dict with the torch.save()

Save in pickle:

Load:

model.eval()

call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference.

Drawback of save/load entire model

  • Pickle does not save the model class itself.

  • The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved.

Last updated

Was this helpful?