-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathMLMemoryAnalysis.py
35 lines (30 loc) · 1.33 KB
/
MLMemoryAnalysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Inspired by https://pytorch.org/docs/stable/notes/faq.html#my-model-reports-cuda-runtime-error-2-out-of-memory
from collections import defaultdict
from .BaseAnalysis import BaseAnalysis
class MLMemoryAnalysis(BaseAnalysis):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.in_ctrl_flow = []
self.threshold = 3
self.memory_leak = defaultdict(lambda: 0)
self.last_opr = None
def enter_control_flow(self, dyn_ast, iid, condition):
self.last_opr = None
if (len(self.in_ctrl_flow) > 0) and (self.in_ctrl_flow[-1] != iid):
self.in_ctrl_flow.append(iid)
def exit_control_flow(self, dyn_ast, iid):
self.last_opr = None
self.in_ctrl_flow.pop()
def binary_operation(self, dyn_ast, iid, opr, left, right, res):
if (len(self.in_ctrl_flow) > 0) and right.requires_grad:
self.last_opr = iid
else:
self.last_opr = None
def write(self, dyn_ast, iid, left, right):
if (len(self.in_ctrl_flow) > 0) and right.requires_grad and (self.last_opr is not None):
cur = (iid, self.in_ctrl_flow[-1])
self.memory_leak[cur] += 1
if self.memory_leak[cur] > 3:
print('Memory issue detected')
exit(1)
self.last_opr = None