diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index d8c2dca4afe..f42bc124c7b 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -204,6 +204,7 @@ def get_graph_node_names( model: nn.Module, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, + concrete_args: Optional[Dict[str, Any]] = None, ) -> Tuple[List[str], List[str]]: """ Dev utility to return node names in order of execution. See note on node @@ -232,10 +233,13 @@ def get_graph_node_names( {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),} WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user provided dictionary. - suppress_diff_warning (bool, optional): whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that should + not be treated as Proxies. According to the `Pytorch docs + `_, + this parameter's API may not be guaranteed. Returns: tuple(list, list): a list of node names from tracing the model in @@ -249,9 +253,9 @@ def get_graph_node_names( tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs) is_training = model.training train_tracer = NodePathTracer(**tracer_kwargs) - train_tracer.trace(model.train()) + train_tracer.trace(model.train(), concrete_args=concrete_args) eval_tracer = NodePathTracer(**tracer_kwargs) - eval_tracer.trace(model.eval()) + eval_tracer.trace(model.eval(), concrete_args=concrete_args) train_nodes = list(train_tracer.node_to_qualname.values()) eval_nodes = list(eval_tracer.node_to_qualname.values()) if not suppress_diff_warning: @@ -334,6 +338,7 @@ def create_feature_extractor( eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, + concrete_args: Optional[Dict[str, Any]] = None, ) -> fx.GraphModule: """ Creates a new graph module that returns intermediate nodes from a given @@ -398,6 +403,10 @@ def create_feature_extractor( suppress_diff_warning (bool, optional): whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that should + not be treated as Proxies. According to the `Pytorch docs + `_, + this parameter's API may not be guaranteed. Examples:: @@ -482,7 +491,7 @@ def to_strdict(n) -> Dict[str, str]: # Instantiate our NodePathTracer and use that to trace the model tracer = NodePathTracer(**tracer_kwargs) - graph = tracer.trace(model) + graph = tracer.trace(model, concrete_args=concrete_args) name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ graph_module = fx.GraphModule(tracer.root, graph, name)