diff --git a/docs/developer_guides/config.md b/docs/developer_guides/config.md index bf216278f3..29d093f9d0 100644 --- a/docs/developer_guides/config.md +++ b/docs/developer_guides/config.md @@ -118,3 +118,33 @@ Often times, you just want to play with the parameters of an existing model with # NOTE: the dataparser and associated configurations go at the end of the command ns-train {METHOD_NAME} --vis viewer {DATA_PARSER} --scale-factor 0.5 ``` + +### Extending Nerfstudio with custom methods +In order to extend the Nerfstudio and register your own methods, you can package your code as a python package +and register it with Nerfstudio as a `nerfstudio.method_configs` entrypoint in the `pyproject.toml` file. +The Nerfstudio will automatically look for all registered methods and will register them to be used +by method such as `ns-train`. + +Here is an example: +```python +"""my_package/my_config.py""" + +from nerfstudio.engine.trainer import TrainerConfig +from nerfstudio.plugins.types import MethodSpecification + +MyMethod = MethodSpecification( + config=TrainerConfig( + method_name="my-method", + ... + ), + description="Custom description" +) + +"""pyproject.toml""" +[project] +name = "my_package" +... + +[project.entry-points.'nerfstudio.method_configs'] +my-method = 'my_package.my_config:MyMethod' +``` \ No newline at end of file diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index 947dad880d..4f41ba4af6 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -57,6 +57,7 @@ from nerfstudio.models.vanilla_nerf import NeRFModel, VanillaModelConfig from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig from nerfstudio.pipelines.dynamic_batch import DynamicBatchPipelineConfig +from nerfstudio.plugins.registry import discover_methods method_configs: Dict[str, TrainerConfig] = {} descriptions = { @@ -387,6 +388,10 @@ vis="viewer", ) +external_methods, external_descriptions = discover_methods() +method_configs.update(external_methods) +descriptions.update(external_descriptions) + AnnotatedBaseConfigUnion = tyro.conf.SuppressFixed[ # Don't show unparseable (fixed) arguments in helptext. tyro.conf.FlagConversionOff[ tyro.extras.subcommand_type_from_defaults(defaults=method_configs, descriptions=descriptions) diff --git a/nerfstudio/plugins/__init__.py b/nerfstudio/plugins/__init__.py new file mode 100644 index 0000000000..d926cf95c3 --- /dev/null +++ b/nerfstudio/plugins/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The Nerfstudio Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nerfstudio/plugins/registry.py b/nerfstudio/plugins/registry.py new file mode 100644 index 0000000000..365f05a126 --- /dev/null +++ b/nerfstudio/plugins/registry.py @@ -0,0 +1,50 @@ +# Copyright 2022 The Nerfstudio Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module that keeps all registered plugins and allows for plugin discovery. +""" + +import sys +import typing as t + +from rich.progress import Console + +from nerfstudio.plugins.types import MethodSpecification + +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points +CONSOLE = Console(width=120) + + +def discover_methods(): + """ + Discovers all methods registered using the `nerfstudio.method_configs` entrypoint. + """ + methods = {} + descriptions = {} + discovered_entry_points = entry_points(group="nerfstudio.method_configs") + for name in discovered_entry_points.names: + specification = discovered_entry_points[name].load() + if not isinstance(specification, MethodSpecification): + CONSOLE.print( + "[bold yellow]Warning: Could not entry point {n} as it is not an instance of MethodSpecification" + ) + continue + specification = t.cast(MethodSpecification, specification) + methods[specification.config.method_name] = specification.config + descriptions[specification.config.method_name] = specification.description + return methods, descriptions diff --git a/nerfstudio/plugins/types.py b/nerfstudio/plugins/types.py new file mode 100644 index 0000000000..ee7e8e4706 --- /dev/null +++ b/nerfstudio/plugins/types.py @@ -0,0 +1,33 @@ +# Copyright 2022 The Nerfstudio Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This package contains specifications used to register plugins. +""" +from dataclasses import dataclass + +from nerfstudio.engine.trainer import TrainerConfig + + +@dataclass +class MethodSpecification: + """ + Method specification class used to register custom methods with Nerfstudio. + The registered methods will be available in commands such as `ns-train` + """ + + config: TrainerConfig + """Trainer configuration""" + description: str + """Method description shown in `ns-train` help""" diff --git a/pyproject.toml b/pyproject.toml index 21bc694562..ba57c01d49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "functorch==0.2.1", "h5py>=2.9.0", "imageio>=2.21.1", + 'importlib-metadata>=6.0.0; python_version < "3.10"', "ipywidgets>=7.6", "jupyterlab>=3.3.4", "matplotlib>=3.5.3", diff --git a/tests/plugins/test_registry.py b/tests/plugins/test_registry.py new file mode 100644 index 0000000000..fe8361a183 --- /dev/null +++ b/tests/plugins/test_registry.py @@ -0,0 +1,51 @@ +""" +Tests for the nerfstudio.plugins.registry module. +""" +import sys + +from nerfstudio.engine.trainer import TrainerConfig +from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig +from nerfstudio.plugins import registry +from nerfstudio.plugins.types import MethodSpecification + +if sys.version_info < (3, 10): + import importlib_metadata +else: + from importlib import metadata as importlib_metadata + + +TestConfig = MethodSpecification( + config=TrainerConfig( + method_name="test-method", + pipeline=VanillaPipelineConfig(), + optimizers={}, + ), + description="Test description", +) + + +def test_discover_methods(): + """Tests if a custom method gets properly registered using the discover_methods method""" + entry_points_backup = registry.entry_points + + def entry_points(group=None): + return importlib_metadata.EntryPoints( + [ + importlib_metadata.EntryPoint( + name="test", value="test_registry:TestConfig", group="nerfstudio.method_configs" + ) + ] + ).select(group=group) + + try: + # Mock importlib entry_points + registry.entry_points = entry_points + + # Discover plugins + methods, _ = registry.discover_methods() + assert "test-method" in methods + config = methods["test-method"] + assert isinstance(config, TrainerConfig) + finally: + # Revert mock + registry.entry_points = entry_points_backup