Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add concrete_args to feature extraction tracing. #8393

Merged
merged 3 commits into from
Apr 29, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions torchvision/models/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def _set_default_tracer_kwargs(original_tr_kwargs: Optional[Dict[str, Any]]) ->
def get_graph_node_names(
model: nn.Module,
tracer_kwargs: Optional[Dict[str, Any]] = None,
concrete_args: Optional[Dict[str, Any]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this parameter must be added as the last parameter, otherwise there's a (low) risk of breaking code that is using non-keyword parameters. This API should have been using keyword-only parameters to avoid that, but it didn't.

suppress_diff_warning: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Expand Down Expand Up @@ -232,7 +233,9 @@ def get_graph_node_names(
{"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
provided dictionary.

concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
not be treated as Proxies. This parameter is experimental and
its backwards-compatibility is *NOT* guaranteed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original pytorch docs are confusing and also state https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.trace

Backwards-compatibility for this API is guaranteed.

🤔

suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
Expand All @@ -249,9 +252,9 @@ def get_graph_node_names(
tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
is_training = model.training
train_tracer = NodePathTracer(**tracer_kwargs)
train_tracer.trace(model.train())
train_tracer.trace(model.train(), concrete_args=concrete_args)
eval_tracer = NodePathTracer(**tracer_kwargs)
eval_tracer.trace(model.eval())
eval_tracer.trace(model.eval(), concrete_args=concrete_args)
train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values())
if not suppress_diff_warning:
Expand Down Expand Up @@ -333,6 +336,7 @@ def create_feature_extractor(
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
tracer_kwargs: Optional[Dict[str, Any]] = None,
concrete_args: Optional[Dict[str, Any]] = None,
suppress_diff_warning: bool = False,
) -> fx.GraphModule:
"""
Expand Down Expand Up @@ -395,6 +399,9 @@ def create_feature_extractor(
{"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
provided dictionary.
concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
not be treated as Proxies. This parameter is experimental and
its backwards-compatibility is *NOT* guaranteed.
suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
Expand Down Expand Up @@ -482,7 +489,7 @@ def to_strdict(n) -> Dict[str, str]:

# Instantiate our NodePathTracer and use that to trace the model
tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model)
graph = tracer.trace(model, concrete_args=concrete_args)

name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__
graph_module = fx.GraphModule(tracer.root, graph, name)
Expand Down
Loading