Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to ruff #38

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/workflows/python_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@ jobs:

- name: Install dependencies
run: |
pip install flake8
pip install pyre-check
pip install -r requirements-dev.txt
pip install .

- name: Run Flake8
run: flake8 .
- name: Run ruff
run: ruff format --check --diff .

- name: Run Pyre Check
run: pyre check
63 changes: 18 additions & 45 deletions et_converter/et_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from .text2chakra_converter import Text2ChakraConverter
from .pytorch2chakra_converter import PyTorch2ChakraConverter


def get_logger(log_filename: str) -> logging.Logger:
formatter = logging.Formatter(
"%(levelname)s [%(asctime)s] %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p")
formatter = logging.Formatter("%(levelname)s [%(asctime)s] %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p")

file_handler = FileHandler(log_filename, mode="w")
file_handler.setLevel(logging.DEBUG)
Expand All @@ -29,44 +28,23 @@ def get_logger(log_filename: str) -> logging.Logger:

return logger


def main() -> None:
parser = argparse.ArgumentParser(
description="Execution Trace Converter")
parser.add_argument(
"--input_type",
type=str,
default=None,
required=True,
help="Input execution trace type")
parser = argparse.ArgumentParser(description="Execution Trace Converter")
parser.add_argument("--input_type", type=str, default=None, required=True, help="Input execution trace type")
parser.add_argument(
"--input_filename",
type=str,
default=None,
required=True,
help="Input execution trace filename")
"--input_filename", type=str, default=None, required=True, help="Input execution trace filename"
)
parser.add_argument(
"--output_filename",
type=str,
default=None,
required=True,
help="Output Chakra execution trace filename")
"--output_filename", type=str, default=None, required=True, help="Output Chakra execution trace filename"
)
parser.add_argument(
"--num_npus",
type=int,
default=None,
required="Text" in sys.argv,
help="Number of NPUs in a system")
"--num_npus", type=int, default=None, required="Text" in sys.argv, help="Number of NPUs in a system"
)
parser.add_argument(
"--num_passes",
type=int,
default=None,
required="Text" in sys.argv,
help="Number of training passes")
parser.add_argument(
"--log_filename",
type=str,
default="debug.log",
help="Log filename")
"--num_passes", type=int, default=None, required="Text" in sys.argv, help="Number of training passes"
)
parser.add_argument("--log_filename", type=str, default="debug.log", help="Log filename")
args = parser.parse_args()

logger = get_logger(args.log_filename)
Expand All @@ -75,17 +53,11 @@ def main() -> None:
try:
if args.input_type == "Text":
converter = Text2ChakraConverter(
args.input_filename,
args.output_filename,
args.num_npus,
args.num_passes,
logger)
args.input_filename, args.output_filename, args.num_npus, args.num_passes, logger
)
converter.convert()
elif args.input_type == "PyTorch":
converter = PyTorch2ChakraConverter(
args.input_filename,
args.output_filename,
logger)
converter = PyTorch2ChakraConverter(args.input_filename, args.output_filename, logger)
converter.convert()
else:
logger.error(f"{args.input_type} unsupported")
Expand All @@ -95,5 +67,6 @@ def main() -> None:
logger.debug(traceback.format_exc())
sys.exit(1)


if __name__ == "__main__":
main()
117 changes: 53 additions & 64 deletions et_converter/pytorch2chakra_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,7 @@ class PyTorch2ChakraConverter:
dependencies.
"""

def __init__(
self,
input_filename: str,
output_filename: str,
logger: logging.Logger
) -> None:
def __init__(self, input_filename: str, output_filename: str, logger: logging.Logger) -> None:
"""
Initializes the PyTorch to Chakra converter. It sets up necessary
attributes and prepares the environment for the conversion process.
Expand Down Expand Up @@ -157,8 +152,9 @@ def convert(self) -> None:
self.open_chakra_execution_trace()

for pytorch_nid, pytorch_node in self.pytorch_nodes.items():
if (pytorch_node.get_op_type() == PyTorchNodeType.CPU_OP)\
or (pytorch_node.get_op_type() == PyTorchNodeType.LABEL):
if (pytorch_node.get_op_type() == PyTorchNodeType.CPU_OP) or (
pytorch_node.get_op_type() == PyTorchNodeType.LABEL
):
chakra_node = self.convert_to_chakra_node(pytorch_node)
self.chakra_nodes[chakra_node.id] = chakra_node

