Skip to content

Commit

Permalink
Merge pull request #88 from CentML/johncalesp/add-pl
Browse files Browse the repository at this point in the history
[Pytorch Lightning support] added pytorch lightning
  • Loading branch information
jimgao1 authored Nov 26, 2024
2 parents 5d4fd07 + 346e099 commit c8b1aa5
Show file tree
Hide file tree
Showing 6 changed files with 460 additions and 10 deletions.
158 changes: 158 additions & 0 deletions deepview_profile/export_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import json


def convert(message):
new_message = {}
with open("message.json", "w") as fp:
json.dump(message, fp, indent=4)

new_message["ddp"] = {}
new_message["message_type"] = message["message_type"]
new_message["project_root"] = message["project_root"]
new_message["project_entry_point"] = message["project_entry_point"]

new_message["hardware_info"] = {
"hostname": message["hardware_info"]["hostname"],
"os": message["hardware_info"]["os"],
"gpus": message["hardware_info"]["gpus"],
}

new_message["throughput"] = {
"samples_per_second": message["throughput"]["samples_per_second"],
"predicted_max_samples_per_second": message["throughput"][
"predicted_max_samples_per_second"
],
"run_time_ms": (
[
message["throughput"]["run_time_ms"]["slope"],
message["throughput"]["run_time_ms"]["bias"],
]
if "run_time_ms" in message["throughput"]
else [0, 0]
),
"peak_usage_bytes": (
[
message["throughput"]["peak_usage_bytes"]["slope"],
message["throughput"]["peak_usage_bytes"]["bias"],
]
if "peak_usage_bytes" in message["throughput"]
else [0, 0]
),
"batch_size_context": None,
"can_manipulate_batch_size": False,
}

new_message["utilization"] = message["utilization"]

def fix(a):
for d in ["cpu", "gpu"]:
for s in ["Forward", "Backward"]:
if f"{d}_{s.lower()}" in a:
a[f"{d}{s}"] = a[f"{d}_{s.lower()}"]
del a[f"{d}_{s.lower()}"]
else:
a[f"{d}{s}"] = 0

if f"{d}_{s.lower()}_span" in a:
a[f"{d}{s}Span"] = a[f"{d}_{s.lower()}_span"]
del a[f"{d}_{s.lower()}_span"]
else:
a[f"{d}{s}Span"] = 0

if "children" not in a:
a["children"] = []
return

if a:
for c in a["children"]:
fix(c)

(
fix(new_message["utilization"]["rootNode"])
if new_message["utilization"].get("rootNode", None)
else None
)
try:
new_message["utilization"]["tensor_core_usage"] = message["utilization"][
"tensor_utilization"
]
except:
new_message["utilization"]["tensor_core_usage"] = 0

new_message["habitat"] = {
"predictions": [
(
[prediction["device_name"], prediction["runtime_ms"]]
if prediction["device_name"] != "unavailable"
else ["default_device", 0]
)
for prediction in message["habitat"]["predictions"]
]
}

new_message["breakdown"] = {
"peak_usage_bytes": int(message["breakdown"]["peak_usage_bytes"]),
"memory_capacity_bytes": int(message["breakdown"]["memory_capacity_bytes"]),
"iteration_run_time_ms": message["breakdown"]["iteration_run_time_ms"],
# TODO change these hardcoded numbers
"batch_size": 48,
"num_nodes_operation_tree": len(message["breakdown"]["operation_tree"]),
"num_nodes_weight_tree": 0,
"operation_tree": [
{
"name": op["name"],
"num_children": op["num_children"] if "num_children" in op else 0,
"forward_ms": op["operation"]["forward_ms"],
"backward_ms": op["operation"]["backward_ms"],
"size_bytes": (
int(op["operation"]["size_bytes"])
if "size_bytes" in op["operation"]
else 0
),
"file_refs": (
[
{
"path": "/".join(ctx["context"]["file_path"]["components"]),
"line_no": ctx["context"]["line_number"],
"run_time_ms": ctx["run_time_ms"],
"size_bytes": (
int(ctx["size_bytes"]) if "size_bytes" in ctx else 0
),
}
for ctx in op["operation"]["context_info_map"]
]
if "context_info_map" in op["operation"]
else list()
),
}
for op in message["breakdown"]["operation_tree"]
],
}

def fix_components(m):
for c in m["components"]:
if "consumption_joules" not in c:
c["consumption"] = 0
else:
c["consumption"] = c["consumption_joules"]
del c["consumption_joules"]
c["type"] = c["component_type"]
if c["type"] == "ENERGY_NVIDIA":
c["type"] = "ENERGY_GPU"
del c["component_type"]

new_message["energy"] = {
"current": {
"total_consumption": message["energy"]["total_consumption"],
"components": message["energy"]["components"],
"batch_size": 48,
},
"past_measurements": message["energy"].get("past_measurements", None),
}

fix_components(new_message["energy"]["current"])
if new_message["energy"].get("past_measurements", None):
for m in new_message["energy"]["past_measurements"]:
fix_components(m)

return new_message
92 changes: 92 additions & 0 deletions deepview_profile/pl/deepview_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Callable, Tuple

import time
import os
import json
import torch
import sys

try:
import pytorch_lightning as pl
except ImportError:
sys.exit("Please install pytorch-lightning:\nuse: pip install lightning\nExiting...")

