Skip to content

Commit

Permalink
Try to make a more illustrative example for LineProfiler (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangkuiyi authored Jul 25, 2024
1 parent cf92544 commit 563cfc4
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions test/test_line_profiler.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,40 @@
import re

import numpy as np
import pytest
import torch
import numpy as np
from pytorch_memlab import (LineProfiler, clear_global_line_profiler, profile,
profile_every, set_target_gpu)

from pytorch_memlab import LineProfiler, profile, profile_every, set_target_gpu, clear_global_line_profiler

def test_display():

def work():
# comment
def main():
linear = torch.nn.Linear(100, 100).cuda()
linear_2 = torch.nn.Linear(100, 100).cuda()
linear_3 = torch.nn.Linear(100, 100).cuda()
part1()
part2()

def work_3():
def part1():
lstm = torch.nn.LSTM(1000, 1000).cuda()
subpart11()

def work_2():
# comment
def part2():
linear_2 = torch.nn.Linear(100, 100).cuda()
linear_3 = torch.nn.Linear(100, 100).cuda()

def subpart11():
linear = torch.nn.Linear(100, 100).cuda()
linear_2 = torch.nn.Linear(100, 100).cuda()
linear_3 = torch.nn.Linear(100, 100).cuda()
work_3()

with LineProfiler(work, work_2) as prof:
work()
work_2()

return prof.display()
with LineProfiler(subpart11, part2) as prof:
main()

s = str(prof.display()) # cast from line_records.RecordsDisplay
assert re.search("## .*subpart11", s)
assert "def subpart11():" in s
assert re.search("## .*part2", s)
assert "def part2():" in s


def test_line_report():
Expand Down Expand Up @@ -56,6 +64,7 @@ def work_2():
line_profiler.disable()
line_profiler.print_stats()


def test_line_report_decorator():
clear_global_line_profiler()

Expand All @@ -72,11 +81,13 @@ def work2():
linear = torch.nn.Linear(100, 100).cuda()
linear_2 = torch.nn.Linear(100, 100).cuda()
linear_3 = torch.nn.Linear(100, 100).cuda()

work()
work2()
work()
work()


def test_line_report_method():
clear_global_line_profiler()

Expand All @@ -94,6 +105,7 @@ def forward(self, inp):
inp = torch.Tensor(50, 100).cuda()
net(inp)


def test_line_report_profile():
clear_global_line_profiler()

Expand All @@ -107,6 +119,7 @@ def work():
work()
work()


def test_line_report_profile_set_gpu():
clear_global_line_profiler()

Expand All @@ -122,6 +135,7 @@ def work():
work()
work()


def test_line_report_profile_interrupt():
clear_global_line_profiler()

Expand All @@ -139,4 +153,4 @@ def work2():

work()
work2()
raise KeyboardInterrupt
raise KeyboardInterrupt

0 comments on commit 563cfc4

Please sign in to comment.