.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/basics/saveloadrun_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_basics_saveloadrun_tutorial.py: `Learn the Basics `_ || `Quickstart `_ || `Tensors `_ || `Datasets & DataLoaders `_ || `Transforms `_ || `Build Model `_ || `Autograd `_ || `Optimization `_ || **Save & Load Model** Save and Load the Model ============================ In this section we will look at how to persist model state with saving, loading and running model predictions. .. GENERATED FROM PYTHON SOURCE LINES 17-22 .. code-block:: default import torch import torchvision.models as models .. GENERATED FROM PYTHON SOURCE LINES 23-28 Saving and Loading Model Weights -------------------------------- PyTorch models store the learned parameters in an internal state dictionary, called ``state_dict``. These can be persisted via the ``torch.save`` method: .. GENERATED FROM PYTHON SOURCE LINES 28-32 .. code-block:: default model = models.vgg16(weights='IMAGENET1K_V1') torch.save(model.state_dict(), 'model_weights.pth') .. GENERATED FROM PYTHON SOURCE LINES 33-35 To load model weights, you need to create an instance of the same model first, and then load the parameters using ``load_state_dict()`` method. .. GENERATED FROM PYTHON SOURCE LINES 35-40 .. code-block:: default model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model model.load_state_dict(torch.load('model_weights.pth')) model.eval() .. GENERATED FROM PYTHON SOURCE LINES 41-42 .. note:: be sure to call ``model.eval()`` method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results. .. GENERATED FROM PYTHON SOURCE LINES 44-49 Saving and Loading Models with Shapes ------------------------------------- When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. We might want to save the structure of this class together with the model, in which case we can pass ``model`` (and not ``model.state_dict()``) to the saving function: .. GENERATED FROM PYTHON SOURCE LINES 49-52 .. code-block:: default torch.save(model, 'model.pth') .. GENERATED FROM PYTHON SOURCE LINES 53-54 We can then load the model like this: .. GENERATED FROM PYTHON SOURCE LINES 54-57 .. code-block:: default model = torch.load('model.pth') .. GENERATED FROM PYTHON SOURCE LINES 58-59 .. note:: This approach uses Python `pickle `_ module when serializing the model, thus it relies on the actual class definition to be available when loading the model. .. GENERATED FROM PYTHON SOURCE LINES 61-64 Related Tutorials ----------------- `Saving and Loading a General Checkpoint in PyTorch `_ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_beginner_basics_saveloadrun_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: saveloadrun_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: saveloadrun_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_