Skip to content

Commit

Permalink
[App] Improve LightningTrainerScript start-up time (#15751)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Nov 21, 2022
1 parent bc797fd commit c2c1974
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions src/lightning_app/components/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,33 +147,28 @@ def __init__(
the ServableModule API
"""
super().__init__()
self.ws = structures.List()
self.has_initialized = False
self.script_path = script_path
self.script_args = script_args
self.num_nodes = num_nodes
self._cloud_compute = cloud_compute # TODO: Add support for cloudCompute
self.sanity_serving = sanity_serving
self._script_runner = script_runner
self._script_runner_kwargs = script_runner_kwargs

def run(self, **run_kwargs):
if not self.has_initialized:
for node_rank in range(self.num_nodes):
self.ws.append(
self._script_runner(
script_path=self.script_path,
script_args=self.script_args,
cloud_compute=self._cloud_compute,
node_rank=node_rank,
sanity_serving=self.sanity_serving,
num_nodes=self.num_nodes,
**self._script_runner_kwargs,
)
self.ws = structures.List()
for node_rank in range(self.num_nodes):
self.ws.append(
self._script_runner(
script_path=self.script_path,
script_args=self.script_args,
cloud_compute=cloud_compute,
node_rank=node_rank,
sanity_serving=self.sanity_serving,
num_nodes=self.num_nodes,
**self._script_runner_kwargs,
)
)

self.has_initialized = True

def run(self, **run_kwargs):
for work in self.ws:
if all(w.internal_ip for w in self.ws):
internal_urls = [(w.internal_ip, w.port) for w in self.ws]
Expand Down

0 comments on commit c2c1974

Please sign in to comment.