Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Re-design call_hook interface #8506

Closed
carmocca opened this issue Jul 21, 2021 · 23 comments · Fixed by #10575
Closed

[RFC] Re-design call_hook interface #8506

carmocca opened this issue Jul 21, 2021 · 23 comments · Fixed by #10575
Assignees
Labels
design Includes a design discussion feature Is an improvement or enhancement let's do it! approved to implement priority: 1 Medium priority task refactor
Milestone

Comments

@carmocca
Copy link
Contributor

carmocca commented Jul 21, 2021

🚀 Feature

The current Trainer.call_hook implementation has several drawbacks:

  1. The call order cannot be modified. Forces that Callback hooks are called before LightningModule hooks of the same name, however, this is not always true. An example is on_load_checkpoint, where the order is the opposite.
  2. Any trainer method (which maybe is not part of the Callback hook API) with the same name as the hook to call will be silently called. This is potentially dangerous if somebody names a trainer method with the same name as a hook.
  3. Any accelerator method with the same names as the hook to call will be silently called. This is unwanted for some hook names as setup and teardown. Also dangerous.
  4. It doesn't allow the Callback hook to return something

https://github.com/PyTorchLightning/pytorch-lightning/blob/e1442d247e0e4967dd2772bdcf5166226c974f89/pytorch_lightning/trainer/trainer.py#L1282-L1317

Motivation

The previously mentioned reasons mean that the current implementation is error-prone. Additionally, when we support loop customization, we will want users to be able to call these hooks themselves.

Pitch

Provide a lambda call_hook function that just handles any necessary state maintenance before the hook call.

Option (a):

def call_hook(self, fn: Callable, hook_name: Optional[str], *args: Any, **kwargs: Any) -> Any:
    _set_current_fx_name()

    if hook_name is None:
        output = fn()
    else:
        with self.profiler.profile(hook_name):
            output = fn(*args, **kwargs)

    _unset_current_fx_name()

    return output

The state maintenance is currently just the _current_fx_name required for self.log()

This option means that the code will be more verbose but the calls will be made explicit and the user will have more responsibility.
This extra responsibility means that the user might break things if the hooks are not called in the proper order - whatever the proper order is for the hook in particular

Alternatives

Option (b):

def call_hook(self, hook_name: str, *args: Any, **kwargs: Any) -> Any:
    _set_current_fx_name()

    with self.profiler.profile(hook_name):
        _call_callback_hook(hook_name, *args, **kwargs)
        output = _call_model_hook(hook_name, *args, **kwargs)

    _unset_current_fx_name()

    return output

Does not resolve (1) and (4)

Option (c):

def call_hook(
    self,
    hook_name: str,
    should_call_accelerator: bool = False,
    should_call_trainer: bool = False,
    *args: Any,
    **kwargs: Any
) -> Any:
    _set_current_fx_name()

    with self.profiler.profile(hook_name):
        _call_callback_hook(hook_name, *args, **kwargs)
        if should_call_trainer:
            # note: in the master impl this is silently done with the call above
            _call_trainer_hook(hook_name, **args, **kwargs)
        output = _call_model_hook(hook_name, *args, **kwargs)
        if should_call_accelerator:
            # note: in the master impl this can also return something
            _call_accelerator_hook(hook_name, **args, **kwargs)

    _unset_current_fx_name()

    return output

Does not resolve (1) and (4)

Questions

Do the profilers merge profile timings with the same name?

If you enjoy PL, check out our other projects:

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc: @tchaton @awaelchli @ananthsub

cc @Borda @justusschock @awaelchli @akihironitta @tchaton

@carmocca carmocca added feature Is an improvement or enhancement discussion In a discussion stage refactor design Includes a design discussion labels Jul 21, 2021
@carmocca carmocca added this to the v1.5 milestone Jul 21, 2021
@carmocca carmocca changed the title Re-design call_hook interface [RFC] Re-design call_hook interface Aug 5, 2021
@ananthsub
Copy link
Contributor

