Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytorch Lightning support] added pytorch lightning #88

Merged
merged 2 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions deepview_profile/export_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import json


def convert(message):
new_message = {}
with open("message.json", "w") as fp:
json.dump(message, fp, indent=4)

new_message["ddp"] = {}
new_message["message_type"] = message["message_type"]
new_message["project_root"] = message["project_root"]
new_message["project_entry_point"] = message["project_entry_point"]

new_message["hardware_info"] = {
"hostname": message["hardware_info"]["hostname"],
"os": message["hardware_info"]["os"],
"gpus": message["hardware_info"]["gpus"],
}

new_message["throughput"] = {
"samples_per_second": message["throughput"]["samples_per_second"],
"predicted_max_samples_per_second": message["throughput"][
"predicted_max_samples_per_second"
],
"run_time_ms": (
[
message["throughput"]["run_time_ms"]["slope"],
message["throughput"]["run_time_ms"]["bias"],
]
if "run_time_ms" in message["throughput"]
else [0, 0]
),
"peak_usage_bytes": (
[
message["throughput"]["peak_usage_bytes"]["slope"],
message["throughput"]["peak_usage_bytes"]["bias"],
]
if "peak_usage_bytes" in message["throughput"]
else [0, 0]
),
"batch_size_context": None,
"can_manipulate_batch_size": False,
}

new_message["utilization"] = message["utilization"]

def fix(a):
for d in ["cpu", "gpu"]:
for s in ["Forward", "Backward"]:
if f"{d}_{s.lower()}" in a:
a[f"{d}{s}"] = a[f"{d}_{s.lower()}"]
del a[f"{d}_{s.lower()}"]
else:
a[f"{d}{s}"] = 0

if f"{d}_{s.lower()}_span" in a:
a[f"{d}{s}Span"] = a[f"{d}_{s.lower()}_span"]
del a[f"{d}_{s.lower()}_span"]
else:
a[f"{d}{s}Span"] = 0

if "children" not in a:
a["children"] = []
return

if a:
for c in a["children"]:
fix(c)

(
fix(new_message["utilization"]["rootNode"])
if new_message["utilization"].get("rootNode", None)
else None
)
try:
new_message["utilization"]["tensor_core_usage"] = message["utilization"][
"tensor_utilization"
]
except:
new_message["utilization"]["tensor_core_usage"] = 0

new_message["habitat"] = {
"predictions": [
(
[prediction["device_name"], prediction["runtime_ms"]]
if prediction["device_name"] != "unavailable"
else ["default_device", 0]
)
for prediction in message["habitat"]["predictions"]
]
}

new_message["breakdown"] = {
"peak_usage_bytes": int(message["breakdown"]["peak_usage_bytes"]),
"memory_capacity_bytes": int(message["breakdown"]["memory_capacity_bytes"]),
"iteration_run_time_ms": message["breakdown"]["iteration_run_time_ms"],
# TODO change these hardcoded numbers
"batch_size": 48,
"num_nodes_operation_tree": len(message["breakdown"]["operation_tree"]),
"num_nodes_weight_tree": 0,
"operation_tree": [
{
"name": op["name"],
"num_children": op["num_children"] if "num_children" in op else 0,
"forward_ms": op["operation"]["forward_ms"],
"backward_ms": op["operation"]["backward_ms"],
"size_bytes": (
int(op["operation"]["size_bytes"])
if "size_bytes" in op["operation"]
else 0
),
"file_refs": (
[
{
"path": "/".join(ctx["context"]["file_path"]["components"]),
"line_no": ctx["context"]["line_number"],
"run_time_ms": ctx["run_time_ms"],
"size_bytes": (
int(ctx["size_bytes"]) if "size_bytes" in ctx else 0
),
}
for ctx in op["operation"]["context_info_map"]
]
if "context_info_map" in op["operation"]
else list()
),
}
for op in message["breakdown"]["operation_tree"]
],
}

def fix_components(m):
for c in m["components"]:
if "consumption_joules" not in c:
c["consumption"] = 0
else:
c["consumption"] = c["consumption_joules"]
del c["consumption_joules"]
c["type"] = c["component_type"]
if c["type"] == "ENERGY_NVIDIA":
c["type"] = "ENERGY_GPU"
del c["component_type"]

new_message["energy"] = {
"current": {
"total_consumption": message["energy"]["total_consumption"],
"components": message["energy"]["components"],
"batch_size": 48,
},
"past_measurements": message["energy"].get("past_measurements", None),
}

fix_components(new_message["energy"]["current"])
if new_message["energy"].get("past_measurements", None):
for m in new_message["energy"]["past_measurements"]:
fix_components(m)

return new_message
92 changes: 92 additions & 0 deletions deepview_profile/pl/deepview_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Callable, Tuple

import time
import os
import json
import torch
import sys

try:
import pytorch_lightning as pl
except ImportError:
sys.exit("Please install pytorch-lightning:\nuse: pip install lightning\nExiting...")

from termcolor import colored
from deepview_profile.pl.deepview_interface import trigger_profiling


