diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 403e98a094f90a..da31573332f97d 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -246,7 +246,7 @@ def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: min_index, min_memory = min(memory_map.items(), key=lambda item: item[1]) max_index, max_memory = max(memory_map.items(), key=lambda item: item[1]) - memory_map = {min_index: min_memory, max_index: max_memory} + memory_map = {'min_gpu_mem': min_memory, 'max_gpu_mem': max_memory} return memory_map