Skip to content

Commit

Permalink
Refactor the CLI checking of config arguments (#667)
Browse files Browse the repository at this point in the history
This change moves the validation of the config file earlier in the setup
of the arg parser. The argparse library itself is validating whether the
file exists, so seems appropriate to also validate whether the file is a
proper toml config file early as well.

Signed-off-by: Eric Brown <eric.brown@securesauce.dev>
  • Loading branch information
ericwb authored Oct 29, 2024
1 parent a188b74 commit c5d6d46
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions precli/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,35 +160,34 @@ def setup_arg_parser():
)
args = parser.parse_args()

if args.config:
try:
args.config = tomllib.load(args.config)
except tomllib.TOMLDecodeError as err:
parser.error(
f"argument -c/--config: can't load '{args.config.name}': {err}"
)

if not args.targets:
parser.print_usage()
sys.exit(2)

return parser, args
return args, args.config


def load_config(
parser: ArgumentParser, config: dict, targets: list[str]
) -> dict:
if config:
try:
return tomllib.load(config)
except tomllib.TOMLDecodeError as err:
parser.error(
f"argument -c/--config: can't load '{config.name}': {err}"
)
else:
default_confs = (".precli.toml", "precli.toml", "pyproject.toml")
for target in filter(os.path.isdir, targets):
for conf in default_confs:
path = pathlib.Path(target) / conf
try:
if path.exists():
with open(path, "rb") as f:
return tomllib.load(f)
except tomllib.TOMLDecodeError:
# TODO: Log but don't exit
pass
def find_config(targets: list[str]) -> dict:
default_confs = (".precli.toml", "precli.toml", "pyproject.toml")

for target in filter(os.path.isdir, targets):
for conf in default_confs:
path = pathlib.Path(target) / conf
try:
if path.exists():
with open(path, "rb") as f:
return tomllib.load(f)
except tomllib.TOMLDecodeError:
# TODO: Log but don't exit
pass

return {}

Expand Down Expand Up @@ -377,10 +376,11 @@ def main():
logging.getLogger("urllib3").setLevel(debug)

# Setup the command line arguments
parser, args = setup_arg_parser()
args, config = setup_arg_parser()

# Load optional configuration file
config = load_config(parser, args.config, args.targets)
# Attempt to find config files if one not provided
if not config:
config = find_config(args.targets)

# CLI enabled/disabled override any config in files
config["enabled"] = (
Expand Down

0 comments on commit c5d6d46

Please sign in to comment.