From 4c2005d514a4d7a1bbca81d5eaf5007347b637b9 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 13 Apr 2021 22:17:38 +0100 Subject: [PATCH] Ci fix for pytorch profiler log to file tests --- tests/test_profiler.py | 4 +++- tests/trainer/test_trainer.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 6abcf17a04893..267a2e4bce1e1 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -174,8 +174,10 @@ def test_advanced_profiler_describe(tmpdir, advanced_profiler): # record at least one event with advanced_profiler.profile("test"): pass - # log to stdout and print to file + # logs to output file advanced_profiler.describe() + # ensures file is flushed before reading + advanced_profiler.output_file.close() data = Path(advanced_profiler.output_fname).read_text() assert len(data) > 0 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b96f7310b180a..4957f4093e20e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1548,8 +1548,10 @@ def test_pytorch_profiler_describe(pytorch_profiler): with pytorch_profiler.profile("test_step"): pass - # log to stdout and print to file + # logs to output file pytorch_profiler.describe() + # ensures file is flushed before reading + pytorch_profiler.output_file.close() data = Path(pytorch_profiler.output_fname).read_text() assert len(data) > 0