I'd also add to the drawbacks:

  • It assumes the arguments are the same for each of the hooks it calls. There's no guarantee for this because we don't share the same interfaces across callbacks, model, and even accelerator. They're each defined independently
  • It shouldn't be a public API on the Trainer (even renaming it to _call_hook is preferable)
  • call_hook profiles the entire chunk of execution as one unit. this means callbacks, module, and accelerator execution for the same hook is attributed together. this can be very misleading for debugging perf bottlenecks. We should have the profiling wrapped around the most granular component instead: https://github.com/PyTorchLightning/pytorch-lightning/blob/e1442d247e0e4967dd2772bdcf5166226c974f89/pytorch_lightning/trainer/trainer.py#L1290-L1291 (e.g. [Callback.{state_key}.hook], [LightningModule.hook], [Accelerator.hook] are all separate entries in the profiler

In an ideal state, do we really need call_hook? Right now, its sole purpose seems to be for setting the current_fx name on the lightning module, which is used for logging. I believe that if we redid logging such that the LightningModule owned the results, then call_hook is purely a convenience for calling hooks. we could use #8509 to discuss it further . If we didn't need to set current_fx, then is there any additional benefit to having call_hook? it also reduces traceability, so I'd prefer call the model hooks directly from wherever in the trainer/loop we're at.

@carmocca
Copy link
Contributor Author

It shouldn't be a public API on the Trainer

It should be public as the users should be able to call custom hooks or change the order/location with loop customization. This is why I feel like this is an important issue to resolve - the users are limited because it currently does too much.

If we didn't need to set current_fx, then is there any additional benefit to having call_hook?

The only benefit then would be not having to deal with the profiler.

it also reduces traceability

If we go with option (a), since the callable itself is passed to call_hook, it would improve traceability as we wouldn't be relying on getattr anymore.

if we redid logging such that the LightningModule owned the results, then call_hook is purely a convenience for calling hooks. we could use #8509 to discuss it further

Sure. Although I think we could do what this issue proposes which would already improve the current state - if #8509 is resolved, then potentially we would just need to remove its use.

@ananthsub
Copy link
Contributor

Sure. Although I think we could do what this issue proposes which would already improve the current state - if #8509 is resolved, then potentially we would just need to remove its use.

I agree, I don't think this and #8509 are at odds.

It should be public as the users should be able to call custom hooks or change the order/location with loop customization. This is why I feel like this is an important issue to resolve - the users are limited because it currently does too much.

I think we could keep it private for now and expose this as and when the definition of custom loops is clearer.

@carmocca I wonder what we really want call_hook today to do. Initially, this was a convenience instituted for calling callback hooks. It's since morphed into something that's taken on more responsibilities (callback hooks, model hooks, profiling, setting the hook name for logging, accelerators).

My view:

What do you think?

@ananthsub
Copy link
Contributor

@carmocca i have a stupid question: when would someone pass a Callable vs the hook name in option A?

@carmocca
Copy link
Contributor Author

when would someone pass a Callable vs the hook name in option A?

I actually wasn't sure if we would need hook_name, doesn't seem like it after starting the implementation in #9029

Although, if we want to do the following:

