diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 360ac1bfbad93..f7dc167fea6e4 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -42,22 +42,42 @@ def test_bad_nullable_kvs(arg): nullable_kvs(arg) -@pytest.mark.parametrize(("arg", "expected"), [ - (None, None), - ("{}", {}), - ('{"num_crops": 4}', { - "num_crops": 4 - }), - ('{"foo": {"bar": "baz"}}', { - "foo": { - "bar": "baz" - } - }), +# yapf: disable +@pytest.mark.parametrize(("arg", "expected", "option"), [ + (None, None, "mm-processor-kwargs"), + ("{}", {}, "mm-processor-kwargs"), + ( + '{"num_crops": 4}', + { + "num_crops": 4 + }, + "mm-processor-kwargs" + ), + ( + '{"foo": {"bar": "baz"}}', + { + "foo": + { + "bar": "baz" + } + }, + "mm-processor-kwargs" + ), + ( + '{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}', + { + "cast_logits_dtype": "bfloat16", + "sequence_parallel_norm": True, + "sequence_parallel_norm_threshold": 2048, + }, + "override-neuron-config" + ), ]) -def test_mm_processor_kwargs_prompt_parser(arg, expected): +# yapf: enable +def test_composite_arg_parser(arg, expected, option): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: args = parser.parse_args([]) else: - args = parser.parse_args(["--mm-processor-kwargs", arg]) - assert args.mm_processor_kwargs == expected + args = parser.parse_args([f"--{option}", arg]) + assert getattr(args, option.replace("-", "_")) == expected diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0d4559e377427..152470430e606 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -789,13 +789,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "lower performance.") parser.add_argument( '--override-neuron-config', - type=lambda configs: { - str(key): value - for key, value in - (config.split(':') for config in configs.split(',')) - }, + type=json.loads, default=None, - help="override or set neuron device configuration.") + help="Override or set neuron device configuration. " + "e.g. {\"cast_logits_dtype\": \"bloat16\"}.'") return parser