-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #88 from CentML/johncalesp/add-pl
[Pytorch Lightning support] added pytorch lightning
- Loading branch information
Showing
6 changed files
with
460 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import json | ||
|
||
|
||
def convert(message): | ||
new_message = {} | ||
with open("message.json", "w") as fp: | ||
json.dump(message, fp, indent=4) | ||
|
||
new_message["ddp"] = {} | ||
new_message["message_type"] = message["message_type"] | ||
new_message["project_root"] = message["project_root"] | ||
new_message["project_entry_point"] = message["project_entry_point"] | ||
|
||
new_message["hardware_info"] = { | ||
"hostname": message["hardware_info"]["hostname"], | ||
"os": message["hardware_info"]["os"], | ||
"gpus": message["hardware_info"]["gpus"], | ||
} | ||
|
||
new_message["throughput"] = { | ||
"samples_per_second": message["throughput"]["samples_per_second"], | ||
"predicted_max_samples_per_second": message["throughput"][ | ||
"predicted_max_samples_per_second" | ||
], | ||
"run_time_ms": ( | ||
[ | ||
message["throughput"]["run_time_ms"]["slope"], | ||
message["throughput"]["run_time_ms"]["bias"], | ||
] | ||
if "run_time_ms" in message["throughput"] | ||
else [0, 0] | ||
), | ||
"peak_usage_bytes": ( | ||
[ | ||
message["throughput"]["peak_usage_bytes"]["slope"], | ||
message["throughput"]["peak_usage_bytes"]["bias"], | ||
] | ||
if "peak_usage_bytes" in message["throughput"] | ||
else [0, 0] | ||
), | ||
"batch_size_context": None, | ||
"can_manipulate_batch_size": False, | ||
} | ||
|
||
new_message["utilization"] = message["utilization"] | ||
|
||
def fix(a): | ||
for d in ["cpu", "gpu"]: | ||
for s in ["Forward", "Backward"]: | ||
if f"{d}_{s.lower()}" in a: | ||
a[f"{d}{s}"] = a[f"{d}_{s.lower()}"] | ||
del a[f"{d}_{s.lower()}"] | ||
else: | ||
a[f"{d}{s}"] = 0 | ||
|
||
if f"{d}_{s.lower()}_span" in a: | ||
a[f"{d}{s}Span"] = a[f"{d}_{s.lower()}_span"] | ||
del a[f"{d}_{s.lower()}_span"] | ||
else: | ||
a[f"{d}{s}Span"] = 0 | ||
|
||
if "children" not in a: | ||
a["children"] = [] | ||
return | ||
|
||
if a: | ||
for c in a["children"]: | ||
fix(c) | ||
|
||
( | ||
fix(new_message["utilization"]["rootNode"]) | ||
if new_message["utilization"].get("rootNode", None) | ||
else None | ||
) | ||
try: | ||
new_message["utilization"]["tensor_core_usage"] = message["utilization"][ | ||
"tensor_utilization" | ||
] | ||
except: | ||
new_message["utilization"]["tensor_core_usage"] = 0 | ||
|
||
new_message["habitat"] = { | ||
"predictions": [ | ||
( | ||
[prediction["device_name"], prediction["runtime_ms"]] | ||
if prediction["device_name"] != "unavailable" | ||
else ["default_device", 0] | ||
) | ||
for prediction in message["habitat"]["predictions"] | ||
] | ||
} | ||
|
||
new_message["breakdown"] = { | ||
"peak_usage_bytes": int(message["breakdown"]["peak_usage_bytes"]), | ||
"memory_capacity_bytes": int(message["breakdown"]["memory_capacity_bytes"]), | ||
"iteration_run_time_ms": message["breakdown"]["iteration_run_time_ms"], | ||
# TODO change these hardcoded numbers | ||
"batch_size": 48, | ||
"num_nodes_operation_tree": len(message["breakdown"]["operation_tree"]), | ||
"num_nodes_weight_tree": 0, | ||
"operation_tree": [ | ||
{ | ||
"name": op["name"], | ||
"num_children": op["num_children"] if "num_children" in op else 0, | ||
"forward_ms": op["operation"]["forward_ms"], | ||
"backward_ms": op["operation"]["backward_ms"], | ||
"size_bytes": ( | ||
int(op["operation"]["size_bytes"]) | ||
if "size_bytes" in op["operation"] | ||
else 0 | ||
), | ||
"file_refs": ( | ||
[ | ||
{ | ||
"path": "/".join(ctx["context"]["file_path"]["components"]), | ||
"line_no": ctx["context"]["line_number"], | ||
"run_time_ms": ctx["run_time_ms"], | ||
"size_bytes": ( | ||
int(ctx["size_bytes"]) if "size_bytes" in ctx else 0 | ||
), | ||
} | ||
for ctx in op["operation"]["context_info_map"] | ||
] | ||
if "context_info_map" in op["operation"] | ||
else list() | ||
), | ||
} | ||
for op in message["breakdown"]["operation_tree"] | ||
], | ||
} | ||
|
||
def fix_components(m): | ||
for c in m["components"]: | ||
if "consumption_joules" not in c: | ||
c["consumption"] = 0 | ||
else: | ||
c["consumption"] = c["consumption_joules"] | ||
del c["consumption_joules"] | ||
c["type"] = c["component_type"] | ||
if c["type"] == "ENERGY_NVIDIA": | ||
c["type"] = "ENERGY_GPU" | ||
del c["component_type"] | ||
|
||
new_message["energy"] = { | ||
"current": { | ||
"total_consumption": message["energy"]["total_consumption"], | ||
"components": message["energy"]["components"], | ||
"batch_size": 48, | ||
}, | ||
"past_measurements": message["energy"].get("past_measurements", None), | ||
} | ||
|
||
fix_components(new_message["energy"]["current"]) | ||
if new_message["energy"].get("past_measurements", None): | ||
for m in new_message["energy"]["past_measurements"]: | ||
fix_components(m) | ||
|
||
return new_message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from typing import Callable, Tuple | ||
|
||
import time | ||
import os | ||
import json | ||
import torch | ||
import sys | ||
|
||
try: | ||
import pytorch_lightning as pl | ||
except ImportError: | ||
sys.exit("Please install pytorch-lightning:\nuse: pip install lightning\nExiting...") | ||
|
||
from termcolor import colored | ||
from deepview_profile.pl.deepview_interface import trigger_profiling | ||
|
||
|
||
class DeepViewProfilerCallback(pl.Callback): | ||
def __init__(self, profile_name: str): | ||
super().__init__() | ||
self.profiling_triggered = False | ||
self.output_filename = f"{profile_name}_{int(time.time())}.json" | ||
|
||
def on_train_batch_end( | ||
self, | ||
trainer: pl.Trainer, | ||
pl_module: pl.LightningModule, | ||
outputs, | ||
batch, | ||
batch_idx, | ||
): | ||
|
||
# only do this once | ||
if self.profiling_triggered: | ||
return | ||
|
||
print(colored("DeepViewProfiler: Running profiling.", "green")) | ||
|
||
""" | ||
need 3 things: | ||
input_provider: just return batch | ||
model_provider: just return pl_module | ||
iteration_provider: a lambda function that (a) calls pl_module.forward_step and (b) calls loss.backward | ||
""" | ||
initial_batch_size = batch[0].shape[0] | ||
|
||
def input_provider(batch_size: int = initial_batch_size) -> Tuple: | ||
model_inputs = list() | ||
for elem in batch: | ||
# we assume the first dimension is the batch dimension | ||
model_inputs.append( | ||
elem[:1].repeat([batch_size] + [1 for _ in elem.shape[1:]]) | ||
) | ||
return (tuple(model_inputs), 0) | ||
|
||
model_provider = lambda: pl_module | ||
|
||
def iteration_provider(module: torch.nn.Module) -> Callable: | ||
def iteration(*args, **kwargs): | ||
loss = module.training_step(*args, **kwargs) | ||
loss.backward() | ||
|
||
return iteration | ||
|
||
project_root = os.getcwd() | ||
|
||
output = trigger_profiling( | ||
project_root, | ||
"entry_point.py", | ||
initial_batch_size, | ||
input_provider, | ||
model_provider, | ||
iteration_provider, | ||
) | ||
|
||
with open(self.output_filename, "w") as fp: | ||
json.dump(output, fp, indent=4) | ||
|
||
print( | ||
colored( | ||
f"DeepViewProfiler: Profiling complete! Report written to ", "green" | ||
) | ||
+ colored(self.output_filename, "green", attrs=["bold"]) | ||
) | ||
print( | ||
colored( | ||
f"DeepViewProfiler: View your report at https://deepview.centml.ai", | ||
"green", | ||
) | ||
) | ||
self.profiling_triggered = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import sys | ||
from typing import Callable | ||
import platform | ||
|
||
from deepview_profile.analysis.session import AnalysisSession | ||
from deepview_profile.exceptions import AnalysisError | ||
from deepview_profile.nvml import NVML | ||
|
||
# from deepview_profile.utils import release_memory, next_message_to_dict, files_encoded_unique | ||
from deepview_profile.utils import release_memory, files_encoded_unique | ||
from deepview_profile.error_printing import print_analysis_error | ||
|
||
from google.protobuf.json_format import MessageToDict | ||
|
||
|
||
def measure_breakdown(session, nvml): | ||
print("analysis: running measure_breakdown()") | ||
yield session.measure_breakdown(nvml) | ||
release_memory() | ||
|
||
|
||
def measure_throughput(session): | ||
print("analysis: running measure_throughput()") | ||
yield session.measure_throughput() | ||
release_memory() | ||
|
||
|
||
def habitat_predict(session): | ||
print("analysis: running deepview_predict()") | ||
yield session.habitat_predict() | ||
release_memory() | ||
|
||
|
||
def measure_utilization(session): | ||
print("analysis: running measure_utilization()") | ||
yield session.measure_utilization() | ||
release_memory() | ||
|
||
|
||
def energy_compute(session): | ||
print("analysis: running energy_compute()") | ||
yield session.energy_compute() | ||
release_memory() | ||
|
||
|
||
def ddp_analysis(session): | ||
print("analysis: running ddp_computation()") | ||
yield session.ddp_computation() | ||
release_memory() | ||
|
||
|
||
def hardware_information(nvml): | ||
hardware_info = { | ||
"hostname": platform.node(), | ||
"os": " ".join(list(platform.uname())), | ||
"gpus": nvml.get_device_names(), | ||
} | ||
return hardware_info | ||
|
||
|
||
class DummyStaticAnalyzer: | ||
def batch_size_location(self): | ||
return None | ||
|
||
|
||
def next_message_to_dict(a): | ||
message = next(a) | ||
return MessageToDict(message, preserving_proto_field_name=True) | ||
|
||
|
||
def trigger_profiling( | ||
project_root: str, | ||
entry_point: str, | ||
initial_batch_size: int, | ||
input_provider: Callable, | ||
model_provider: Callable, | ||
iteration_provider: Callable, | ||
): | ||
try: | ||
data = { | ||
"analysis": { | ||
"message_type": "analysis", | ||
"project_root": project_root, | ||
"project_entry_point": entry_point, | ||
"hardware_info": {}, | ||
"throughput": {}, | ||
"breakdown": {}, | ||
"habitat": {}, | ||
"additionalProviders": "", | ||
"energy": {}, | ||
"utilization": {}, | ||
"ddp": {}, | ||
}, | ||
"epochs": 50, | ||
"iterations": 1000, | ||
"encodedFiles": [], | ||
} | ||
|
||
session = AnalysisSession( | ||
project_root, | ||
entry_point, | ||
project_root, | ||
model_provider, | ||
input_provider, | ||
iteration_provider, | ||
initial_batch_size, | ||
DummyStaticAnalyzer(), | ||
) | ||
release_memory() | ||
|
||
exclude_source = False | ||
|
||
with NVML() as nvml: | ||
data["analysis"]["hardware_info"] = hardware_information(nvml) | ||
data["analysis"]["breakdown"] = next_message_to_dict( | ||
measure_breakdown(session, nvml) | ||
) | ||
|
||
operation_tree = data["analysis"]["breakdown"]["operation_tree"] | ||
if not exclude_source and operation_tree is not None: | ||
data["encodedFiles"] = files_encoded_unique(operation_tree) | ||
|
||
data["analysis"]["throughput"] = next_message_to_dict( | ||
measure_throughput(session) | ||
) | ||
data["analysis"]["habitat"] = next_message_to_dict(habitat_predict(session)) | ||
data["analysis"]["utilization"] = next_message_to_dict( | ||
measure_utilization(session) | ||
) | ||
data["analysis"]["energy"] = next_message_to_dict(energy_compute(session)) | ||
# data['analysis']['ddp'] = next_message_to_dict(ddp_analysis(session)) | ||
|
||
from deepview_profile.export_converter import convert | ||
|
||
data["analysis"] = convert(data["analysis"]) | ||
|
||
return data | ||
|
||
except AnalysisError as ex: | ||
print_analysis_error(ex) | ||
sys.exit(1) |
Oops, something went wrong.