(e.g. [Callback.{state_key}.hook], [LightningModule.hook], [Accelerator.hook] are all separate entries in the profiler

We might want to call it like trainer.call_hook(pl_module.on_train_batch_start, hook_name="LightningModule.on_train_batch_start"). Is this too verbose? Do you have any other idea to pass the profiler name?

@carmocca
Copy link
Contributor Author

carmocca commented Aug 22, 2021

We could also do the following:

    def _profile_and_call(self, obj: object, hook_name: str, *args: Any, **kwargs: Any) -> Optional[Any]:
        fn = getattr(obj, hook_name)
        with self.profiler.profile(f"{obj.__class__.__name__}.{hook_name}"):
            return fn(*args, **kwargs)

    def call_hook(
        self,
        obj: object,
        hook_name: str,
        *args: Any,
        pl_module: Optional["pl.LightningModule"] = None,
        **kwargs: Any,
    ) -> Optional[Any]:
        pl_module = pl_module or self.lightning_module

        if pl_module is not None:
            prev_fx_name = pl_module._current_fx_name
            pl_module._current_fx_name = hook_name

        output = None
        if isinstance(obj, Trainer):
            for obj in self.callbacks:
                if hook_name in ("on_init_start", "on_init_end"):
                    # these `Callback` hooks are the only ones that do not take a lightning module
                    output = self._profile_and_call(obj, hook_name, self, *args, **kwargs)
                else:
                    output = self._profile_and_call(obj, hook_name, self, pl_module, *args, **kwargs)
        else:
            output = self._profile_and_call(obj, hook_name, *args, **kwargs)

        if pl_module is not None:
            # restore current_fx when nested context
            pl_module._current_fx_name = prev_fx_name

        return output

Usage:

self.trainer.call_hook(self.trainer, "on_before_zero_grad", optimizer)
self.trainer.call_hook(self.trainer.lightning_module, "on_before_zero_grad", optimizer)

Which would allow profiling each hook per class in the case of callbacks. It would also remove the need for TrainerCallbackHookMixin

with the disadvantage of losing traceability

@ananthsub
Copy link
Contributor

i think the latter proposal is risky: what guarantees on obj are placed? is it really any object with any name with any args? is it for the trainer? the lightning module? accelerator?

does having separate functions resolve this? trainer._call_lightning_module_hook, trainer._call_callback_hook, trainer._call_accelerator_hook ?

i think this is a slightly restricted view of option a but still addresses the drawbacks

@carmocca
Copy link
Contributor Author

what guarantees on obj are placed? is it really any object with any name with any args? is it for the trainer? the lightning module? accelerator?

call_hook wouldn't guarantee anything. With this proposal it becomes a utility to reduce boilerplate (setting/unsetting state, managing profiler, looping over callbacks). Any misusage would be on the caller.

does having separate functions resolve this? trainer._call_lightning_module_hook, trainer._call_callback_hook, trainer._call_accelerator_hook ?

I don't think it would significantly. Would that be the same code split over 3 functions with an isinstance check on each?
I'd say that would be a lot of code for very little functionality - especially when we consider this "beta" and its on us to properly call things.

I'm a bit conflicted on this because I wanted to reduce the responsability of this method, however, managing the profiler and the profiled names complicates this.
I guess my previous snippet is an okay tradeoff with:

Pros:

  • Flexibility in call order
  • Per-class profiler sections
  • Removing the simple loops over callbacks in TrainerCallbackHookMixin. on_save_checkpoint and on_load_checkpoint would stay though.

Cons:

  • Lost traceability
  • Holds some assumptions on the inputs, most notably in the callbacks taking a Trainer and LightningModule

@tchaton
Copy link
Contributor

tchaton commented Aug 25, 2021

Hey @carmocca,

I believe your proposal is sane and I agree any misusage would be on the caller, but at the same time the behaviour is quite clear.

Best,
T.C

@tchaton tchaton added the let's do it! approved to implement label Aug 25, 2021
@daniellepintz
Copy link
Contributor

daniellepintz commented Oct 28, 2021

@carmocca a few questions

  1. is there a reason why you didn't finish [WIP] Trainer.call_hook re-design #9029?
  2. what exactly is meant by traceability, and why does having call_hook reduce it?
  3. Not every function goes through call_hook right, so why are these the only ones profiler and logger care about?

@daniellepintz
Copy link
Contributor

I like your proposal above in #8506 (comment) (let's call it option d), since it allows us to deprecate TrainerCallbackHookMixin. if the concern there is about guarantees on obj, can't we remedy this by checking if the object is either Trainer/LM/Accelerator at the start of the function?

@carmocca
Copy link
Contributor Author

is there a reason why you didn't finish

Got dragged into other things. I plan to finish it for 1.6

what exactly is meant by traceability

When I say "loses traceability", I mean that tools cannot track its calls and usage anymore. Compare:

# no traceability
fn = getattr(obj, "method_name")
fn()
# vs
# Can be understood by IDEs
obj.method_name()

Not every function goes through call_hook right?

Every hook should, if it doesn't then it's an oversight (hook == user-overrideable method with no default impl).

can't we remedy this by checking if the object is either Trainer/LM/Accelerator at the start of the function?

Sure.

@carmocca carmocca modified the milestones: v1.5, v1.6 Oct 28, 2021
@carmocca carmocca removed the discussion In a discussion stage label Oct 28, 2021
@carmocca carmocca added the priority: 0 High priority task label Oct 28, 2021
@daniellepintz
Copy link
Contributor

@ananthsub are you okay with option d above in #8506 (comment)? We can add a check to make sure this is only called on Callback, LM, or Accelerator?

@daniellepintz
Copy link
Contributor

daniellepintz commented Nov 15, 2021

Another thought I had is do we expect users to sometimes want to call a hook on a callback directly? i.e. instead of calling self.trainer.call_hook(self.trainer, "on_before_zero_grad", optimizer), call self.trainer.call_hook(self.trainer.progress_bar_callback, "on_before_zero_grad", optimizer)? (They could do that with option d, just curious if we think this will be used?)

@daniellepintz
Copy link
Contributor

Also @carmocca why do we need to pass pl_module - can't we always use self.lightning_module?

@carmocca
Copy link
Contributor Author

self.trainer.progress_bar_callback.call_hook()

call_hook is a method of the Trainer, not a Callback

why do we need to pass pl_module - can't we always use self.lightning_module?

Mostly for legacy reasons. Some methods of the Trainer are marked as public (perhaps accidentally), which means they can be called before a LightningModule reference has been set, so they take an optional model. These methods can call hooks so that's why we need to pass it. Removing support for that would mean going through a deprecation.

https://github.com/PyTorchLightning/pytorch-lightning/blob/dcafc95f2b0fd3f176d425139ca99676ce943a12/pytorch_lightning/trainer/data_loading.py#L566

https://github.com/PyTorchLightning/pytorch-lightning/blob/dcafc95f2b0fd3f176d425139ca99676ce943a12/pytorch_lightning/trainer/data_loading.py#L332

https://github.com/PyTorchLightning/pytorch-lightning/blob/dcafc95f2b0fd3f176d425139ca99676ce943a12/pytorch_lightning/trainer/data_loading.py#L511-L552

@daniellepintz
Copy link
Contributor

call_hook is a method of the Trainer, not a Callback

Ah, sorry I meant self.trainer.call_hook(self.trainer.progress_bar_callback, "on_before_zero_grad", optimizer). Updating my comment

@daniellepintz
Copy link
Contributor

daniellepintz commented Nov 16, 2021

Thanks for explaining why we need to pass pl_module!

Removing support for that would mean going through a deprecation.

Got it, do you think it would be feasible to remove support for it? I think it would be nice if we can, separate from this issue, since it can be confusing that some functions are passed a LM and others just use self.lightning_module, resulting in lots of lines like this pl_module = pl_module or self.lightning_module throughout the codebase.

@daniellepintz
Copy link
Contributor

@carmocca @ananthsub I have a first draft implementation in #10575, would appreciate if you could take a look. I created a new PR since #9029 is quite old and has a lot of merge conflicts

@awaelchli
Copy link
Contributor

Haven't followed the full discussion, came here from the PR review. I find it irritating that we are choosing the call_hook to be protected yet accessing it everywhere as if it's public. A protected member is meant to be accessed only within the class itself. This pattern is recurring in Lightning and becoming a standard, at which point people are probably annoyed that I bring it up again and again. I get it that you don't want to make the methods public, but have you considered making it a function instead? The function would be public and take the trainer instance as input.

@daniellepintz
Copy link
Contributor

daniellepintz commented Nov 26, 2021

@awaelchli I agree, we should not be accessing protected methods from outside the class. Why would making them public functions be better than public methods on the Trainer class?

@awaelchli
Copy link
Contributor

Since we want to call hooks from components outside the Trainer, and a trainer instance is required to do it, there would naturally be two options:

  1. method on the trainer
  2. function that takes the trainer as input

Option 1) does not go well these days, because #7740 has put a lot of pressure not to expose any Trainer attributes to the user unless highly stable.
IMO this functionality needs to be publicly accessible because of the heavy use in loops. Custom loop implementations (by users) will most likely contain such calls. So the only option I see left is 2), the function approach. It will be hard to convince @carmocca and @ananthsub, but these are just the concerns I wanted to raise.

@carmocca
Copy link
Contributor Author

I do prefer having a method on the trainer as only the trainer itself and the loops (which have a self.trainer reference) should call hooks. This imposes the limitation of requiring a trainer instance to call a hook, which avoids potential errors from users who might call this from anywhere if it's a function.

needs to be publicly accessible because of the heavy use in loops

It's true, however, the loops right now include many methods important for customization which are "protected", yet we advertise loop customization.

Just as we should review these Loop's protected methods to make them public at some point in the future, we can do the same for call hook right now, start protected and reevaluate later.

Not a strong opinion anyways.

@tchaton tchaton added priority: 1 Medium priority task and removed priority: 0 High priority task labels Nov 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement let's do it! approved to implement priority: 1 Medium priority task refactor
Projects
None yet
5 participants