from termcolor import colored
from deepview_profile.pl.deepview_interface import trigger_profiling


class DeepViewProfilerCallback(pl.Callback):
def __init__(self, profile_name: str):
super().__init__()
self.profiling_triggered = False
self.output_filename = f"{profile_name}_{int(time.time())}.json"

def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs,
batch,
batch_idx,
):

# only do this once
if self.profiling_triggered:
return

print(colored("DeepViewProfiler: Running profiling.", "green"))

"""
need 3 things:
input_provider: just return batch
model_provider: just return pl_module
iteration_provider: a lambda function that (a) calls pl_module.forward_step and (b) calls loss.backward
"""
initial_batch_size = batch[0].shape[0]

def input_provider(batch_size: int = initial_batch_size) -> Tuple:
model_inputs = list()
for elem in batch:
# we assume the first dimension is the batch dimension
model_inputs.append(
elem[:1].repeat([batch_size] + [1 for _ in elem.shape[1:]])
)
return (tuple(model_inputs), 0)

model_provider = lambda: pl_module

def iteration_provider(module: torch.nn.Module) -> Callable:
def iteration(*args, **kwargs):
loss = module.training_step(*args, **kwargs)
loss.backward()

return iteration

project_root = os.getcwd()

output = trigger_profiling(
project_root,
"entry_point.py",
initial_batch_size,
input_provider,
model_provider,
iteration_provider,
)

with open(self.output_filename, "w") as fp:
json.dump(output, fp, indent=4)

print(
colored(
f"DeepViewProfiler: Profiling complete! Report written to ", "green"
)
+ colored(self.output_filename, "green", attrs=["bold"])
)
print(
colored(
f"DeepViewProfiler: View your report at https://deepview.centml.ai",
"green",
)
)
self.profiling_triggered = True
141 changes: 141 additions & 0 deletions deepview_profile/pl/deepview_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import sys
from typing import Callable
import platform

from deepview_profile.analysis.session import AnalysisSession
from deepview_profile.exceptions import AnalysisError
from deepview_profile.nvml import NVML

# from deepview_profile.utils import release_memory, next_message_to_dict, files_encoded_unique
from deepview_profile.utils import release_memory, files_encoded_unique
from deepview_profile.error_printing import print_analysis_error

from google.protobuf.json_format import MessageToDict


def measure_breakdown(session, nvml):
print("analysis: running measure_breakdown()")
yield session.measure_breakdown(nvml)
release_memory()


def measure_throughput(session):
print("analysis: running measure_throughput()")
yield session.measure_throughput()
release_memory()


def habitat_predict(session):
print("analysis: running deepview_predict()")
yield session.habitat_predict()
release_memory()


def measure_utilization(session):
print("analysis: running measure_utilization()")
yield session.measure_utilization()
release_memory()


def energy_compute(session):
print("analysis: running energy_compute()")
yield session.energy_compute()
release_memory()


def ddp_analysis(session):
print("analysis: running ddp_computation()")
yield session.ddp_computation()
release_memory()


def hardware_information(nvml):
hardware_info = {
"hostname": platform.node(),
"os": " ".join(list(platform.uname())),
"gpus": nvml.get_device_names(),
}
return hardware_info


class DummyStaticAnalyzer:
def batch_size_location(self):
return None


def next_message_to_dict(a):
message = next(a)
return MessageToDict(message, preserving_proto_field_name=True)


def trigger_profiling(
project_root: str,
entry_point: str,
initial_batch_size: int,
input_provider: Callable,
model_provider: Callable,
iteration_provider: Callable,
):
try:
data = {
"analysis": {
"message_type": "analysis",
"project_root": project_root,
"project_entry_point": entry_point,
"hardware_info": {},
"throughput": {},
"breakdown": {},
"habitat": {},
"additionalProviders": "",
"energy": {},
"utilization": {},
"ddp": {},
},
"epochs": 50,
"iterations": 1000,
"encodedFiles": [],
}

session = AnalysisSession(
project_root,
entry_point,
project_root,
model_provider,
input_provider,
iteration_provider,
initial_batch_size,
DummyStaticAnalyzer(),
)
release_memory()

exclude_source = False

with NVML() as nvml:
data["analysis"]["hardware_info"] = hardware_information(nvml)
data["analysis"]["breakdown"] = next_message_to_dict(
measure_breakdown(session, nvml)
)

operation_tree = data["analysis"]["breakdown"]["operation_tree"]
if not exclude_source and operation_tree is not None:
data["encodedFiles"] = files_encoded_unique(operation_tree)

data["analysis"]["throughput"] = next_message_to_dict(
measure_throughput(session)
)
data["analysis"]["habitat"] = next_message_to_dict(habitat_predict(session))
data["analysis"]["utilization"] = next_message_to_dict(
measure_utilization(session)
)
data["analysis"]["energy"] = next_message_to_dict(energy_compute(session))
# data['analysis']['ddp'] = next_message_to_dict(ddp_analysis(session))

from deepview_profile.export_converter import convert

data["analysis"] = convert(data["analysis"])

return data

except AnalysisError as ex:
print_analysis_error(ex)
sys.exit(1)
Loading

0 comments on commit c8b1aa5

Please sign in to comment.