.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/transformer_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_transformer_tutorial.py: Language Modeling with ``nn.Transformer`` and torchtext =============================================================== This is a tutorial on training a model to predict the next word in a sequence using the `nn.Transformer `__ module. The PyTorch 1.2 release includes a standard transformer module based on the paper `Attention is All You Need `__. Compared to Recurrent Neural Networks (RNNs), the transformer model has proven to be superior in quality for many sequence-to-sequence tasks while being more parallelizable. The ``nn.Transformer`` module relies entirely on an attention mechanism (implemented as `nn.MultiheadAttention `__) to draw global dependencies between input and output. The ``nn.Transformer`` module is highly modularized such that a single component (e.g., `nn.TransformerEncoder `__) can be easily adapted/composed. .. image:: ../_static/img/transformer_architecture.jpg .. GENERATED FROM PYTHON SOURCE LINES 25-28 Define the model ---------------- .. GENERATED FROM PYTHON SOURCE LINES 31-51 In this tutorial, we train a ``nn.TransformerEncoder`` model on a language modeling task. Please note that this tutorial does not cover the training of `nn.TransformerDecoder `__, as depicted in the right half of the diagram above. The language modeling task is to assign a probability for the likelihood of a given word (or a sequence of words) to follow a sequence of words. A sequence of tokens are passed to the embedding layer first, followed by a positional encoding layer to account for the order of the word (see the next paragraph for more details). The ``nn.TransformerEncoder`` consists of multiple layers of `nn.TransformerEncoderLayer `__. Along with the input sequence, a square attention mask is required because the self-attention layers in ``nn.TransformerDecoder`` are only allowed to attend the earlier positions in the sequence. For the language modeling task, any tokens on the future positions should be masked. To produce a probability distribution over output words, the output of the ``nn.TransformerEncoder`` model is passed through a linear layer to output unnormalized logits. The log-softmax function isn't applied here due to the later use of `CrossEntropyLoss `__, which requires the inputs to be unnormalized logits. .. GENERATED FROM PYTHON SOURCE LINES 51-99 .. code-block:: default import math import os from tempfile import TemporaryDirectory from typing import Tuple import torch from torch import nn, Tensor from torch.nn import TransformerEncoder, TransformerEncoderLayer from torch.utils.data import dataset class TransformerModel(nn.Module): def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5): super().__init__() self.model_type = 'Transformer' self.pos_encoder = PositionalEncoding(d_model, dropout) encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) self.embedding = nn.Embedding(ntoken, d_model) self.d_model = d_model self.linear = nn.Linear(d_model, ntoken) self.init_weights() def init_weights(self) -> None: initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.linear.bias.data.zero_() self.linear.weight.data.uniform_(-initrange, initrange) def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor: """ Arguments: src: Tensor, shape ``[seq_len, batch_size]`` src_mask: Tensor, shape ``[seq_len, seq_len]`` Returns: output Tensor of shape ``[seq_len, batch_size, ntoken]`` """ src = self.embedding(src) * math.sqrt(self.d_model) src = self.pos_encoder(src) output = self.transformer_encoder(src, src_mask) output = self.linear(output) return output .. GENERATED FROM PYTHON SOURCE LINES 100-106 ``PositionalEncoding`` module injects some information about the relative or absolute position of the tokens in the sequence. The positional encodings have the same dimension as the embeddings so that the two can be summed. Here, we use ``sine`` and ``cosine`` functions of different frequencies. .. GENERATED FROM PYTHON SOURCE LINES 106-129 .. code-block:: default class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x: Tensor) -> Tensor: """ Arguments: x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` """ x = x + self.pe[:x.size(0)] return self.dropout(x) .. GENERATED FROM PYTHON SOURCE LINES 130-133 Load and batch data ------------------- .. GENERATED FROM PYTHON SOURCE LINES 136-170 This tutorial uses ``torchtext`` to generate Wikitext-2 dataset. To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data. %% .. code-block:: bash %%bash pip install portalocker pip install torchdata The vocab object is built based on the train dataset and is used to numericalize tokens into tensors. Wikitext-2 represents rare tokens as ``. Given a 1-D vector of sequential data, ``batchify()`` arranges the data into ``batch_size`` columns. If the data does not divide evenly into ``batch_size`` columns, then the data is trimmed to fit. For instance, with the alphabet as the data (total length of 26) and ``batch_size=4``, we would divide the alphabet into sequences of length 6, resulting in 4 of such sequences. .. math:: \begin{bmatrix} \text{A} & \text{B} & \text{C} & \ldots & \text{X} & \text{Y} & \text{Z} \end{bmatrix} \Rightarrow \begin{bmatrix} \begin{bmatrix}\text{A} \\ \text{B} \\ \text{C} \\ \text{D} \\ \text{E} \\ \text{F}\end{bmatrix} & \begin{bmatrix}\text{G} \\ \text{H} \\ \text{I} \\ \text{J} \\ \text{K} \\ \text{L}\end{bmatrix} & \begin{bmatrix}\text{M} \\ \text{N} \\ \text{O} \\ \text{P} \\ \text{Q} \\ \text{R}\end{bmatrix} & \begin{bmatrix}\text{S} \\ \text{T} \\ \text{U} \\ \text{V} \\ \text{W} \\ \text{X}\end{bmatrix} \end{bmatrix} Batching enables more parallelizable processing. However, batching means that the model treats each column independently; for example, the dependence of ``G`` and ``F`` can not be learned in the example above. .. GENERATED FROM PYTHON SOURCE LINES 170-217 .. code-block:: default from torchtext.datasets import WikiText2 from torchtext.data.utils import get_tokenizer from torchtext.vocab import build_vocab_from_iterator train_iter = WikiText2(split='train') tokenizer = get_tokenizer('basic_english') vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['']) vocab.set_default_index(vocab['']) def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor: """Converts raw text into a flat Tensor.""" data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter] return torch.cat(tuple(filter(lambda t: t.numel() > 0, data))) # ``train_iter`` was "consumed" by the process of building the vocab, # so we have to create it again train_iter, val_iter, test_iter = WikiText2() train_data = data_process(train_iter) val_data = data_process(val_iter) test_data = data_process(test_iter) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def batchify(data: Tensor, bsz: int) -> Tensor: """Divides the data into ``bsz`` separate sequences, removing extra elements that wouldn't cleanly fit. Arguments: data: Tensor, shape ``[N]`` bsz: int, batch size Returns: Tensor of shape ``[N // bsz, bsz]`` """ seq_len = data.size(0) // bsz data = data[:seq_len * bsz] data = data.view(bsz, seq_len).t().contiguous() return data.to(device) batch_size = 20 eval_batch_size = 10 train_data = batchify(train_data, batch_size) # shape ``[seq_len, batch_size]`` val_data = batchify(val_data, eval_batch_size) test_data = batchify(test_data, eval_batch_size) .. GENERATED FROM PYTHON SOURCE LINES 218-221 Functions to generate input and target sequence ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 224-236 ``get_batch()`` generates a pair of input-target sequences for the transformer model. It subdivides the source data into chunks of length ``bptt``. For the language modeling task, the model needs the following words as ``Target``. For example, with a ``bptt`` value of 2, we’d get the following two Variables for ``i`` = 0: .. image:: ../_static/img/transformer_input_target.png It should be noted that the chunks are along dimension 0, consistent with the ``S`` dimension in the Transformer model. The batch dimension ``N`` is along dimension 1. .. GENERATED FROM PYTHON SOURCE LINES 236-254 .. code-block:: default bptt = 35 def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]: """ Args: source: Tensor, shape ``[full_seq_len, batch_size]`` i: int Returns: tuple (data, target), where data has shape ``[seq_len, batch_size]`` and target has shape ``[seq_len * batch_size]`` """ seq_len = min(bptt, len(source) - 1 - i) data = source[i:i+seq_len] target = source[i+1:i+1+seq_len].reshape(-1) return data, target .. GENERATED FROM PYTHON SOURCE LINES 255-258 Initiate an instance -------------------- .. GENERATED FROM PYTHON SOURCE LINES 261-264 The model hyperparameters are defined below. The ``vocab`` size is equal to the length of the vocab object. .. GENERATED FROM PYTHON SOURCE LINES 264-274 .. code-block:: default ntokens = len(vocab) # size of vocabulary emsize = 200 # embedding dimension d_hid = 200 # dimension of the feedforward network model in ``nn.TransformerEncoder`` nlayers = 2 # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder`` nhead = 2 # number of heads in ``nn.MultiheadAttention`` dropout = 0.2 # dropout probability model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device) .. GENERATED FROM PYTHON SOURCE LINES 275-278 Run the model ------------- .. GENERATED FROM PYTHON SOURCE LINES 281-288 We use `CrossEntropyLoss `__ with the `SGD `__ (stochastic gradient descent) optimizer. The learning rate is initially set to 5.0 and follows a `StepLR `__ schedule. During training, we use `nn.utils.clip_grad_norm\_ `__ to prevent gradients from exploding. .. GENERATED FROM PYTHON SOURCE LINES 288-338 .. code-block:: default import time criterion = nn.CrossEntropyLoss() lr = 5.0 # learning rate optimizer = torch.optim.SGD(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95) def train(model: nn.Module) -> None: model.train() # turn on train mode total_loss = 0. log_interval = 200 start_time = time.time() num_batches = len(train_data) // bptt for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)): data, targets = get_batch(train_data, i) output = model(data) output_flat = output.view(-1, ntokens) loss = criterion(output_flat, targets) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() total_loss += loss.item() if batch % log_interval == 0 and batch > 0: lr = scheduler.get_last_lr()[0] ms_per_batch = (time.time() - start_time) * 1000 / log_interval cur_loss = total_loss / log_interval ppl = math.exp(cur_loss) print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | ' f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | ' f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}') total_loss = 0 start_time = time.time() def evaluate(model: nn.Module, eval_data: Tensor) -> float: model.eval() # turn on evaluation mode total_loss = 0. with torch.no_grad(): for i in range(0, eval_data.size(0) - 1, bptt): data, targets = get_batch(eval_data, i) seq_len = data.size(0) output = model(data) output_flat = output.view(-1, ntokens) total_loss += seq_len * criterion(output_flat, targets).item() return total_loss / (len(eval_data) - 1) .. GENERATED FROM PYTHON SOURCE LINES 339-341 Loop over epochs. Save the model if the validation loss is the best we've seen so far. Adjust the learning rate after each epoch. .. GENERATED FROM PYTHON SOURCE LINES 341-367 .. code-block:: default best_val_loss = float('inf') epochs = 3 with TemporaryDirectory() as tempdir: best_model_params_path = os.path.join(tempdir, "best_model_params.pt") for epoch in range(1, epochs + 1): epoch_start_time = time.time() train(model) val_loss = evaluate(model, val_data) val_ppl = math.exp(val_loss) elapsed = time.time() - epoch_start_time print('-' * 89) print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | ' f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}') print('-' * 89) if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), best_model_params_path) scheduler.step() model.load_state_dict(torch.load(best_model_params_path)) # load best model states .. GENERATED FROM PYTHON SOURCE LINES 368-371 Evaluate the best model on the test dataset ------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 371-378 .. code-block:: default test_loss = evaluate(model, test_data) test_ppl = math.exp(test_loss) print('=' * 89) print(f'| End of training | test loss {test_loss:5.2f} | ' f'test ppl {test_ppl:8.2f}') print('=' * 89) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_beginner_transformer_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: transformer_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: transformer_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_