class DeepViewProfilerCallback(pl.Callback):
def __init__(self, profile_name: str):
super().__init__()
self.profiling_triggered = False
self.output_filename = f"{profile_name}_{int(time.time())}.json"

def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs,
batch,
batch_idx,
):

# only do this once
if self.profiling_triggered:
return

print(colored("DeepViewProfiler: Running profiling.", "green"))

"""
need 3 things:

input_provider: just return batch
model_provider: just return pl_module
iteration_provider: a lambda function that (a) calls pl_module.forward_step and (b) calls loss.backward
"""
initial_batch_size = batch[0].shape[0]

def input_provider(batch_size: int = initial_batch_size) -> Tuple:
model_inputs = list()
for elem in batch:
# we assume the first dimension is the batch dimension
model_inputs.append(
elem[:1].repeat([batch_size] + [1 for _ in elem.shape[1:]])
)
return (tuple(model_inputs), 0)

model_provider = lambda: pl_module

def iteration_provider(module: torch.nn.Module) -> Callable:
def iteration(*args, **kwargs):
loss = module.training_step(*args, **kwargs)
loss.backward()

return iteration

project_root = os.getcwd()

output = trigger_profiling(
project_root,
"entry_point.py",
initial_batch_size,
input_provider,
model_provider,
iteration_provider,
)

with open(self.output_filename, "w") as fp:
json.dump(output, fp, indent=4)

print(
colored(
f"DeepViewProfiler: Profiling complete! Report written to ", "green"
)
+ colored(self.output_filename, "green", attrs=["bold"])
)
print(
colored(
f"DeepViewProfiler: View your report at https://deepview.centml.ai",
"green",
)
)
self.profiling_triggered = True
141 changes: 141 additions & 0 deletions deepview_profile/pl/deepview_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import sys
from typing import Callable
import platform

from deepview_profile.analysis.session import AnalysisSession
from deepview_profile.exceptions import AnalysisError
from deepview_profile.nvml import NVML

# from deepview_profile.utils import release_memory, next_message_to_dict, files_encoded_unique
from deepview_profile.utils import release_memory, files_encoded_unique
from deepview_profile.error_printing import print_analysis_error

from google.protobuf.json_format import MessageToDict


def measure_breakdown(session, nvml):
print("analysis: running measure_breakdown()")
yield session.measure_breakdown(nvml)
release_memory()


def measure_throughput(session):
print("analysis: running measure_throughput()")
yield session.measure_throughput()
release_memory()


def habitat_predict(session):
print("analysis: running deepview_predict()")
yield session.habitat_predict()
release_memory()


def measure_utilization(session):
print("analysis: running measure_utilization()")
yield session.measure_utilization()
release_memory()


def energy_compute(session):
print("analysis: running energy_compute()")
yield session.energy_compute()
release_memory()


def ddp_analysis(session):
print("analysis: running ddp_computation()")
yield session.ddp_computation()
release_memory()


def hardware_information(nvml):
hardware_info = {
"hostname": platform.node(),
"os": " ".join(list(platform.uname())),
"gpus": nvml.get_device_names(),
}
return hardware_info


class DummyStaticAnalyzer:
def batch_size_location(self):
return None


def next_message_to_dict(a):
message = next(a)
return MessageToDict(message, preserving_proto_field_name=True)


def trigger_profiling(
project_root: str,
entry_point: str,
initial_batch_size: int,
input_provider: Callable,
model_provider: Callable,
iteration_provider: Callable,
):
try:
data = {
"analysis": {
"message_type": "analysis",
"project_root": project_root,
"project_entry_point": entry_point,
"hardware_info": {},
"throughput": {},
"breakdown": {},
"habitat": {},
"additionalProviders": "",
"energy": {},
"utilization": {},
"ddp": {},
},
"epochs": 50,
"iterations": 1000,
"encodedFiles": [],
}

session = AnalysisSession(
project_root,
entry_point,
project_root,
model_provider,
input_provider,
iteration_provider,
initial_batch_size,
DummyStaticAnalyzer(),
)
release_memory()

exclude_source = False

with NVML() as nvml:
data["analysis"]["hardware_info"] = hardware_information(nvml)
data["analysis"]["breakdown"] = next_message_to_dict(
measure_breakdown(session, nvml)
)

operation_tree = data["analysis"]["breakdown"]["operation_tree"]
if not exclude_source and operation_tree is not None:
data["encodedFiles"] = files_encoded_unique(operation_tree)

data["analysis"]["throughput"] = next_message_to_dict(
measure_throughput(session)
)
data["analysis"]["habitat"] = next_message_to_dict(habitat_predict(session))
data["analysis"]["utilization"] = next_message_to_dict(
measure_utilization(session)
)
data["analysis"]["energy"] = next_message_to_dict(energy_compute(session))
# data['analysis']['ddp'] = next_message_to_dict(ddp_analysis(session))

from deepview_profile.export_converter import convert

data["analysis"] = convert(data["analysis"])

return data

except AnalysisError as ex:
print_analysis_error(ex)
sys.exit(1)
Loading