Skip to content

alpha0422/torch-graph

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torch-graph

Simple PyTorch graph capturing.

Instructions

Please install graphviz first:

$ apt-get install graphviz

Clone and install this package:

$ pip install .

Examples:

from torchgraph import dispatch_capture, aot_capture, compile_capture

import torch
import torch.nn as nn
model = nn.Sequential(
    nn.Conv2d(16, 32, 3),
    nn.BatchNorm2d(32),
    nn.SiLU(),
).cuda()
x = torch.randn((2, 16, 8, 8), requires_grad=True, device="cuda")

# Capture joint forward and backward graph through dispatch
dispatch_capture(model, x)

# Capture separate forward and backward graphs through PyTorch AOTAutograd
aot_capture(model, x)

# Capture forward graph through PyTorch compile
compile_capture(model, x)

You'll find the captured graphs in .svg format under current folder.

About

Simple PyTorch graph capturing.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages