Skip to content

Commit

Permalink
Compute cost-analysis on only one HLO module.
Browse files Browse the repository at this point in the history
There was historically a goal to support multiple HLOs in an executable, but this work was never finished and is no longer planned so we don't need this support.

This will soon enable us to return only a dict, instead of a list of dicts with only one item.

PiperOrigin-RevId: 707067448
  • Loading branch information
zacmustin authored and Google-ML-Automation committed Jan 2, 2025
1 parent 800f903 commit 9a5142a
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def as_text(self) -> str:
else:
raise

# TODO(skyewm): this should return a single dict (I think returning a list
# was to support MPMD executables, which never fully landed)
# TODO(b/384741132): this should return a single dict (I think returning a list
# was to support MPMD executables, which never fully landed).
def cost_analysis(self) -> list[dict[str, float]]:
xla_ext_exe = self.xla_extension_executable()

Expand All @@ -266,9 +266,19 @@ def cost_analysis(self) -> list[dict[str, float]]:
# Try client method if executable cost_analysis method is unimplemented
if hasattr(xla_ext_exe, "client"):
try:
# TODO(b/384741132): We expect that the executable has only one
# HloModule. We should be able to remove this check once we update the
# Executable class to have only a single HloModule (see bug).
hlo_modules = xla_ext_exe.hlo_modules()
assert len(hlo_modules) == 1, (
f"Exectuable should have only one HloModule ({len(hlo_modules)})"
" were found)."
)

return [
xla_extension.hlo_module_cost_analysis(xla_ext_exe.client, m)
for m in xla_ext_exe.hlo_modules()
xla_extension.hlo_module_cost_analysis(
xla_ext_exe.client, hlo_modules[0]
)
]
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
Expand Down

0 comments on commit 9a5142a

Please sign in to comment.