-
Notifications
You must be signed in to change notification settings - Fork 131
/
typeconverter.py
637 lines (511 loc) · 20.7 KB
/
typeconverter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
# Copyright (c) 2009-2024 The Regents of the University of Michigan.
# Part of HOOMD-blue, released under the BSD 3-Clause License.
"""Implement type conversion helpers."""
import numpy as np
from abc import ABC, abstractmethod
from collections.abc import Mapping, MutableMapping
from inspect import isclass
from hoomd.error import TypeConversionError
from hoomd.util import _is_iterable
from hoomd.variant import Variant, Constant
from hoomd.trigger import Trigger, Periodic
from hoomd.filter import ParticleFilter, CustomFilter
import hoomd
class RequiredArg:
"""Define a parameter as required."""
pass
def trigger_preprocessing(trigger):
"""Process triggers.
Convert integers to periodic triggers.
"""
if isinstance(trigger, Trigger):
return trigger
else:
try:
return Periodic(period=int(trigger), phase=0)
except Exception:
raise ValueError("Expected a hoomd.trigger.trigger_like object.")
def variant_preprocessing(variant):
"""Process variants.
Convert floats to constant variants.
"""
if isinstance(variant, Variant):
return variant
else:
try:
return Constant(float(variant))
except Exception:
raise ValueError("Expected a hoomd.variant.variant_like object.")
def box_preprocessing(box):
"""Process boxes.
Convert values that `Box.from_box` handles.
"""
if isinstance(box, hoomd.Box):
return box
else:
try:
return hoomd.Box.from_box(box)
except Exception:
raise ValueError(f"{box} is not convertible into a hoomd.Box object"
f". using hoomd.Box.from_box")
def box_variant_preprocessing(input):
"""Process box variants.
Convert boxes and length-6 array-like objects to
`hoomd.variant.box.Constant`.
"""
if isinstance(input, hoomd.variant.box.BoxVariant):
return input
else:
try:
return hoomd.variant.box.Constant(box_preprocessing(input))
except Exception:
raise ValueError(f"{input} is not convertible into a "
f"hoomd.variant.box.BoxVariant object.")
def positive_real(number):
"""Ensure that a value is positive."""
try:
float_number = float(number)
except Exception as err:
raise TypeConversionError(
f"{number} not convertible to float.") from err
if float_number <= 0:
raise TypeConversionError("Expected a number greater than zero.")
return float_number
def nonnegative_real(number):
"""Ensure that a value is not negative."""
try:
float_number = float(number)
except Exception as err:
raise TypeConversionError(
f"{number} not convertible to float.") from err
if float_number < 0:
raise TypeConversionError("Expected a nonnegative real number.")
return float_number
def identity(value):
"""Return the given value."""
return value
class _HelpValidate(ABC):
"""Base class for classes that perform validation on an inputed value.
Supports arbitrary pre and post processing as well as optionally allowing
None values. The `_validate` function should raise a `ValueError` or
`TypeConverterValue` if validation fails, else it should return the
validated/transformed value.
"""
def __init__(self, preprocess=None, postprocess=None, allow_none=False):
self._preprocess = identity if preprocess is None else preprocess
self._postprocess = identity if postprocess is None else postprocess
self._allow_none = allow_none
def __call__(self, value):
if value is RequiredArg:
return value
if value is None:
if not self._allow_none:
raise ValueError("None is not allowed.")
else:
return None
try:
return self._postprocess(self._validate(self._preprocess(value)))
except Exception as err:
if isinstance(err, TypeConversionError):
raise err
raise TypeConversionError(
f"Error raised in conversion: {str(err)}") from err
@abstractmethod
def _validate(self, value):
pass
class Any(_HelpValidate):
"""Accept any input."""
def __init__(self, preprocess=None, postprocess=None):
super().__init__(preprocess, postprocess)
def _validate(self, value):
return value
def __str__(self):
"""str: String representation of the validator."""
return "Any()"
class Either(_HelpValidate):
"""Class that has multiple equally valid validation methods for an input.
For instance if a parameter can either be a length 6 tuple or float then
Example::
e = Either(to_type_converter((float,) * 6), to_type_converter(float))
would allow either value to pass.
"""
def __init__(self, specs, preprocess=None, postprocess=None):
super().__init__(preprocess, postprocess)
self.specs = specs
def _validate(self, value):
for spec in self.specs:
try:
return spec(value)
except Exception:
continue
raise ValueError(f"value {value} not converible using "
f"{[str(spec) for spec in self.specs]}")
def __str__(self):
"""str: String representation of the validator."""
return f"Either({[str(spec) for spec in self.specs]})"
class OnlyIf(_HelpValidate):
"""A wrapper around a validation function.
Not strictly necessary, but keeps the theme of the other classes, and allows
pre/post-processing and optionally allows None.
"""
def __init__(self,
cond,
preprocess=None,
postprocess=None,
allow_none=False):
super().__init__(preprocess, postprocess, allow_none)
self.cond = cond
def _validate(self, value):
return self.cond(value)
def __str__(self):
"""str: String representation of the validator."""
return f"OnlyIf({str(self.cond)})"
class OnlyTypes(_HelpValidate):
"""Only allow values that are instances of type.
Developers should consider the `collections.abc` module in using this type.
In general `OnlyTypes(Sequence)` is more readable than the similar
`OnlyIf(lambda x: hasattr(x, '__iter__'))`. If a sequence of types is
provided and ``strict`` is ``False``, conversions will be attempted in the
order of the ``types`` sequence.
"""
def __init__(self,
*types,
disallow_types=None,
strict=False,
preprocess=None,
postprocess=None,
allow_none=False):
super().__init__(preprocess, postprocess, allow_none)
# Handle if a class is passed rather than an iterable of classes
self.types = types
if disallow_types is None:
self.disallow_types = ()
else:
self.disallow_types = disallow_types
self.strict = strict
def _validate(self, value):
if isinstance(value, self.disallow_types):
raise TypeConversionError(
f"Value {value} cannot be of type {type(value)}")
if isinstance(value, self.types):
return value
elif self.strict:
raise ValueError(
f"Value {value} is not an instance of any of {self.types}.")
else:
for type_ in self.types:
try:
return type_(value)
except Exception:
pass
raise ValueError(
f"Value {value} is not convertable into any of these types "
f"{self.types}")
def __str__(self):
"""str: String representation of the validator."""
return f"OnlyTypes({str(self.types)})"
class OnlyFrom(_HelpValidate):
"""Validates a value against a given set of options.
An example that allows integers less than ten `OnlyFrom(range(10))`. Note
that generator expressions are fine.
"""
def __init__(self,
options,
preprocess=None,
postprocess=None,
allow_none=False):
super().__init__(preprocess, postprocess, allow_none)
self.options = set(options)
def _validate(self, value):
if value in self:
return value
else:
raise ValueError(f"Value {value} not in options: {self.options}")
def __contains__(self, value):
"""bool: True when value is in the options."""
return value in self.options
def __str__(self):
"""str: String representation of the validator."""
return "OnlyFrom[{self.options}]"
class SetOnce:
"""Used to make properties read-only after setting."""
def __init__(self, validation):
if isclass(validation):
self._validation = OnlyTypes(validation)
else:
self._validation = validation
def __call__(self, value):
"""Handle setting values."""
if self._validation is not None:
val = self._validation(value)
self._validation = None
return val
else:
raise ValueError("Attribute is read-only.")
class TypeConverter(ABC):
"""Base class for TypeConverter's encodes structure and validation.
Subclasses represent validating a different data structure. When called they
are to attempt to validate and transform the inputs as given by the
specification set up at the initialization.
Note:
Subclasses should not be instantiated directly. Instead use
`to_type_converter`.
"""
@abstractmethod
def __init__(self, *args, **kwargs):
pass
def __call__(self, value):
"""Called when values are set."""
if value is RequiredArg:
return value
return self._validate(value)
@abstractmethod
def _validate(self, value):
pass
class NDArrayValidator(_HelpValidate):
"""Validates array and array-like structures.
Args:
dtype (numpy.dtype): The type of individual items in the array.
shape (`tuple` [`int`, ...], optional): The shape of the array. The
number of dimensions is specified by the length of the tuple and the
length of a dimension is specified by the value. A value of ``None``
in an index indicates any length is acceptable. Defaults to
``(None,)``.
order (`str`, optional): The kind of ordering needed for the array.
Options are ``["C", "F", "K", "A"]``. See `numpy.array`
documentation for imformation about the orderings. Defaults to
`"K"`.
preprocess (callable, optional): An optional function like argument to
use to preprocess arrays before general validation. Defaults to
``None`` which mean on preprocessing.
preprocess (callable, optional): An optional function like argument to
use to postprocess arrays after general validation. Defaults to
``None`` which means no postprocessing.
allow_none (`bool`, optional): Whether to allow ``None`` as a valid
value. Defaults to ``None``.
The validation will attempt to convert array-like objects to arrays. We will
change the dtype and ordering if necessary, but do not reshape the given
arrays since this is non-trivial depending on the shape specification passed
in.
"""
def __init__(self,
dtype,
shape=(None,),
order="K",
preprocess=None,
postprocess=None,
allow_none=False):
"""Create a NDArrayValidator object."""
super().__init__(preprocess, postprocess, allow_none)
self._dtype = dtype
self._shape = shape
self._order = order
def _validate(self, arr):
"""Validate an array or array-like object."""
typed_and_ordered = np.array(arr, dtype=self._dtype, order=self._order)
if len(typed_and_ordered.shape) != len(self._shape):
raise ValueError(
f"Expected array of {len(self._shape)} dimensions, but "
f"recieved array of {len(typed_and_ordered.shape)} dimensions.")
for i, dim in enumerate(self._shape):
if dim is not None:
if typed_and_ordered.shape[i] != dim:
raise ValueError(
f"In dimension {i}, expected size {dim}, but got size "
f"{typed_and_ordered.shape[i]}")
return typed_and_ordered
class _BaseConverter:
"""Get the base level (i.e. no deeper level exists) validator."""
_conversion_func_dict = {
Variant:
OnlyTypes(Variant, preprocess=variant_preprocessing),
ParticleFilter:
OnlyTypes(ParticleFilter, CustomFilter, strict=True),
str:
OnlyTypes(str, strict=True),
Trigger:
OnlyTypes(Trigger, preprocess=trigger_preprocessing),
hoomd.Box:
OnlyTypes(hoomd.Box, preprocess=box_preprocessing),
hoomd.variant.box.BoxVariant:
OnlyTypes(hoomd.variant.box.BoxVariant,
preprocess=box_variant_preprocessing),
# arrays default to float of one dimension of arbitrary length and
# ordering
np.ndarray:
NDArrayValidator(float),
}
@classmethod
def to_base_converter(cls, schema):
# If the schema is a class object
if isclass(schema):
# if constructor with special default setting logic
for special_class in cls._conversion_func_dict:
if issubclass(schema, special_class):
return cls._conversion_func_dict[special_class]
# constructor with no special logic
return OnlyTypes(schema)
# If the schema is a special_class instance
# if schema is a subtype of a type with special schema setting logic
for special_class in cls._conversion_func_dict:
if isinstance(schema, special_class):
return cls._conversion_func_dict[special_class]
# if schema is a callable assume that it is the validation function
if callable(schema):
return schema
# if any other object
else:
return OnlyTypes(type(schema))
class TypeConverterSequence(TypeConverter):
"""Validation for a generic any length sequence.
Uses `to_type_converter` for construction the validation. For each item in
the inputted sequence, a corresponding `TypeConverter` object is
constructed.
Args:
coverter (TypeConverter): Any object compatible with the type converter
API.
Specification:
When validating, the given element was given that element is repeated
for every element of the inputed sequence. This class is unsuited for
fix length sequences (`TypeConverterFixedLengthSequence` exists for
this). An Example,
Example::
# All elements should be floats
TypeConverterSequence(float)
"""
def __init__(self, converter):
self.converter = to_type_converter(converter)
def _validate(self, sequence):
"""Called when the value is set."""
if not _is_iterable(sequence):
raise TypeConversionError(
f"Expected a sequence like instance. Received {sequence} of "
f"type {type(sequence)}.")
else:
new_sequence = []
try:
for i, v in enumerate(sequence):
new_sequence.append(self.converter(v))
except (ValueError, TypeError) as err:
raise TypeConversionError(
f"In list item number {i}: {str(err)}") from err
return new_sequence
class TypeConverterFixedLengthSequence(TypeConverter):
"""Validation for a fixed length sequence (read tuple).
Uses `to_type_converter` for construction the validation. For each item in
the inputted sequence, a corresponding `TypeConverter` object is
constructed.
Parameters:
sequence (Sequence[Any]): Any sequence or iterable, anything else passed
is an error.
Specification:
When validating, a sequence of the exact length given on instantiation
is expected, else an error is raised.
Example::
# Three floats
TypeConverterFixedLengthSequence((float, float, float))
# a string followed for a float and int
TypeConverterFixedLengthSequence((string, float, int))
"""
def __init__(self, sequence):
self.converter = tuple([to_type_converter(item) for item in sequence])
def _validate(self, sequence):
"""Called when the value is set."""
if not _is_iterable(sequence):
raise TypeConversionError(
f"Expected a tuple like object. Received {sequence} of type "
f"{type(sequence)}.")
elif len(sequence) != len(self.converter):
raise TypeConversionError(
f"Expected exactly {len(self.converter)} items. Received "
f"{len(sequence)}.")
else:
new_sequence = []
try:
for i, (v, c) in enumerate(zip(sequence, self)):
new_sequence.append(c(v))
except (ValueError, TypeError) as err:
raise TypeConversionError(
f"In tuple item number {i}: {str(err)}") from err
return tuple(new_sequence)
def __iter__(self):
"""Iterate over converters in the sequence."""
yield from self.converter
def __getitem__(self, index):
"""Return the index-th converter."""
return self.converter[index]
class TypeConverterMapping(TypeConverter, MutableMapping):
"""Validation for a mapping of string keys to any type values.
Uses `to_type_converter` for construction the validation. For each value in
the inputted sequence, a corresponding `TypeConverter` object is
constructed.
Parameters:
mapping (Mapping[str, Any]): Any mapping, anything else passed is an
error.
Specification:
When validating, a subset of keys is expected to be used. No error is
raised if not all keys are used in the validation. The validation either
errors or returns a mapping with all the same keys as the inputted
mapping.
Example::
t = TypeConverterMapping({'str': str, 'list_of_floats': [float]})
# valid
t({'str': 'hello'})
# invalid
t({'new_key': None})
"""
def __init__(self, mapping):
self.converter = {
key: to_type_converter(value) for key, value in mapping.items()
}
def _validate(self, mapping):
"""Called when the value is set."""
if not isinstance(mapping, Mapping):
raise TypeConversionError(
f"Expected a dict like value. Recieved {mapping} of type "
f"{type(mapping)}.")
new_mapping = {}
for key, value in mapping.items():
if key in self:
try:
new_mapping[key] = self.converter[key](value)
except (ValueError, TypeError) as err:
raise TypeConversionError(
f"In key {key}: {str(err)}") from err
else:
new_mapping[key] = value
return new_mapping
def __iter__(self):
"""Iterate over converters in the mapping."""
yield from self.converter
def __getitem__(self, key):
"""Get a converter by key."""
return self.converter[key]
def __setitem__(self, key, value):
"""Set a converter by key."""
self.converter[key] = value
def __delitem__(self, key):
"""Remove a converter by key."""
del self.converter[key]
def __len__(self):
"""int: Number of converters."""
return len(self.converter)
def to_type_converter(value):
"""The function to use for creating a structure of `TypeConverter` objects.
This is the function to use when defining validation not any of the
`TypeConverter` subclasses.
Example::
# list take a list of tuples of 3 floats each
validation = to_type_converter(
{'str': str, 'list': [(float, float, float)]})
"""
if isinstance(value, tuple):
return TypeConverterFixedLengthSequence(value)
if _is_iterable(value):
if len(value) == 0:
return TypeConverterSequence(Any())
return TypeConverterSequence(value[0])
elif isinstance(value, Mapping):
return TypeConverterMapping(value)
else:
return _BaseConverter.to_base_converter(value)