diff --git a/.pylintrc b/.pylintrc index 2176ad9..49da7d1 100644 --- a/.pylintrc +++ b/.pylintrc @@ -33,7 +33,8 @@ disable= too-few-public-methods, cyclic-import, import-error, - no-name-in-module + no-name-in-module, + redefined-outer-name, [REPORTS] diff --git a/src/crllm/cli.py b/src/crllm/cli.py index b2e3762..57f059d 100644 --- a/src/crllm/cli.py +++ b/src/crllm/cli.py @@ -13,15 +13,29 @@ def cli(): description="Get Code Reviews from large language models." ) - parser.add_argument("input_file", help="Path to the source code file") + parser.add_argument("input", help="Path to the sources") parser.add_argument( - "-c", help="Path to the config file", default="crllm_config.toml" + "-c", "--config", help="Path to the config file", default="crllm_config.toml" + ) + parser.add_argument( + "-l", + "--loader", + help="Loader to use", + required=False, + choices=["file", "git", "git_compare"], ) args = parser.parse_args() - config_service.get_config(args.c) - app(args.input_file) + config = {"crllm": {}} + + if args.loader: + config["crllm"]["loader"] = args.loader + + config_service.set_config_path(args.config) + config_service.override_config(config) + + app(args.input) if __name__ == "__main__": diff --git a/src/crllm/config/config_service.py b/src/crllm/config/config_service.py index cf9beed..610df7b 100644 --- a/src/crllm/config/config_service.py +++ b/src/crllm/config/config_service.py @@ -7,24 +7,34 @@ class ConfigService: - configPath = os.path.join(CONFIG_DIR, "config.toml") + default_config_path = os.path.join(CONFIG_DIR, "config.toml") + config_path = "./crllm_config.toml" config = None - def get_config(self, path="./crllm_config.toml"): + def set_config_path(self, path): + self.config_path = path + + def get_config(self): if self.config: return self.config - self.config = toml.load(self.configPath) + self.config = toml.load(self.default_config_path) - if not os.path.isfile(path): + if not os.path.isfile(self.config_path): return self.config - project_config = toml.load(path) - self.config = always_merger.merge(self.config, project_config) + project_config = toml.load(self.config_path) + self.override_config(project_config) logging.debug(self.config) return self.config + def override_config(self, config: dict): + if not self.config: + self.get_config() + + self.config = always_merger.merge(self.config, config) + config_service = ConfigService() diff --git a/src/crllm/test/unit/config/config_service_test.py b/src/crllm/test/unit/config/config_service_test.py new file mode 100644 index 0000000..12d5257 --- /dev/null +++ b/src/crllm/test/unit/config/config_service_test.py @@ -0,0 +1,36 @@ +import pytest +from crllm.config.config_service import ConfigService + + +@pytest.fixture() +def mock_config_file(tmpdir): + tmpdir.join("config.toml").write( + """ + [crllm] + loader = "foo" + provider = "bar" + """ + ) + + return str(tmpdir.join("config.toml")) + + +def test_get_config(mock_config_file): + config_service = ConfigService() + config_service.set_config_path(mock_config_file) + + result = config_service.get_config() + + assert result["crllm"]["loader"] == "foo" + assert result["crllm"]["provider"] == "bar" + + +def test_override(mock_config_file): + config_service = ConfigService() + config_service.set_config_path(mock_config_file) + + config_service.override_config({"crllm": {"loader": "test"}}) + + result = config_service.get_config() + + assert result["crllm"]["loader"] == "test"