Skip to content

Commit

Permalink
[Frontend] [Neuron] Parse literals out of override-neuron-config (#8959)
Browse files Browse the repository at this point in the history
Co-authored-by: Jerzy Zagorski <jzagorsk@amazon.com>
  • Loading branch information
xendo and Jerzy Zagorski authored Oct 3, 2024
1 parent f5d72b2 commit 63e3993
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 20 deletions.
48 changes: 34 additions & 14 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,42 @@ def test_bad_nullable_kvs(arg):
nullable_kvs(arg)


@pytest.mark.parametrize(("arg", "expected"), [
(None, None),
("{}", {}),
('{"num_crops": 4}', {
"num_crops": 4
}),
('{"foo": {"bar": "baz"}}', {
"foo": {
"bar": "baz"
}
}),
# yapf: disable
@pytest.mark.parametrize(("arg", "expected", "option"), [
(None, None, "mm-processor-kwargs"),
("{}", {}, "mm-processor-kwargs"),
(
'{"num_crops": 4}',
{
"num_crops": 4
},
"mm-processor-kwargs"
),
(
'{"foo": {"bar": "baz"}}',
{
"foo":
{
"bar": "baz"
}
},
"mm-processor-kwargs"
),
(
'{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}',
{
"cast_logits_dtype": "bfloat16",
"sequence_parallel_norm": True,
"sequence_parallel_norm_threshold": 2048,
},
"override-neuron-config"
),
])
def test_mm_processor_kwargs_prompt_parser(arg, expected):
# yapf: enable
def test_composite_arg_parser(arg, expected, option):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None:
args = parser.parse_args([])
else:
args = parser.parse_args(["--mm-processor-kwargs", arg])
assert args.mm_processor_kwargs == expected
args = parser.parse_args([f"--{option}", arg])
assert getattr(args, option.replace("-", "_")) == expected
9 changes: 3 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,13 +800,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"lower performance.")
parser.add_argument(
'--override-neuron-config',
type=lambda configs: {
str(key): value
for key, value in
(config.split(':') for config in configs.split(','))
},
type=json.loads,
default=None,
help="override or set neuron device configuration.")
help="Override or set neuron device configuration. "
"e.g. {\"cast_logits_dtype\": \"bloat16\"}.'")

parser.add_argument(
'--scheduling-policy',
Expand Down

0 comments on commit 63e3993

Please sign in to comment.