Expand All @@ -167,11 +163,12 @@ def convert(self) -> None:

if chakra_node.type == COMM_COLL_NODE:
collective_comm_type = self.get_collective_comm_type(pytorch_node.name)
chakra_gpu_node.attr.extend([
ChakraAttr(name="comm_type",
int64_val=collective_comm_type),
ChakraAttr(name="comm_size",
int64_val=pytorch_gpu_node.comm_size)])
chakra_gpu_node.attr.extend(
[
ChakraAttr(name="comm_type", int64_val=collective_comm_type),
ChakraAttr(name="comm_size", int64_val=pytorch_gpu_node.comm_size),
]
)

self.chakra_nodes[chakra_gpu_node.id] = chakra_gpu_node

Expand Down Expand Up @@ -229,14 +226,10 @@ def _parse_and_instantiate_nodes(self, pytorch_et_data: Dict) -> None:
self.pytorch_finish_ts = pytorch_et_data["finish_ts"]

pytorch_nodes = pytorch_et_data["nodes"]
pytorch_node_objects = {
node_data["id"]: PyTorchNode(node_data) for node_data in pytorch_nodes
}
pytorch_node_objects = {node_data["id"]: PyTorchNode(node_data) for node_data in pytorch_nodes}
self._establish_parent_child_relationships(pytorch_node_objects)

def _establish_parent_child_relationships(
self, pytorch_node_objects: Dict[int, PyTorchNode]
) -> None:
def _establish_parent_child_relationships(self, pytorch_node_objects: Dict[int, PyTorchNode]) -> None:
"""
Establishes parent-child relationships among PyTorch nodes and counts
the node types.
Expand All @@ -252,7 +245,7 @@ def _establish_parent_child_relationships(
"gpu_op": 0,
"record_param_comms_op": 0,
"nccl_op": 0,
"root_op": 0
"root_op": 0,
}

# Establish parent-child relationships
Expand All @@ -271,8 +264,10 @@ def _establish_parent_child_relationships(
if pytorch_node.is_nccl_op():
parent_node.nccl_node = pytorch_node

if pytorch_node.name in ["[pytorch|profiler|execution_graph|thread]",
"[pytorch|profiler|execution_trace|thread]"]:
if pytorch_node.name in [
"[pytorch|profiler|execution_graph|thread]",
"[pytorch|profiler|execution_trace|thread]",
]:
self.pytorch_root_nids.append(pytorch_node.id)
node_type_counts["root_op"] += 1

Expand Down Expand Up @@ -333,17 +328,19 @@ def convert_to_chakra_node(self, pytorch_node: PyTorchNode) -> ChakraNode:
chakra_node.outputs.values = str(pytorch_node.outputs)
chakra_node.outputs.shapes = str(pytorch_node.output_shapes)
chakra_node.outputs.types = str(pytorch_node.output_types)
chakra_node.attr.extend([
ChakraAttr(name="rf_id", int64_val=pytorch_node.rf_id),
ChakraAttr(name="fw_parent", int64_val=pytorch_node.fw_parent),
ChakraAttr(name="seq_id", int64_val=pytorch_node.seq_id),
ChakraAttr(name="scope", int64_val=pytorch_node.scope),
ChakraAttr(name="tid", int64_val=pytorch_node.tid),
ChakraAttr(name="fw_tid", int64_val=pytorch_node.fw_tid),
ChakraAttr(name="op_schema", string_val=pytorch_node.op_schema),
ChakraAttr(name="is_cpu_op", int32_val=not pytorch_node.is_gpu_op()),
ChakraAttr(name="ts", int64_val=pytorch_node.ts)
])
chakra_node.attr.extend(
[
ChakraAttr(name="rf_id", int64_val=pytorch_node.rf_id),
ChakraAttr(name="fw_parent", int64_val=pytorch_node.fw_parent),
ChakraAttr(name="seq_id", int64_val=pytorch_node.seq_id),
ChakraAttr(name="scope", int64_val=pytorch_node.scope),
ChakraAttr(name="tid", int64_val=pytorch_node.tid),
ChakraAttr(name="fw_tid", int64_val=pytorch_node.fw_tid),
ChakraAttr(name="op_schema", string_val=pytorch_node.op_schema),
ChakraAttr(name="is_cpu_op", int32_val=not pytorch_node.is_gpu_op()),
ChakraAttr(name="ts", int64_val=pytorch_node.ts),
]
)
return chakra_node

def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> ChakraNodeType:
Expand All @@ -356,9 +353,7 @@ def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> C
Returns:
int: The corresponding Chakra node type.
"""
if pytorch_node.is_gpu_op() and (
"ncclKernel" in pytorch_node.name or "ncclDevKernel" in pytorch_node.name
):
if pytorch_node.is_gpu_op() and ("ncclKernel" in pytorch_node.name or "ncclDevKernel" in pytorch_node.name):
return COMM_COLL_NODE
elif ("c10d::" in pytorch_node.name) or ("nccl:" in pytorch_node.name):
return COMM_COLL_NODE
Expand Down Expand Up @@ -392,8 +387,10 @@ def get_collective_comm_type(self, name: str) -> int:
if key.lower() in name.lower():
return value

