Day 5: PyTorch Basics - Saving and Loading Models

Previously, we learned how to train a neural network model. But how do we actually make predictions with it? And how can we use open-source models that other people trained? In this post we learn how to store and load the weights that the model learned during training, and how to use open-source models and weights
100 Days Of PyTorch
Author

Liam Groen

Published

May 12, 2025

Keywords

Save PyTorch model, Load PyTorch model, Open-Source Models, Pre Trained Model

This post is part of a series on deploying models, other posts include:

How to Save a PyTorch Model?

Let’s reuse our model training code from the previous blog post.

Show code
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

class OurNeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        # Define neural net structure here, so we can store weights in them.
        self.flatten = nn.Flatten()
        self.linear_relu_chain = nn.Sequential(
            nn.Linear(in_features=28*28, out_features=512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
    
    def forward(self, input):
        # Use neural net structure to pass input data through

        input = self.flatten(input) # Shape: (28,28) -> shape: (784)

        predictions = self.linear_relu_chain(input) # Shape: (784) -> shape: (512) -> shape: (512) -> shape: (10)
        
        return predictions

def train_loop(dataloader, model, loss_func, optimizer):
    size = len(dataloader.dataset)
    model.train() # Set model to training mode

    # Update parameters each new batch
    for batch, (images, labels) in enumerate(dataloader):
        model_predictions = model(images)
        loss = loss_func(model_predictions, labels)

        # Compute gradient with backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Something to look at while model trains
        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(images)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def checking_loop(dataloader, model, loss_func):
    size = len(dataloader.dataset)
    number_of_batches = len(dataloader)
    test_loss, correct_amount = 0, 0

    model.eval() # Set model to check/test mode

    with torch.no_grad(): # We don't need to update parameters anymore. This speeds up testing.

        # This dataloader contains the test images
        for images, labels in dataloader:
            model_predictions = model(images)
            
            loss = loss_func(model_predictions, labels).item()
            test_loss += loss

            predicted_labels = nn.Softmax(dim=1)(model_predictions).argmax(1)
            correct = predicted_labels == labels
            # Turn every 'True' into a 1, and sum over them, converting the resulting tensor to a python integer
            correct_amount += correct.type(torch.float).sum().item()

    test_loss /= number_of_batches
    correct_amount /= size
    print(f"Test Error: \n Accuracy: {(100*correct_amount):>0.1f}%, Avg loss: {test_loss:>8f} \n")


model = OurNeuralNetwork().to("cpu")

learning_rate = 1e-3
batch_size = 64
epochs = 5

loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

Training the model again:

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_function, optimizer)
    checking_loop(test_dataloader, model, loss_function)
print("Done!")
Show output

Epoch 1
-------------------------------
loss: 1.469673  [   64/60000]
loss: 1.433510  [ 6464/60000]
loss: 1.416146  [12864/60000]
loss: 1.377274  [19264/60000]
loss: 1.353796  [25664/60000]
loss: 1.429164  [32064/60000]
loss: 1.270087  [38464/60000]
loss: 1.251621  [44864/60000]
loss: 1.361366  [51264/60000]
loss: 1.273360  [57664/60000]
Test Error: 
 Accuracy: 63.7%, Avg loss: 1.244470 

Epoch 2
-------------------------------
loss: 1.250782  [   64/60000]
loss: 1.358424  [ 6464/60000]
loss: 1.121522  [12864/60000]
loss: 1.036822  [19264/60000]
loss: 1.188575  [25664/60000]
loss: 1.127103  [32064/60000]
loss: 1.138114  [38464/60000]
loss: 1.059073  [44864/60000]
loss: 1.033557  [51264/60000]
loss: 1.098557  [57664/60000]
Test Error: 
 Accuracy: 65.0%, Avg loss: 1.083130 

Epoch 3
-------------------------------
loss: 1.070835  [   64/60000]
loss: 1.022191  [ 6464/60000]
loss: 1.005594  [12864/60000]
loss: 0.993012  [19264/60000]
loss: 1.047417  [25664/60000]
loss: 1.001495  [32064/60000]
loss: 1.095251  [38464/60000]
loss: 0.926997  [44864/60000]
loss: 0.960782  [51264/60000]
loss: 0.937367  [57664/60000]
Test Error: 
 Accuracy: 66.1%, Avg loss: 0.979457 

Epoch 4
-------------------------------
loss: 0.900051  [   64/60000]
loss: 1.099103  [ 6464/60000]
loss: 1.052053  [12864/60000]
loss: 0.843110  [19264/60000]
loss: 0.914962  [25664/60000]
loss: 1.017330  [32064/60000]
loss: 0.707650  [38464/60000]
loss: 0.890666  [44864/60000]
loss: 1.078490  [51264/60000]
loss: 0.758047  [57664/60000]
Test Error: 
 Accuracy: 67.3%, Avg loss: 0.909492 

Epoch 5
-------------------------------
loss: 0.935071  [   64/60000]
loss: 0.930360  [ 6464/60000]
loss: 0.886458  [12864/60000]
loss: 0.747989  [19264/60000]
loss: 0.919060  [25664/60000]
loss: 0.857149  [32064/60000]
loss: 0.808115  [38464/60000]
loss: 0.957309  [44864/60000]
loss: 0.915866  [51264/60000]
loss: 1.035016  [57664/60000]
Test Error: 
 Accuracy: 68.0%, Avg loss: 0.857042 

Epoch 6
-------------------------------
loss: 0.711309  [   64/60000]
loss: 0.731404  [ 6464/60000]
loss: 0.778495  [12864/60000]
loss: 0.826608  [19264/60000]
loss: 0.690381  [25664/60000]
loss: 0.793883  [32064/60000]
loss: 1.049005  [38464/60000]
loss: 0.860935  [44864/60000]
loss: 0.850578  [51264/60000]
loss: 0.894870  [57664/60000]
Test Error: 
 Accuracy: 69.2%, Avg loss: 0.820537 

Epoch 7
-------------------------------
loss: 0.751307  [   64/60000]
loss: 0.690765  [ 6464/60000]
loss: 0.885832  [12864/60000]
loss: 0.810388  [19264/60000]
loss: 0.656271  [25664/60000]
loss: 0.795354  [32064/60000]
loss: 0.873639  [38464/60000]
loss: 0.952544  [44864/60000]
loss: 0.621379  [51264/60000]
loss: 0.782824  [57664/60000]
Test Error: 
 Accuracy: 70.0%, Avg loss: 0.791091 

Epoch 8
-------------------------------
loss: 0.706885  [   64/60000]
loss: 0.791194  [ 6464/60000]
loss: 0.665691  [12864/60000]
loss: 0.586563  [19264/60000]
loss: 0.746921  [25664/60000]
loss: 0.670890  [32064/60000]
loss: 0.818113  [38464/60000]
loss: 0.725863  [44864/60000]
loss: 0.793836  [51264/60000]
loss: 0.689501  [57664/60000]
Test Error: 
 Accuracy: 71.0%, Avg loss: 0.764419 

Epoch 9
-------------------------------
loss: 0.579552  [   64/60000]
loss: 0.783948  [ 6464/60000]
loss: 0.766569  [12864/60000]
loss: 0.831361  [19264/60000]
loss: 0.964704  [25664/60000]
loss: 0.772870  [32064/60000]
loss: 0.836838  [38464/60000]
loss: 0.806005  [44864/60000]
loss: 0.795276  [51264/60000]
loss: 0.934505  [57664/60000]
Test Error: 
 Accuracy: 72.8%, Avg loss: 0.744880 

Epoch 10
-------------------------------
loss: 0.493929  [   64/60000]
loss: 0.790016  [ 6464/60000]
loss: 0.750857  [12864/60000]
loss: 0.762535  [19264/60000]
loss: 0.822756  [25664/60000]
loss: 0.723158  [32064/60000]
loss: 0.535035  [38464/60000]
loss: 0.708430  [44864/60000]
loss: 0.695287  [51264/60000]
loss: 0.616080  [57664/60000]
Test Error: 
 Accuracy: 73.5%, Avg loss: 0.724422 

Done!

The learned weights are stored in an attribute called state_dict. We can save these weights to a file to reuse them later by calling torch.save() on the attribute.

torch.save(model.state_dict(), "model_weights.pth")

How to Load a Saved Model?

To load the saved model, we first need to create a new instance of the same model class. Just the weights won’t do us any good, they need to correspond to the correct model structure. After instantiating a new model, we can copy the weights to it.

weights = torch.load("model_weights.pth", weights_only=True)

model_from_weights = OurNeuralNetwork()
model_from_weights.load_state_dict(weights)
model_from_weights.eval()
Show output
OurNeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_chain): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

