Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass extra argument in arguments validator #1094

Merged
merged 5 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::types::{PyDict, PyList, PyString, PyTuple};
use ahash::AHashSet;

use crate::build_tools::py_schema_err;
use crate::build_tools::schema_or_config_same;
use crate::build_tools::{schema_or_config_same, ExtraBehavior};
use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult};
use crate::input::{GenericArguments, Input, ValidationMatch};
use crate::lookup_key::LookupKey;
Expand All @@ -31,6 +31,7 @@ pub struct ArgumentsValidator {
var_args_validator: Option<Box<CombinedValidator>>,
var_kwargs_validator: Option<Box<CombinedValidator>>,
loc_by_alias: bool,
extra: ExtraBehavior,
}

impl BuildValidator for ArgumentsValidator {
Expand Down Expand Up @@ -119,6 +120,7 @@ impl BuildValidator for ArgumentsValidator {
None => None,
},
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
extra: ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Forbid)?,
}
.into())
}
Expand Down Expand Up @@ -307,15 +309,16 @@ impl Validator for ArgumentsValidator {
Err(err) => return Err(err),
},
None => {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.as_loc_item(),
));
if let ExtraBehavior::Forbid = self.extra {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.as_loc_item(),
));
}
}
}
}
}
}}
}
}
}};
Expand Down
114 changes: 60 additions & 54 deletions tests/validators/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,57 +775,57 @@ def test_alias_populate_by_name(py_and_json: PyAndJson, input_value, expected):
assert v.validate_test(input_value) == expected


def validate(function):
"""
a demo validation decorator to test arguments
"""
parameters = signature(function).parameters

type_hints = get_type_hints(function)
mode_lookup = {
Parameter.POSITIONAL_ONLY: 'positional_only',
Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword',
Parameter.KEYWORD_ONLY: 'keyword_only',
}

arguments_schema = []
schema = {'type': 'arguments', 'arguments_schema': arguments_schema}
for i, (name, p) in enumerate(parameters.items()):
if p.annotation is p.empty:
annotation = Any
else:
annotation = type_hints[name]

assert annotation in (bool, int, float, str, Any), f'schema for {annotation} not implemented'
if annotation in (bool, int, float, str):
arg_schema = {'type': annotation.__name__}
else:
assert annotation is Any
arg_schema = {'type': 'any'}

if p.kind in mode_lookup:
if p.default is not p.empty:
arg_schema = {'type': 'default', 'schema': arg_schema, 'default': p.default}
s = {'name': name, 'mode': mode_lookup[p.kind], 'schema': arg_schema}
arguments_schema.append(s)
elif p.kind == Parameter.VAR_POSITIONAL:
schema['var_args_schema'] = arg_schema
else:
assert p.kind == Parameter.VAR_KEYWORD, p.kind
schema['var_kwargs_schema'] = arg_schema

validator = SchemaValidator(schema)

@wraps(function)
def wrapper(*args, **kwargs):
validated_args, validated_kwargs = validator.validate_python(ArgsKwargs(args, kwargs))
return function(*validated_args, **validated_kwargs)

return wrapper
def validate(config=None):
def decorator(function):
parameters = signature(function).parameters
type_hints = get_type_hints(function)
mode_lookup = {
Parameter.POSITIONAL_ONLY: 'positional_only',
Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword',
Parameter.KEYWORD_ONLY: 'keyword_only',
}

arguments_schema = []
schema = {'type': 'arguments', 'arguments_schema': arguments_schema}
for i, (name, p) in enumerate(parameters.items()):
if p.annotation is p.empty:
annotation = Any
else:
annotation = type_hints[name]

assert annotation in (bool, int, float, str, Any), f'schema for {annotation} not implemented'
if annotation in (bool, int, float, str):
arg_schema = {'type': annotation.__name__}
else:
assert annotation is Any
arg_schema = {'type': 'any'}

if p.kind in mode_lookup:
if p.default is not p.empty:
arg_schema = {'type': 'default', 'schema': arg_schema, 'default': p.default}
s = {'name': name, 'mode': mode_lookup[p.kind], 'schema': arg_schema}
arguments_schema.append(s)
elif p.kind == Parameter.VAR_POSITIONAL:
schema['var_args_schema'] = arg_schema
else:
assert p.kind == Parameter.VAR_KEYWORD, p.kind
schema['var_kwargs_schema'] = arg_schema

validator = SchemaValidator(schema, config=config)

@wraps(function)
def wrapper(*args, **kwargs):
# Validate arguments using the original schema
validated_args, validated_kwargs = validator.validate_python(ArgsKwargs(args, kwargs))
return function(*validated_args, **validated_kwargs)

return wrapper

return decorator


def test_function_any():
@validate
@validate()
def foobar(a, b, c):
return a, b, c

Expand All @@ -842,7 +842,7 @@ def foobar(a, b, c):


def test_function_types():
@validate
@validate()
def foobar(a: int, b: int, *, c: int):
return a, b, c

Expand Down Expand Up @@ -894,8 +894,8 @@ def test_function_positional_only(import_execute):
# language=Python
m = import_execute(
"""
def create_function(validate):
@validate
def create_function(validate, config = None):
@validate(config = config)
def foobar(a: int, b: int, /, c: int):
return a, b, c
return foobar
Expand All @@ -915,6 +915,12 @@ def foobar(a: int, b: int, /, c: int):
},
{'type': 'unexpected_keyword_argument', 'loc': ('b',), 'msg': 'Unexpected keyword argument', 'input': 2},
]
# Allowing extras using the config
foobar = m.create_function(validate, config={'title': 'func', 'extra_fields_behavior': 'allow'})
assert foobar('1', '2', c=3, d=4) == (1, 2, 3)
# Ignore works similar than allow
foobar = m.create_function(validate, config={'title': 'func', 'extra_fields_behavior': 'ignore'})
assert foobar('1', '2', c=3, d=4) == (1, 2, 3)


@pytest.mark.skipif(sys.version_info < (3, 10), reason='requires python3.10 or higher')
Expand All @@ -923,7 +929,7 @@ def test_function_positional_only_default(import_execute):
m = import_execute(
"""
def create_function(validate):
@validate
@validate()
def foobar(a: int, b: int = 42, /):
return a, b
return foobar
Expand All @@ -940,7 +946,7 @@ def test_function_positional_kwargs(import_execute):
m = import_execute(
"""
def create_function(validate):
@validate
@validate()
def foobar(a: int, b: int, /, **kwargs: bool):
return a, b, kwargs
return foobar
Expand All @@ -953,7 +959,7 @@ def foobar(a: int, b: int, /, **kwargs: bool):


def test_function_args_kwargs():
@validate
@validate()
def foobar(*args, **kwargs):
return args, kwargs

Expand Down