Shortcuts

NestedTensors

NestedTensors are similar to regular tensors, except for their shape:

  • for a regular tensor, each dimension has a size

  • for a nestedtensor, not all dimensions have regular sizes; some of them are jagged

Nestedtensors are a natural solution for representing sequential data within various domains:

  • in NLP, sentences can have variable lengths, so a batch of sentences forms a nestedtensor

  • in CV, images can have variable shapes, so a batch of images forms a nestedtensor

In this tutorial, we will demonstrate basic usage of nestedtensors and motivate their usefulness for operating on sequential data of varying lengths with a real-world example.

NestedTensor are currently a prototype feature and are subject to change.

import torch
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

NestedTensor Initialization

From the Python frontend, a nestedtensor can be created from a list of tensors. We denote nt[i] as the ith tensor component of a nestedtensor.

nt = torch.nested.nested_tensor([torch.arange(12).reshape(
    2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device)
print(f"{nt=}")

By padding every underlying tensor to the same shape, a nestedtensor can be converted to a regular tensor.

padded_out_tensor = torch.nested.to_padded_tensor(nt, padding=0.0)
print(f"{padded_out_tensor=}")

All tensors posses an attribute for determining if they are nested;

print(f"nt is nested: {nt.is_nested}")
print(f"padded_out_tensor is nested: {padded_out_tensor.is_nested}")

It is common to construct nestedtensors from batches of irregularly shaped tensors. i.e. dimension 0 is assumed to be the batch dimension. Indexing dimension 0 gives back the first underlying tensor component.

print("First underlying tensor component:", nt[0], sep='\n')
print("last column of 2nd underlying tensor component:", nt[1, :, -1], sep='\n')

# When indexing a nestedtensor's 0th dimension, the result is a regular tensor.
print(f"First underlying tensor component is nested: {nt[0].is_nested}")

An important note is that slicing in dimension 0 has not been supported yet. Which means it not currently possible to construct a view that combines the underlying tensor components.

Nested Tensor Operations

As each operation must be explicitly implemented for nestedtensors, operation coverage for nestedtensors is currently narrower than that of regular tensors. For now, only basic operations such as index, dropout, softmax, transpose, reshape, linear, bmm are covered. However, coverage is being expanded. If you need certain operations, please file an issue to help us prioritize coverage.

reshape

The reshape op is for changing the shape of a tensor. Its full semantics for regular tensors can be found here. For regular tensors, when specifying the new shape, a single dimension may be -1, in which case it is inferred from the remaining dimensions and the number of elements.

The semantics for nestedtensors are similar, except that -1 no longer infers. Instead, it inherits the old size (here 2 for nt[0] and 3 for nt[1]). -1 is the only legal size to specify for a jagged dimension.

nt_reshaped = nt.reshape(2, -1, 2, 3)
print(f"{nt_reshaped=}")

transpose

The transpose op is for swapping two dimensions of a tensor. Its full semantics can be found here. Note that for nestedtensors dimension 0 is special; it is assumed to be the batch dimension, so transposes involving nestedtensor dimension 0 are not supported.

nt_transposed = nt_reshaped.transpose(1, 2)
print(f"{nt_transposed=}")

others

Other operations have the same semantics as for regular tensors. Applying the operation on a nestedtensor is equivalent to applying the operation to the underlying tensor components, with the result being a nestedtensor as well.

nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device)
nt3 = torch.matmul(nt_transposed, nt_mm)
print(f"Result of Matmul:\n {nt3}")

nt4 = F.dropout(nt3, 0.1)
print(f"Result of Dropout:\n {nt4}")

nt5 = F.softmax(nt4, -1)
print(f"Result of Softmax:\n {nt5}")

Why Nested Tensor

When data is sequential, it is often the case that each sample has a different length. For example, in a batch of sentences, each sentence has a different number of words. A common technique for handling varying sequences is to manually pad each data tensor to the same shape in order to form a batch. For example, we have 2 sentences with different lengths and a vocabulary In order to represent his as single tensor we pad with 0 to the max length in the batch.

sentences = [["goodbye", "padding"],
             ["embrace", "nested", "tensor"]]
vocabulary = {"goodbye": 1.0, "padding": 2.0,
              "embrace": 3.0, "nested": 4.0, "tensor": 5.0}
padded_sentences = torch.tensor([[1.0, 2.0, 0.0],
                                 [3.0, 4.0, 5.0]])
nested_sentences = torch.nested.nested_tensor([torch.tensor([1.0, 2.0]),
                                               torch.tensor([3.0, 4.0, 5.0])])
print(f"{padded_sentences=}")
print(f"{nested_sentences=}")

This techinque of padding a batch of data to its max length is not optimal. The padded data is not needed for computation and wastes memory by allocating larger tensors than necessary. Further, not all operations have the same semnatics when applied to padded data. For matrix multiplications in order to ignore the padded entries, one needs to pad with 0 while for softmax one has to pad with -inf to ignore specific entries.

padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")],
                                             [3.0, 4.0, 5.0]])
