Skip to content

Commit

Permalink
pass extra argument in arguments validator (#1094)
Browse files Browse the repository at this point in the history
  • Loading branch information
andresliszt authored Nov 30, 2023
1 parent f323e74 commit 7fa450d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 62 deletions.
19 changes: 11 additions & 8 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::types::{PyDict, PyList, PyString, PyTuple};
use ahash::AHashSet;

use crate::build_tools::py_schema_err;
use crate::build_tools::schema_or_config_same;
use crate::build_tools::{schema_or_config_same, ExtraBehavior};
use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult};
use crate::input::{GenericArguments, Input, ValidationMatch};
use crate::lookup_key::LookupKey;
Expand All @@ -31,6 +31,7 @@ pub struct ArgumentsValidator {
var_args_validator: Option<Box<CombinedValidator>>,
var_kwargs_validator: Option<Box<CombinedValidator>>,
loc_by_alias: bool,
extra: ExtraBehavior,
}

impl BuildValidator for ArgumentsValidator {
Expand Down Expand Up @@ -119,6 +120,7 @@ impl BuildValidator for ArgumentsValidator {
None => None,
},
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
extra: ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Forbid)?,
}
.into())
}
Expand Down Expand Up @@ -307,15 +309,16 @@ impl Validator for ArgumentsValidator {
Err(err) => return Err(err),
},
None => {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.as_loc_item(),
));
if let ExtraBehavior::Forbid = self.extra {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.as_loc_item(),
));
}
}
}
}
}
}}
}
}
}};
Expand Down
114 changes: 60 additions & 54 deletions tests/validators/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,57 +775,57 @@ def test_alias_populate_by_name(py_and_json: PyAndJson, input_value, expected):
assert v.validate_test(input_value) == expected


def validate(function):
"""
a demo validation decorator to test arguments
"""
parameters = signature(function).parameters

type_hints = get_type_hints(function)
mode_lookup = {
Parameter.POSITIONAL_ONLY: 'positional_only',
Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword',
Parameter.KEYWORD_ONLY: 'keyword_only',
}

arguments_schema = []
schema = {'type': 'arguments', 'arguments_schema': arguments_schema}
for i, (name, p) in enumerate(parameters.items()):
if p.annotation is p.empty:
annotation = Any
else:
annotation = type_hints[name]

assert annotation in (bool, int, float, str, Any), f'schema for {annotation} not implemented'
if annotation in (bool, int, float, str):
arg_schema = {'type': annotation.__name__}
else:
assert annotation is Any
arg_schema = {'type': 'any'}

if p.kind in mode_lookup:
if p.default is not p.empty:
arg_schema = {'type': 'default', 'schema': arg_schema, 'default': p.default}
s = {'name': name, 'mode': mode_lookup[p.kind], 'schema': arg_schema}
arguments_schema.append(s)
elif p.kind == Parameter.VAR_POSITIONAL:
schema['var_args_schema'] = arg_schema
else:
assert p.kind == Parameter.VAR_KEYWORD, p.kind
schema['var_kwargs_schema'] = arg_schema

validator = SchemaValidator(schema)

@wraps(function)
def wrapper(*args, **kwargs):
validated_args, validated_kwargs = validator.validate_python(ArgsKwargs(args, kwargs))
return function(*validated_args, **validated_kwargs)

return wrapper
def validate(config=None):
def decorator(function):
parameters = signature(function).parameters
type_hints = get_type_hints(function)
mode_lookup = {
Parameter.POSITIONAL_ONLY: 'positional_only',
Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword',
Parameter.KEYWORD_ONLY: 'keyword_only',
}

arguments_schema = []
schema = {'type': 'arguments', 'arguments_schema': arguments_schema}
for i, (name, p) in enumerate(parameters.items()):
if p.annotation is p.empty:
annotation = Any
else:
annotation = type_hints[name]

assert annotation in (bool, int, float, str, Any), f'schema for {annotation} not implemented'
if annotation in (bool, int, float, str):
arg_schema = {'type': annotation.__name__}
else:
assert annotation is Any
arg_schema = {'type': 'any'}

if p.kind in mode_lookup:
if p.default is not p.empty:
arg_schema = {'type': 'default', 'schema': arg_schema, 'default': p.default}
s = {'name': name, 'mode': mode_lookup[p.kind], 'schema': arg_schema}
arguments_schema.append(s)
elif p.kind == Parameter.VAR_POSITIONAL:
schema['var_args_schema'] = arg_schema
else:
assert p.kind == Parameter.VAR_KEYWORD, p.kind
schema['var_kwargs_schema'] = arg_schema

validator = SchemaValidator(schema, config=config)

@wraps(function)
def wrapper(*args, **kwargs):
# Validate arguments using the original schema
validated_args, validated_kwargs = validator.validate_python(ArgsKwargs(args, kwargs))
return function(*validated_args, **validated_kwargs)

return wrapper

return decorator


def test_function_any():
@validate
@validate()
def foobar(a, b, c):
return a, b, c

Expand All @@ -842,7 +842,7 @@ def foobar(a, b, c):


def test_function_types():
@validate
@validate()
def foobar(a: int, b: int, *, c: int):
return a, b, c

Expand Down Expand Up @@ -894,8 +894,8 @@ def test_function_positional_only(import_execute):
# language=Python
m = import_execute(
"""
def create_function(validate):
@validate
def create_function(validate, config = None):
@validate(config = config)
def foobar(a: int, b: int, /, c: int):
return a, b, c
return foobar
Expand All @@ -915,6 +915,12 @@ def foobar(a: int, b: int, /, c: int):
},
{'type': 'unexpected_keyword_argument', 'loc': ('b',), 'msg': 'Unexpected keyword argument', 'input': 2},
]
# Allowing extras using the config
foobar = m.create_function(validate, config={'title': 'func', 'extra_fields_behavior': 'allow'})
assert foobar('1', '2', c=3, d=4) == (1, 2, 3)
# Ignore works similar than allow
foobar = m.create_function(validate, config={'title': 'func', 'extra_fields_behavior': 'ignore'})
assert foobar('1', '2', c=3, d=4) == (1, 2, 3)


@pytest.mark.skipif(sys.version_info < (3, 10), reason='requires python3.10 or higher')
Expand All @@ -923,7 +929,7 @@ def test_function_positional_only_default(import_execute):
m = import_execute(
"""
def create_function(validate):
@validate
@validate()
def foobar(a: int, b: int = 42, /):
return a, b
return foobar
Expand All @@ -940,7 +946,7 @@ def test_function_positional_kwargs(import_execute):
m = import_execute(
"""
def create_function(validate):
@validate
@validate()
def foobar(a: int, b: int, /, **kwargs: bool):
return a, b, kwargs
return foobar
Expand All @@ -953,7 +959,7 @@ def foobar(a: int, b: int, /, **kwargs: bool):


def test_function_args_kwargs():
@validate
@validate()
def foobar(*args, **kwargs):
return args, kwargs

Expand Down

0 comments on commit 7fa450d

Please sign in to comment.