We now have an exact copy of the neural network in a new variable!

If we want, we can inspect all the parameters (e.g to create visualizations that explain the model) through the models’ state_dict()

layer_2_bias = model.state_dict()['linear_relu_chain.2.bias'] # We can access stored parameters through the state_dict keys
print("stored parameters in the form of 'layer_num.type': ", model.state_dict().keys(), '\n')
print("Amount of values in the layer 2 bias:", layer_2_bias.shape, '\n') 
print("First 10 biases in layer 2:", layer_2_bias[:10])
stored parameters in the form of 'layer_num.type': 
[
    'linear_relu_chain.0.weight',
    'linear_relu_chain.0.bias',
    'linear_relu_chain.2.weight',
    'linear_relu_chain.2.bias',
    'linear_relu_chain.4.weight',
    'linear_relu_chain.4.bias',
]
Amount of values in the layer 2 bias: torch.Size([512]) 

First 3 biases in layer 2: tensor([0.0234, 0.0045, 0.0241])

How to Use Pre-Trained Models?

We don’t have to train every model that we want to use ourselves. Lots of times, a better model trained on more data for longer is available freely online. PyTorch comes with pre-built model structures and weights for these models.

from torchvision.models import resnet50, ResNet50_Weights

weights = ResNet50_Weights.DEFAULT
model = resnet50(weights)

model.eval() 

We can now use the ResNet50 model with best weights in our code.

Warmstarting / Transfer Learning

In a scenario where we have a dataset with domain-specific images and we want to train an image recognition model on the data, we don’t have to start from scratch. We can define the model structure that we want and use imported weights for initializing the model training. This way the model does not have to learn what an image is again. Since that knowledge is already embedded in the downloaded weights, the model only needs to learn to recognize the domain-specific images. This PyTorch article explains how to do this.

Further Reading