Skip to content

Commit

Permalink
weekproxy trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jun 18, 2021
1 parent f34b68f commit d4e6969
Show file tree
Hide file tree
Showing 6 changed files with 5 additions and 18 deletions.
5 changes: 2 additions & 3 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Optional
from weakref import proxy
from weakref import proxy, ProxyType

from deprecate import void

Expand Down Expand Up @@ -45,7 +44,7 @@ class Loop(ABC):

def __init__(self) -> None:
self.iteration_count: int = 0
self.trainer: Optional['pl.Trainer'] = None
self.trainer: Optional[ProxyType['pl.Trainer']] = None

@property
@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def predictions(self):
def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
"""Connects the loop to everything necessary (like trainer and accelerators)"""
super().connect(trainer, *args, **kwargs)
# TODO: Make the trainer a weakref/proxy
self.evaluation_loop.connect(trainer)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def done(self) -> bool:
def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with all necessary things (like trainer)"""
super().connect(trainer, *args, **kwargs)
self.prediction_loop.connect(trainer, *args, **kwargs)
self.prediction_loop.connect(trainer)

def reset(self) -> None:
"""Resets the internal state of the loop for a new run"""
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,7 @@ def skip(self) -> bool:

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with necessary arguments like the trainer"""
# TODO(@justusschock): Do we want to forward *args and **kwargs to the inner loop here?
# TODO(@justusschock): Can we make the trainer a weakref/proxy?
void(*args, **kwargs)
self.trainer = trainer
super().connect(trainer, *args, **kwargs)
self.training_loop.connect(trainer)

def reset(self) -> None:
Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/loops/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@ def optimizer_freq_cumsum(self) -> int:
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
return self._optimizer_freq_cumsum

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
# TODO(@justusschock): can we make this a weakref/proxy?
void(*args, **kwargs)
self.trainer = trainer

def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@ def done(self) -> bool:

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with all necessary parts like trainer and accelerators"""

# TODO(@justusschock): should we forward *args and **kwargs to lower loops?
# TODO(@justusschock): can we make the trainer a proxy here?
self.trainer = trainer
super().connect(trainer, *args, **kwargs)
self.batch_loop = TrainingBatchLoop()
self.batch_loop.connect(trainer)

Expand Down

0 comments on commit d4e6969

Please sign in to comment.