print(F.softmax(padded_sentences_for_softmax, -1))
print(F.softmax(nested_sentences, -1))

Let us take a look at a practical example: the multi-head attention component utilized in Transformers. The nestedtensor version is straightforward.

import math

def mha_nested(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nheads: int,
               W_q: torch.Tensor, W_k: torch.Tensor, W_v: torch.Tensor, W_out: torch.Tensor,
               b_q: torch.Tensor = None, b_k: torch.Tensor = None, b_v: torch.Tensor = None, b_out: torch.Tensor = None,
               dropout_p: float = 0.0) -> torch.Tensor:
    """Compute multi-head attention with nested tensors.
    Args:
        query (torch.Tensor): query of shape (N, L_t, E_q)
        key (torch.Tensor): key of shape (N, L_s, E_k)
        value (torch.Tensor): value of shape (N, L_s, E_v)
        nheads (int): number of heads in multi-head attention
        W_q (torch.Tensor): Weight for query input projection of shape (E_total, E_q)
        W_k (torch.Tensor): Weight for key input projection of shape (E_total, E_k)
        W_v (torch.Tensor): Weight for value input projection of shape (E_total, E_v)
        W_out (torch.Tensor): Weight for output projection of shape (E_out, E_total)
        b_q (torch.Tensor, optional): Bias for query input projection of shape E_total. Default: None. Defaults to None.
        b_k (torch.Tensor, optional): Bias for key input projection of shape E_total. Default: None. Defaults to None.
        b_v (torch.Tensor, optional): Bias for value input projection of shape E_total. Default: None. Defaults to None.
        b_out (torch.Tensor, optional): Bias for output projection of shape E_out. Default: None. Defaults to None.
        dropout_p (float, optional): Dropout probability. Defaults to 0.0.

        Where:
            N is the batch size
            L_t is the target sequence length (jagged)
            L_s is the source sequence length (jagged)
            E_q is the embedding size for query
            E_k is the embedding size for key
            E_v is the embedding size for value
            E_total is the embedding size for all heads combined
            E_out is the output embedding size
    Returns:
        torch.Tensor:  Output of shape (N, L_t, E_out)
    """

    N = query.size(0)
    E_total = W_q.size(0)
    assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
    E_head = E_total // nheads

    # apply input projection
    # (N, L_t, E_q) -> (N, L_t, E_total)
    query = F.linear(query, W_q, b_q)
    # (N, L_s, E_k) -> (N, L_s, E_total)
    key = F.linear(key, W_k, b_k)
    # (N, L_s, E_v) -> (N, L_s, E_total)
    value = F.linear(value, W_v, b_v)

    # reshape query, key, value to separate by head
    # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
    query = query.reshape(N, -1, nheads, E_head).transpose(1, 2)
    # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
    key = key.reshape(N, -1, nheads, E_head).transpose(1, 2)
    # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
    value = value.reshape(N, -1, nheads, E_head).transpose(1, 2)

    # query matmul key^T
    # (N, nheads, L_t, E_head) x (N, nheads, L_s, E_head)^T -> (N, nheads, L_t, L_s)
    keyT = key.transpose(-1, -2)
    attn_weights = torch.matmul(query, keyT)

    # scale down
    attn_weights = attn_weights * (1.0 / math.sqrt(E_head))

    # softmax
    attn_weights = F.softmax(attn_weights, dim=-1)

    # dropout
    if dropout_p > 0.0:
        attn_weights = F.dropout(attn_weights, p=dropout_p)

    # attention_weights matmul value
    # (N, nheads, L_t, L_s) x (N, nheads, L_s, E_head) -> (N, nheads, L_t, E_head)
    attn_output = torch.matmul(attn_weights, value)

    # merge heads
    # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
    attn_output = attn_output.transpose(1, 2).reshape(N, -1, E_total)

    # apply output projection
    # (N, L_t, E_total) -> (N, L_t, E_out)
    attn_output = F.linear(attn_output, W_out, b_out)

    return attn_output

The 0-padded tensor version additionally requires masks for more complicated treatments at padded entries.

def mha_padded(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nheads: int,
               attn_mask_q: torch.Tensor, attn_mask_kv: torch.Tensor,
               W_q: torch.Tensor, W_k: torch.Tensor, W_v: torch.Tensor, W_out: torch.Tensor,
               b_q: torch.Tensor = None, b_k: torch.Tensor = None, b_v: torch.Tensor = None, b_out: torch.Tensor = None,
               dropout_p: float = 0.0) -> torch.Tensor:
    """Compute multi-head attention for padded out dense tensors.

    Args:
        query (torch.Tensor): query of shape (N, L_t, E_q)
        key (torch.Tensor): key of shape (N, L_s, E_k)
        value (torch.Tensor): value of shape (N, L_s, E_v)
        nheads (int): number of heads in multi-head attention
        attn_mask_q (torch.Tensor): boolean mask indicating locations that should not take part in attention for query, shape (N, L_t)
        attn_mask_kv (torch.Tensor): boolean mask indicating locations that should not take part in attention for key and value, shape (N, L_s)
        W_q (torch.Tensor): Weight for query input projection of shape (E_total, E_q)
        W_k (torch.Tensor): Weight for key input projection of shape (E_total, E_k)
        W_v (torch.Tensor): Weight for value input projection of shape (E_total, E_v)
        W_out (torch.Tensor): Weight for output projection of shape (E_out, E_total)
        b_q (torch.Tensor, optional): Bias for query input projection of shape E_total.. Defaults to None.
        b_k (torch.Tensor, optional): Bias for key input projection of shape E_total.. Defaults to None.
        b_v (torch.Tensor, optional): Bias for value input projection of shape E_total.. Defaults to None.
        b_out (torch.Tensor, optional): Bias for output projection of shape E_out. Defaults to None.
        dropout_p (float, optional): Dropout probability. Defaults to 0.0.

        Where:
            N is the batch size
            L_t is the target sequence length (padded)
            L_s is the source sequence length (padded)
            E_q is the embedding size for query
            E_k is the embedding size for key
            E_v is the embedding size for value
            E_total is the embedding size for all heads combined
            E_out is the output embedding size
    Returns:
        torch.Tensor: Output of shape (N, L_t, E_out)
    """
    N = query.size(0)
    L_t = query.size(1)
    L_s = key.size(1)
    E_total = W_q.size(0)
    assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
    assert L_t == L_s, "This implementation assumes equal query and key sequence lengths"
    E_head = E_total // nheads

    # apply input projection
    # (N, L_t, E_q) -> (N, L_t, E_total)
    query = F.linear(query, W_q, b_q)
    # (N, L_s, E_k) -> (N, L_s, E_total)
    key = F.linear(key, W_k, b_k)
    # (N, L_s, E_v) -> (N, L_s, E_total)
    value = F.linear(value, W_v, b_v)

    # reshape query, key, value to separate by head
    # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) -> (N * nheads, L_t, E_head)
    query = query.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head)
    # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head)
    key = key.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head)
    # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head)
    value = value.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head)

    # query bmm key^T
    # (N * nheads, L_t, E_head) x (N * nheads, L_s, E_head)^T -> (N * nheads, L_t, L_s)
    keyT = key.transpose(-1, -2)
    attn_weights = torch.bmm(query, keyT)

    # scale down
    attn_weights = attn_weights * (1.0 / math.sqrt(E_head))

    # Have to manipulate masks in order to apply them to the attention weights
    key_padding_mask = attn_mask_q.view(N, 1, 1, L_t).expand(-1, nheads, -1, -1).reshape(N*nheads, 1, L_t).to(device=device)
    attn_mask = torch.zeros(key_padding_mask.shape, device=device, dtype=torch.float32)
    attn_mask = attn_mask.masked_fill_(key_padding_mask, float("-inf"))

    # Zero out the attention weights where the mask is True by adding -inf prior to softmax
    attn_weights.add_(attn_mask)

    # softmax
    attn_weights = F.softmax(attn_weights, dim=-1).nan_to_num_(0.0)

    # dropout
    if dropout_p > 0.0:
        attn_weights = F.dropout(attn_weights, p=dropout_p)

    # attention_weights bmm value
    # (N * nheads, L_t, L_s) x (N * nheads, L_s, E_head) -> (N * nheads, L_t, E_head)
    attn_output = attn_weights.bmm(value)

    # merge heads
    # (N * nheads, L_t, E_head) -> (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
    attn_output = attn_output.reshape(N, nheads, -1, E_head).transpose(1, 2).reshape(N, -1, E_total)

    # apply output projection
    # (N, L_t, E_total) -> (N, L_t, E_out)
    attn_output = F.linear(attn_output, W_out, b_out)

    # padding-specific step: remove output projection bias from padded entries
    attn_output[attn_mask_q, :] = 0.0

    return attn_output

set hyperparameters following the Transformer paper

N = 512
E_q, E_k, E_v, E_total, E_out = 512, 512, 512, 512, 512
nheads = 8

except for dropout probability: set to 0 for correctness check

dropout_p = 0.0

Let us generate some realistic fake data from Zipf’s law.

import numpy as np

def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return sentence_lengths

alpha = 1.2

sentence_lengths = zipf_sentence_lengths(alpha, N)
L_t = np.max(sentence_lengths)
L_s = L_t

create inputs

