Note
Click here to download the full example code
torch.compile Tutorial¶
Author: William Wen
torch.compile
is the latest method to speed up your PyTorch code!
torch.compile
makes PyTorch code run faster by
JIT-compiling PyTorch code into optimized kernels,
all while requiring minimal code changes.
In this tutorial, we cover basic torch.compile
usage,
and demonstrate the advantages of torch.compile
over
previous PyTorch compiler solutions, such as
TorchScript and
FX Tracing.
Contents
Basic Usage
Demonstrating Speedups
Comparison to TorchScript and FX Tracing
TorchDynamo and FX Graphs
Conclusion
Required pip Dependencies
torch >= 2.0
torchvision
numpy
scipy
tabulate
NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in order to reproduce the speedup numbers shown below and documented elsewhere.
import torch
import warnings
gpu_ok = False
if torch.cuda.is_available():
device_cap = torch.cuda.get_device_capability()
if device_cap in ((7, 0), (8, 0), (9, 0)):
gpu_ok = True
if not gpu_ok:
warnings.warn(
"GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
"than expected."
)
Basic Usage¶
torch.compile
is included in the latest PyTorch..
Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly
binary. If Triton is still missing, try installing torchtriton
via pip
(pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"
for CUDA 11.7).
Arbitrary Python functions can be optimized by passing the callable to
torch.compile
. We can then call the returned optimized
function in place of the original function.
def foo(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
Alternatively, we can decorate the function.
@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
We can also optimize torch.nn.Module
instances.
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x):
return torch.nn.functional.relu(self.lin(x))
mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))
Demonstrating Speedups¶
Let’s now demonstrate that using torch.compile
can speed
up real models. We will compare standard eager mode and
torch.compile
by evaluating and training a torchvision
model on random data.
Before we start, we need to define some utility functions.
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
return (
torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
torch.randint(1000, (b,)).cuda(),
)
N_ITERS = 10
from torchvision.models import densenet121
def init_model():
return densenet121().to(torch.float32).cuda()
First, let’s compare inference.
Note that in the call to torch.compile
, we have have the additional
mode
argument, which we will discuss below.
def evaluate(mod, inp):
with torch.no_grad():
return mod(inp)
model = init_model()
# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()
evaluate_opt = torch.compile(evaluate, mode="reduce-overhead")
inp = generate_data(16)[0]
print("eager:", timed(lambda: evaluate(model, inp))[1])
print("compile:", timed(lambda: evaluate_opt(model, inp))[1])
Notice that torch.compile
takes a lot longer to complete
compared to eager. This is because torch.compile
compiles
the model into optimized kernels as it executes. In our example, the
structure of the model doesn’t change, and so recompilation is not
needed. So if we run our optimized model several more times, we should
see a significant improvement compared to eager.
eager_times = []
for i in range(N_ITERS):
inp = generate_data(16)[0]
_, eager_time = timed(lambda: evaluate(model, inp))
eager_times.append(eager_time)
print(f"eager eval time {i}: {eager_time}")
print("~" * 10)
compile_times = []
for i in range(N_ITERS):
inp = generate_data(16)[0]
_, compile_time = timed(lambda: evaluate_opt(model, inp))
compile_times.append(compile_time)
print(f"compile eval time {i}: {compile_time}")
print("~" * 10)
import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)
And indeed, we can see that running our model with torch.compile
results in a significant speedup. Speedup mainly comes from reducing Python overhead and
GPU read/writes, and so the observed speedup may vary on factors such as model
architecture and batch size. For example, if a model’s architecture is simple
and the amount of data is large, then the bottleneck would be
GPU compute and the observed speedup may be less significant.
You may also see different speedup results depending on the chosen mode
argument. Since our model and data are small, we want to reduce overhead as
much as possible, and so we chose "reduce-overhead"
. For your own models,
you may need to experiment with different modes to maximize speedup. You can
read more about modes here.
For general PyTorch benchmarking, you can try using torch.utils.benchmark
instead of the timed
function we defined above. We wrote our own timing function in this tutorial to show
torch.compile
’s compilation latency.
Now, let’s consider comparing training.
model = init_model()
opt = torch.optim.Adam(model.parameters())
def train(mod, data):
opt.zero_grad(True)
pred = mod(data[0])
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
loss.backward()
opt.step()
eager_times = []
for i in range(N_ITERS):
inp = generate_data(16)
_, eager_time = timed(lambda: train(model, inp))
eager_times.append(eager_time)
print(f"eager train time {i}: {eager_time}")
print("~" * 10)
model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")
compile_times = []
for i in range(N_ITERS):
inp = generate_data(16)
_, compile_time = timed(lambda: train_opt(model, inp))
compile_times.append(compile_time)
print(f"compile train time {i}: {compile_time}")
print("~" * 10)
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)
Again, we can see that torch.compile
takes longer in the first
iteration, as it must compile the model, but in subsequent iterations, we see
significant speedups compared to eager.
Comparison to TorchScript and FX Tracing¶
We have seen that torch.compile
can speed up PyTorch code.
Why else should we use torch.compile
over existing PyTorch
compiler solutions, such as TorchScript or FX Tracing? Primarily, the
advantage of torch.compile
lies in its ability to handle
arbitrary Python code with minimal changes to existing code.
One case that torch.compile
can handle that other compiler
solutions struggle with is data-dependent control flow (the
if x.sum() < 0:
line below).
def f1(x, y):
if x.sum() < 0:
return -y
return y
# Test that `fn1` and `fn2` return the same result, given
# the same arguments `args`. Typically, `fn1` will be an eager function
# while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph).
def test_fns(fn1, fn2, args):
out1 = fn1(*args)
out2 = fn2(*args)
return torch.allclose(out1, out2)
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
TorchScript tracing f1
results in
silently incorrect results, since only the actual control flow path
is traced.
traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))
FX tracing f1
results in an error due to the presence of
data-dependent control flow.
import traceback as tb
try:
torch.fx.symbolic_trace(f1)
except:
tb.print_exc()
If we provide a value for x
as we try to FX trace f1
, then
we run into the same problem as TorchScript tracing, as the data-dependent
control flow is removed in the traced function.
fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1})
print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2)))
print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2)))
Now we can see that torch.compile
correctly handles
data-dependent control flow.
# Reset since we are using a different mode.
torch._dynamo.reset()
compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)
TorchScript scripting can handle data-dependent control flow, but this solution comes with its own set of problems. Namely, TorchScript scripting can require major code changes and will raise errors when unsupported Python is used.
In the example below, we forget TorchScript type annotations and we receive
a TorchScript error because the input type for argument y
, an int
,
does not match with the default argument type, torch.Tensor
.
def f2(x, y):
return x + y
inp1 = torch.randn(5, 5)
inp2 = 3
script_f2 = torch.jit.script(f2)
try:
script_f2(inp1, inp2)
except:
tb.print_exc()
However, torch.compile
is easily able to handle f2
.
compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)
Another case that torch.compile
handles well compared to
previous compilers solutions is the usage of non-PyTorch functions.
import scipy
def f3(x):
x = x * 2
x = scipy.fft.dct(x.numpy())
x = torch.from_numpy(x)
x = x * 2
return x
TorchScript tracing treats results from non-PyTorch function calls as constants, and so our results can be silently wrong.
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f3 = torch.jit.trace(f3, (inp1,))
print("traced 3:", test_fns(f3, traced_f3, (inp2,)))
TorchScript scripting and FX tracing disallow non-PyTorch function calls.
try:
torch.jit.script(f3)
except:
tb.print_exc()
try:
torch.fx.symbolic_trace(f3)
except:
tb.print_exc()
In comparison, torch.compile
is easily able to handle
the non-PyTorch function call.
compile_f3 = torch.compile(f3)
print("compile 3:", test_fns(f3, compile_f3, (inp2,)))
TorchDynamo and FX Graphs¶
One important component of torch.compile
is TorchDynamo.
TorchDynamo is responsible for JIT compiling arbitrary Python code into
FX graphs, which can
then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode
during runtime and detecting calls to PyTorch operations.
Normally, TorchInductor, another component of torch.compile
,
further compiles the FX graphs into optimized kernels,
but TorchDynamo allows for different backends to be used. In order to inspect
the FX graphs that TorchDynamo outputs, let us create a custom backend that
outputs the FX graph and simply returns the graph’s unoptimized forward method.
from typing import List
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("custom backend called with FX graph:")
gm.graph.print_tabular()
return gm.forward
# Reset since we are using a different backend.
torch._dynamo.reset()
opt_model = torch.compile(init_model(), backend=custom_backend)
opt_model(generate_data(16)[0])
Using our custom backend, we can now see how TorchDynamo is able to handle
data-dependent control flow. Consider the function below, where the line
if b.sum() < 0
is the source of data-dependent control flow.
def bar(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
opt_bar = torch.compile(bar, backend=custom_backend)
inp1 = torch.randn(10)
inp2 = torch.randn(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
The output reveals that TorchDynamo extracted 3 different FX graphs corresponding the following code (order may differ from the output above):
x = a / (torch.abs(a) + 1)
b = b * -1; return x * b
return x * b
When TorchDynamo encounters unsupported Python features, such as data-dependent control flow, it breaks the computation graph, lets the default Python interpreter handle the unsupported code, then resumes capturing the graph.
Let’s investigate by example how TorchDynamo would step through bar
.
If b.sum() < 0
, then TorchDynamo would run graph 1, let
Python determine the result of the conditional, then run
graph 2. On the other hand, if not b.sum() < 0
, then TorchDynamo
would run graph 1, let Python determine the result of the conditional, then
run graph 3.
This highlights a major difference between TorchDynamo and previous PyTorch compiler solutions. When encountering unsupported Python features, previous solutions either raise an error or silently fail. TorchDynamo, on the other hand, will break the computation graph.
We can see where TorchDynamo breaks the graph by using torch._dynamo.explain
:
# Reset since we are using a different backend.
torch._dynamo.reset()
explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose = torch._dynamo.explain(
bar, torch.randn(10), torch.randn(10)
)
print(explanation_verbose)
In order to maximize speedup, graph breaks should be limited.
We can force TorchDynamo to raise an error upon the first graph
break encountered by using fullgraph=True
:
opt_bar = torch.compile(bar, fullgraph=True)
try:
opt_bar(torch.randn(10), torch.randn(10))
except:
tb.print_exc()
And below, we demonstrate that TorchDynamo does not break the graph on the model we used above for demonstrating speedups.
opt_model = torch.compile(init_model(), fullgraph=True)
print(opt_model(generate_data(16)[0]))
Finally, if we simply want TorchDynamo to output the FX graph for export,
we can use torch._dynamo.export
. Note that torch._dynamo.export
, like
fullgraph=True
, raises an error if TorchDynamo breaks the graph.
try:
torch._dynamo.export(bar, torch.randn(10), torch.randn(10))
except:
tb.print_exc()
model_exp = torch._dynamo.export(init_model(), generate_data(16)[0])
print(model_exp[0](generate_data(16)[0]))
Conclusion¶
In this tutorial, we introduced torch.compile
by covering
basic usage, demonstrating speedups over eager mode, comparing to previous
PyTorch compiler solutions, and briefly investigating TorchDynamo and its interactions
with FX graphs. We hope that you will give torch.compile
a try!
Total running time of the script: ( 0 minutes 0.000 seconds)