raise ValueError(f"'{name}' not found in collective communication mapping. "
"Please add this collective communication name to the mapping.")
raise ValueError(
f"'{name}' not found in collective communication mapping. "
"Please add this collective communication name to the mapping."
)

def is_root_node(self, node):
"""
Expand All @@ -412,8 +409,7 @@ def is_root_node(self, node):
Returns:
bool: True if the node is a root node, False otherwise.
"""
if node.name in ["[pytorch|profiler|execution_graph|thread]",
"[pytorch|profiler|execution_trace|thread]"]:
if node.name in ["[pytorch|profiler|execution_graph|thread]", "[pytorch|profiler|execution_trace|thread]"]:
return True

def convert_ctrl_dep_to_data_dep(self, chakra_node: ChakraNode) -> None:
Expand Down Expand Up @@ -591,9 +587,7 @@ def dfs(node_id: int, path: List[int]) -> bool:
bool: True if a cycle is detected, False otherwise.
"""
if node_id in stack:
cycle_nodes = " -> ".join(
[self.chakra_nodes[n].name for n in path + [node_id]]
)
cycle_nodes = " -> ".join([self.chakra_nodes[n].name for n in path + [node_id]])
self.logger.error(f"Cyclic dependency detected: {cycle_nodes}")
return True
if node_id in visited:
Expand All @@ -611,10 +605,7 @@ def dfs(node_id: int, path: List[int]) -> bool:

for node_id in self.chakra_nodes:
if dfs(node_id, []):
raise Exception(
f"Cyclic dependency detected starting from node "
f"{self.chakra_nodes[node_id].name}"
)
raise Exception(f"Cyclic dependency detected starting from node " f"{self.chakra_nodes[node_id].name}")

def write_chakra_et(self) -> None:
"""
Expand Down Expand Up @@ -642,7 +633,7 @@ def _write_global_metadata(self) -> None:
ChakraAttr(name="pid", uint64_val=self.pytorch_pid),
ChakraAttr(name="time", string_val=self.pytorch_time),
ChakraAttr(name="start_ts", uint64_val=self.pytorch_start_ts),
ChakraAttr(name="finish_ts", uint64_val=self.pytorch_finish_ts)
ChakraAttr(name="finish_ts", uint64_val=self.pytorch_finish_ts),
]
)
encode_message(self.chakra_et, global_metadata)
Expand Down Expand Up @@ -684,21 +675,18 @@ def simulate_execution(self) -> None:
execution based on the readiness determined by dependency resolution.
A simplistic global clock is used to model the execution time.
"""
self.logger.info("Simulating execution of Chakra nodes based on data "
"dependencies.")
self.logger.info("Simulating execution of Chakra nodes based on data " "dependencies.")

# Initialize queues for ready CPU and GPU nodes
ready_cpu_nodes = [
(node_id, self.chakra_nodes[node_id])
for node_id in self.chakra_nodes
if not self.chakra_nodes[node_id].data_deps and
not self.pytorch_nodes[node_id].is_gpu_op()
if not self.chakra_nodes[node_id].data_deps and not self.pytorch_nodes[node_id].is_gpu_op()
]
ready_gpu_nodes = [
(node_id, self.chakra_nodes[node_id])
for node_id in self.chakra_nodes
if not self.chakra_nodes[node_id].data_deps and
self.pytorch_nodes[node_id].is_gpu_op()
if not self.chakra_nodes[node_id].data_deps and self.pytorch_nodes[node_id].is_gpu_op()
]
ready_cpu_nodes.sort(key=lambda x: x[1].id)
ready_gpu_nodes.sort(key=lambda x: x[1].id)
Expand All @@ -709,8 +697,7 @@ def simulate_execution(self) -> None:

current_time: int = 0 # Simulated global clock in microseconds

while any([ready_cpu_nodes, ready_gpu_nodes, current_cpu_node,
current_gpu_node]):
while any([ready_cpu_nodes, ready_gpu_nodes, current_cpu_node, current_gpu_node]):
if ready_cpu_nodes and not current_cpu_node:
cpu_node_id, cpu_node = ready_cpu_nodes.pop(0)
current_cpu_node = (cpu_node_id, current_time)
Expand All @@ -731,16 +718,18 @@ def simulate_execution(self) -> None:

current_time += 1

if current_cpu_node and current_time - current_cpu_node[1] >= \
self.chakra_nodes[current_cpu_node[0]].duration_micros:
self.logger.info(f"CPU Node ID {current_cpu_node[0]} completed "
f"at {current_time}us")
if (
current_cpu_node
and current_time - current_cpu_node[1] >= self.chakra_nodes[current_cpu_node[0]].duration_micros
):
self.logger.info(f"CPU Node ID {current_cpu_node[0]} completed " f"at {current_time}us")
current_cpu_node = None

if current_gpu_node and current_time - current_gpu_node[1] >= \
self.chakra_nodes[current_gpu_node[0]].duration_micros:
self.logger.info(f"GPU Node ID {current_gpu_node[0]} completed "
f"at {current_time}us")
if (
current_gpu_node
and current_time - current_gpu_node[1] >= self.chakra_nodes[current_gpu_node[0]].duration_micros
):
self.logger.info(f"GPU Node ID {current_gpu_node[0]} completed " f"at {current_time}us")
current_gpu_node = None

for node_id in list(issued_nodes):
Expand Down
16 changes: 8 additions & 8 deletions et_converter/pytorch_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def __init__(self, node_data: Dict[str, Any]) -> None:
PyTorch node.
"""
self.node_data = node_data
self.data_deps: List['PyTorchNode'] = []
self.children: List['PyTorchNode'] = []
self.gpu_children: List['PyTorchNode'] = []
self.record_param_comms_node: Optional['PyTorchNode'] = None
self.nccl_node: Optional['PyTorchNode'] = None
self.data_deps: List["PyTorchNode"] = []
self.children: List["PyTorchNode"] = []
self.gpu_children: List["PyTorchNode"] = []
self.record_param_comms_node: Optional["PyTorchNode"] = None
self.nccl_node: Optional["PyTorchNode"] = None

def __repr__(self) -> str:
"""
Expand Down Expand Up @@ -527,7 +527,7 @@ def is_gpu_op(self) -> bool:
"""
return self.has_cat()

def add_data_dep(self, parent_node: 'PyTorchNode') -> None:
def add_data_dep(self, parent_node: "PyTorchNode") -> None:
"""
Adds a data-dependent parent node to this node.

Expand All @@ -536,7 +536,7 @@ def add_data_dep(self, parent_node: 'PyTorchNode') -> None:
"""
self.data_deps.append(parent_node)

def add_child(self, child_node: 'PyTorchNode') -> None:
def add_child(self, child_node: "PyTorchNode") -> None:
"""
Adds a child node to this node.

Expand All @@ -545,7 +545,7 @@ def add_child(self, child_node: 'PyTorchNode') -> None:
"""
self.children.append(child_node)

def add_gpu_child(self, gpu_child_node: 'PyTorchNode') -> None:
def add_gpu_child(self, gpu_child_node: "PyTorchNode") -> None:
"""
Adds a child GPU node for this node.

Expand Down
Loading
Loading