# create parameters
W_q, b_q = torch.randn((E_total, E_q), device=device), torch.randn(E_total, device=device)
W_k, b_k = torch.randn((E_total, E_k), device=device), torch.randn(E_total, device=device)
W_v, b_v = torch.randn((E_total, E_v), device=device), torch.randn(E_total, device=device)
W_out, b_out = torch.randn((E_out, E_total), device=device), torch.randn(E_out, device=device)

# create nested input
queries = []
keys = []
values = []
for i in range(N):
    l = sentence_lengths[i]
    s = l
    queries.append(torch.randn((l, E_q), device=device))
    keys   .append(torch.randn((s, E_k), device=device))
    values .append(torch.randn((s, E_v), device=device))
query = torch.nested.nested_tensor(queries)
key = torch.nested.nested_tensor(keys)
value = torch.nested.nested_tensor(values)

# pad input
padded_query = torch.nested.to_padded_tensor(query, 0.0, (N, L_t, E_q))
padded_key   = torch.nested.to_padded_tensor(key, 0.0, (N, L_s, E_k))
padded_value = torch.nested.to_padded_tensor(value, 0.0, (N, L_s, E_v))

# create attention masks
attn_mask_q = torch.zeros((N, L_t), dtype=torch.bool)
attn_mask_kv = torch.zeros((N, L_s), dtype=torch.bool)

#  We need to mask out the padding entries in the attention weights.
for i, entry_length in enumerate(sentence_lengths):
    attn_mask_q[i, entry_length:] = True
    attn_mask_kv[i, entry_length:] = True

check correctness and performance

import timeit

t0 = timeit.default_timer()
out_nested = mha_nested(
    query, key, value, nheads,
    W_q, W_k, W_v, W_out,
    b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,
    dropout_p=dropout_p)

t1 = timeit.default_timer()
out_padded = mha_padded(
    padded_query, padded_key, padded_value, nheads,
    attn_mask_q, attn_mask_kv,
    W_q, W_k, W_v, W_out,
    b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,
    dropout_p=dropout_p)
t2 = timeit.default_timer()

print("nested and padded calculations differ by", (torch.nested.to_padded_tensor(out_nested, 0.0, (N, L_t, E_out)) - out_padded).abs().max().item())
print("nestedtensor multi-head attention takes", t1 - t0, "seconds")
print("padded tensor multi-head attention takes", t2 - t1, "seconds")

Although the nestedtensor version avoids wasted computation on padding, it is not faster then the equivalent padded tensor version. This is because the nestedtensor version has implemented a few of the kernels, like softmax, in a non optimal way.

There are plans to implement performance critical operations using the new Pytorch 2.0 stack For now, some performant kernels are provided for specific use cases, e.g. self-attention evaluation by multi-head attention formula.

# embeddings are assumed to be the same
E = E_total
mha_lib = torch.nn.MultiheadAttention(E, nheads, batch_first=True, device=device)
mha_lib.eval()

extract parameters for correctness check

mha_lib.in_proj_weight.requires_grad_(False)
mha_lib.in_proj_bias.requires_grad_(False)
mha_lib.out_proj.weight.requires_grad_(False)
mha_lib.out_proj.bias.requires_grad_(False)
W_q, b_q = mha_lib.in_proj_weight[: E, :], mha_lib.in_proj_bias[: E]
W_k, b_k = mha_lib.in_proj_weight[E : 2 * E, :], mha_lib.in_proj_bias[E : 2 * E]
W_v, b_v = mha_lib.in_proj_weight[2 * E :, :], mha_lib.in_proj_bias[2 * E :]
W_out, b_out = mha_lib.out_proj.weight, mha_lib.out_proj.bias

If we set need_weights to False this will enable the fast path in the library. Under the hood this will call _scaled_dot_product_attention. If your tensors are on CUDA, than a fused, efficient attention kernel will be used. For more detailed performance characteristics look at the benchmark in pytorch/benchmarks/transformer/sdp.py

with torch.inference_mode():
    t0 = timeit.default_timer()
    out_lib, out_lib_weights = mha_lib(query, query, query, need_weights=False)

    t1 = timeit.default_timer()
    padded_out = mha_padded(
        padded_query, padded_query, padded_query, nheads,
        attn_mask_q, attn_mask_q,
        W_q, W_k, W_v, W_out,
        b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,
        dropout_p=dropout_p)
    t2 = timeit.default_timer()

nested_time = t1 - t0
padded_time = t2 - t1
print("Nested and padded calculations differ by", (torch.nested.to_padded_tensor(out_lib, 0.0) - padded_out).abs().max().item())
print("Nested library multi-head attention takes", nested_time, "seconds")
print("Padded tensor multi-head attention takes", padded_time, "seconds")
print(f"Nested Speedup: {padded_time / nested_time:.3f}")

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources