-
Notifications
You must be signed in to change notification settings - Fork 9.5k
/
replace_op.py
63 lines (52 loc) · 2.38 KB
/
replace_op.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from torch.fx import symbolic_trace
import operator
"""
How to Replace One Op With Another
1. Iterate through all Nodes in your GraphModule's Graph.
2. Determine if the current Node should be replaced. (Suggested: match
on the Node's ``target`` attribute).
3. Create a replacement Node and add it to the Graph.
4. Use the FX built-in ``replace_all_uses_with`` to replace all uses of
the current Node with the replacement.
5. Delete the old Node from the graph.
6. Call ``recompile`` on the GraphModule. This updates the generated
Python code to reflect the new Graph state.
Currently, FX does not provide any way to guarantee that replaced
operators are syntactically valid. It's up to the user to confirm that
any new operators will work with the existing operands.
The following code demonstrates an example of replacing any instance of
addition with a bitwise AND.
To examine how the Graph evolves during op replacement, add the
statement `print(traced.graph)` after the line you want to inspect.
Alternatively, call `traced.graph.print_tabular()` to see the IR in a
tabular format.
"""
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y, torch.add(x, y), x.add(y)
# Symbolically trace an instance of the module
traced = symbolic_trace(M())
# As demonstrated in the above example, there are several different ways
# to denote addition. The possible cases are:
# 1. `x + y` - A `call_function` Node with target `operator.add`.
# We can match for equality on that `operator.add` directly.
# 2. `torch.add(x, y)` - A `call_function` Node with target
# `torch.add`. Similarly, we can match this function directly.
# 3. `x.add(y)` - The Tensor method call, whose target we can match
# as a string.
patterns = set([operator.add, torch.add, "add"])
# Go through all the nodes in the Graph
for n in traced.graph.nodes:
# If the target matches one of the patterns
if any(n.target == pattern for pattern in patterns):
# Set the insert point, add the new node, and replace all uses
# of `n` with the new node
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
# Remove the old node from the graph
traced.graph.erase_node(n)
# Don't forget to recompile!
traced.recompile()