Skip to content

Commit

Permalink
feat: adds merge_job_parameter_definitions() (#32)
Browse files Browse the repository at this point in the history
When creating/running a Job, that Job is defined with a Job Template but can
also include one or more externally-defined Environment Templates. This
adds support for including these Environment Templates when creating a
job via create_job(), checking job parameters via
preprocess_job_parameters(), and adds a new function that can be used
just for checking compatibility (merge_job_parameter_definitions()).

Signed-off-by: Daniel Neilson <53624638+ddneilson@users.noreply.github.com>
  • Loading branch information
ddneilson committed Jan 17, 2024
1 parent c6d7752 commit ad944eb
Show file tree
Hide file tree
Showing 5 changed files with 652 additions and 47 deletions.
97 changes: 59 additions & 38 deletions src/openjd/model/_create_job.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from typing import TYPE_CHECKING, cast
from typing import Optional, cast

from pydantic import ValidationError

from ._errors import DecodeValidationError
from ._errors import CompatibilityError, DecodeValidationError
from ._symbol_table import SymbolTable
from ._internal import instantiate_model
from ._merge_job_parameter import merge_job_parameter_definitions
from ._types import (
EnvironmentTemplate,
Job,
JobParameterDefinition,
JobParameterInputValues,
JobParameterValues,
JobTemplate,
Expand All @@ -18,11 +21,6 @@
)
from ._convert_pydantic_error import pydantic_validationerrors_to_str, ErrorDict

if TYPE_CHECKING:
# Avoiding a circular import that occurs when trying to import FormatString
from .v2023_09 import JobTemplate as JobTemplate_2023_09


__all__ = ("preprocess_job_parameters",)


Expand All @@ -31,39 +29,36 @@
# =======================================================================


def _collect_available_parameter_names(job_template: JobTemplate) -> set[str]:
# job_template.parameterDefinitions is a list[JobParameterDefinitionList]
return (
set(param.name for param in job_template.parameterDefinitions)
if job_template.parameterDefinitions
else set()
)
def _collect_available_parameter_names(
job_parameter_definitions: list[JobParameterDefinition],
) -> set[str]:
return set(param.name for param in job_parameter_definitions)


def _collect_extra_job_parameter_names(
job_template: JobTemplate, job_parameter_values: JobParameterInputValues
job_parameter_definitions: list[JobParameterDefinition],
job_parameter_values: JobParameterInputValues,
) -> set[str]:
# Verify that job parameters are provided if the template requires them
available_parameters: set[str] = _collect_available_parameter_names(job_template)
available_parameters: set[str] = _collect_available_parameter_names(job_parameter_definitions)
return set(job_parameter_values).difference(available_parameters)


def _collect_missing_job_parameter_names(
job_template: JobTemplate, job_parameter_values: JobParameterValues
job_parameter_definitions: list[JobParameterDefinition],
job_parameter_values: JobParameterValues,
) -> set[str]:
available_parameters: set[str] = _collect_available_parameter_names(job_template)
available_parameters: set[str] = _collect_available_parameter_names(job_parameter_definitions)
return available_parameters.difference(set(job_parameter_values.keys()))


def _collect_defaults_2023_09(
job_template: "JobTemplate_2023_09", job_parameter_values: JobParameterInputValues
job_parameter_definitions: list[JobParameterDefinition],
job_parameter_values: JobParameterInputValues,
) -> JobParameterValues:
# For the type checker
assert job_template.parameterDefinitions is not None

return_value: JobParameterValues = dict[str, ParameterValue]()
# Collect defaults
for param in job_template.parameterDefinitions:
for param in job_parameter_definitions:
if param.name not in job_parameter_values:
if param.default is not None:
return_value[param.name] = ParameterValue(
Expand All @@ -80,14 +75,12 @@ def _collect_defaults_2023_09(


def _check_2023_09(
job_template: "JobTemplate_2023_09", job_parameter_values: JobParameterValues
job_parameter_definitions: list[JobParameterDefinition],
job_parameter_values: JobParameterValues,
) -> None:
# For the type checker
assert job_template.parameterDefinitions is not None

errors = list[str]()
# Check values
for param in job_template.parameterDefinitions:
for param in job_parameter_definitions:
if param.name in job_parameter_values:
param_value = job_parameter_values[param.name]
try:
Expand All @@ -96,11 +89,14 @@ def _check_2023_09(
errors.append(str(err))

if errors:
raise ValueError(", ".join(errors))
raise ValueError("\n".join(errors))


def preprocess_job_parameters(
*, job_template: JobTemplate, job_parameter_values: JobParameterInputValues
*,
job_template: JobTemplate,
job_parameter_values: JobParameterInputValues,
environment_templates: Optional[list[EnvironmentTemplate]] = None,
) -> JobParameterValues:
"""Preprocess a collection of job parameter values. Must be used prior to
instantiating a Job Template into a Job.
Expand All @@ -117,6 +113,8 @@ def preprocess_job_parameters(
job_template (JobTemplate) -- A Job Template to check the job parameter values against.
job_parameter_values (JobParameterValues) -- Mapping of Job Parameter names to values.
e.g. { "Foo": 12 } if you have a Job Parameter named "Foo"
environment_templates (Optional[list[EnvironmentTemplate]]) -- An ordered list of the
externally defined Environment Templates that are applied to the Job.
Returns:
A copy of job_parameter_values, but with added values for any missing job parameters
Expand All @@ -127,34 +125,49 @@ def preprocess_job_parameters(
"""
if job_template.version not in (SchemaVersion.v2023_09,):
raise NotImplementedError(f"Not implemented for schema version {job_template.version}")
if environment_templates and any(
env.version not in (SchemaVersion.v2023_09,) for env in environment_templates
):
raise NotImplementedError(
f"Not implemented for Environment Template schema versions other than {str(SchemaVersion.ENVIRONMENT_v2023_09)}"
)

return_value: JobParameterValues = dict[str, ParameterValue]()
errors = list[str]()

parameterDefinitions: Optional[list[JobParameterDefinition]] = None
try:
parameterDefinitions = merge_job_parameter_definitions(
job_template=job_template, environment_templates=environment_templates
)
except CompatibilityError as e:
# There's no point in continuing if the job parameter definitions are not compatible.
raise ValueError(str(e))

extra_defined_parameters = _collect_extra_job_parameter_names(
job_template, job_parameter_values
parameterDefinitions, job_parameter_values
)
if extra_defined_parameters:
extra_list = ", ".join(extra_defined_parameters)
extra_list = ", ".join(sorted(extra_defined_parameters))
errors.append(
f"Job parameter values provided for parameters that are not defined in the template: {extra_list}"
)
if job_template.parameterDefinitions:
if parameterDefinitions:
# Set of all required, but undefined, job parameter values
try:
if job_template.version == SchemaVersion.v2023_09:
return_value = _collect_defaults_2023_09(job_template, job_parameter_values)
_check_2023_09(job_template, return_value)
return_value = _collect_defaults_2023_09(parameterDefinitions, job_parameter_values)
_check_2023_09(parameterDefinitions, return_value)
else:
raise NotImplementedError(
f"Not implemented for schema version {job_template.version}"
)
except ValueError as err:
errors.append(str(err))
missing = _collect_missing_job_parameter_names(job_template, return_value)
missing = _collect_missing_job_parameter_names(parameterDefinitions, return_value)

if missing:
missing_list = ", ".join(missing)
missing_list = ", ".join(sorted(missing))
errors.append(f"Values missing for required job parameters: {missing_list}")

if errors:
Expand All @@ -168,7 +181,12 @@ def preprocess_job_parameters(
# =======================================================================


def create_job(*, job_template: JobTemplate, job_parameter_values: JobParameterValues) -> Job:
def create_job(
*,
job_template: JobTemplate,
job_parameter_values: JobParameterValues,
environment_templates: Optional[list[EnvironmentTemplate]] = None,
) -> Job:
"""This function will create a job from a given Job Template and set of values for
Job Parameters. Minimally, values must be provided for Job Parameters that do not have
default values defined in the template.
Expand All @@ -179,6 +197,8 @@ def create_job(*, job_template: JobTemplate, job_parameter_values: JobParameterV
Arguments:
job_template (JobTemplate) -- A Job Template to check the job parameter values against.
job_parameter_values (JobParameterValues) -- Mapping of Job Parameter names to values.
environment_templates (Optional[list[EnvironmentTemplate]]) -- An ordered list of the
externally defined Environment Templates that are applied to the Job.
Raises:
DecodeValidationError
Expand All @@ -195,6 +215,7 @@ def create_job(*, job_template: JobTemplate, job_parameter_values: JobParameterV
job_parameter_values={
name: param.value for name, param in job_parameter_values.items()
},
environment_templates=environment_templates,
)
except ValueError as exc:
raise DecodeValidationError(str(exc))
Expand Down
75 changes: 74 additions & 1 deletion src/openjd/model/_merge_job_parameter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from collections import defaultdict
from decimal import Decimal
from typing import Any, NamedTuple, Optional, Union, cast

from ._errors import CompatibilityError
from ._parse import parse_model
from ._types import JobParameterDefinition
from ._types import JobParameterDefinition, JobTemplate, EnvironmentTemplate, SchemaVersion
from .v2023_09 import (
JobParameterType,
JobPathParameterDefinition,
Expand Down Expand Up @@ -40,6 +41,78 @@ class SourcedFloatParameterDefinition(NamedTuple):
definition: JobFloatParameterDefinition


def merge_job_parameter_definitions(
*,
job_template: Optional[JobTemplate] = None,
environment_templates: Optional[list[EnvironmentTemplate]] = None,
) -> list[JobParameterDefinition]:
"""This function merges the definitions of the Job Parameters in a given list of EnvironmentTemplates with
that in a JobTemplate; both the environment and job templates are optional, however. In the act of doing so,
it also checks that any multiply-defined job parameters' definitions are compatible with one another.
The merge order for these definitions is to first process all of the given environments in the order given,
and then to process the job template last.
Args:
job_template (Optional[JobTemplate], optional): A Job Template whose parameter definitions will
be merged last. Defaults to None.
environment_templates (Optional[list[EnvironmentTemplate]], optional): A list of Environment Templates
whose parameter definitions will be merged in the order given. Defaults to None.
Raises:
CompatibilityError: Raised if the given template's job parameter definitions are not compatible.
Returns:
list[JobParameterDefinition]: The result of merging the Job Parameter Definitions from all of the given
templates.
"""
if job_template and job_template.specificationVersion not in (SchemaVersion.v2023_09,):
raise NotImplementedError(f"Not implemented for schema version {job_template.version}")
if environment_templates and any(
env.specificationVersion not in (SchemaVersion.ENVIRONMENT_v2023_09,)
for env in environment_templates
):
raise NotImplementedError(
f"Not implemented for Environment Template schema versions other than {str(SchemaVersion.ENVIRONMENT_v2023_09)}"
)

# param name -> list[SourcedParamDefinition]
collected_definitions = defaultdict[str, list[SourcedParamDefinition]](list)

# external environments' definitions always come before the job template, so collect them first.
for env in environment_templates or []:
if not env.parameterDefinitions:
continue
for param in env.parameterDefinitions:
collected_definitions[param.name].append(
SourcedParamDefinition(
source=f"EnvironmentTemplate for {env.environment.name}", definition=param
)
)

if job_template is not None and job_template.parameterDefinitions is not None:
for param in job_template.parameterDefinitions:
collected_definitions[param.name].append(
SourcedParamDefinition(source="JobTemplate", definition=param)
)

errors = list[str]()
return_value = list[JobParameterDefinition]()

for name, source in collected_definitions.items():
try:
return_value.append(merge_job_parameter_definitions_for_one(source))
except CompatibilityError as e:
compat_errors = "\n\t".join(str(e).split("\n"))
errors.append(
f"The definitions for job parameter '{name}' are in conflict:\n\t{compat_errors}"
)

if errors:
raise CompatibilityError("\n".join(errors))
return return_value


def merge_job_parameter_definitions_for_one(
params: list[SourcedParamDefinition],
) -> JobParameterDefinition:
Expand Down
17 changes: 11 additions & 6 deletions src/openjd/model/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def decode_job_template(*, template: dict[str, Any]) -> JobTemplate:
)
except ValueError:
# Value of the schema version is not one we know.
values_allowed = ", ".join(str(s) for s in SchemaVersion.job_template_versions())
values_allowed = ", ".join(str(s.value) for s in SchemaVersion.job_template_versions())
raise DecodeValidationError(
(
f"Unknown template version: {document_version}. "
Expand All @@ -153,10 +153,10 @@ def decode_job_template(*, template: dict[str, Any]) -> JobTemplate:
)

if not SchemaVersion.is_job_template(schema_version):
values_allowed = ", ".join(str(s) for s in SchemaVersion.job_template_versions())
values_allowed = ", ".join(str(s.value) for s in SchemaVersion.job_template_versions())
raise DecodeValidationError(
(
f"Specification version '{str(schema_version)}' is not a Job Template version. "
f"Specification version '{document_version}' is not a Job Template version. "
f"Values allowed for 'specificationVersion' in Job Templates are: {values_allowed}"
)
)
Expand Down Expand Up @@ -202,15 +202,20 @@ def decode_environment_template(*, template: dict[str, Any]) -> EnvironmentTempl
)
except ValueError:
# Value of the schema version is not one we know.
values_allowed = ", ".join(str(s) for s in SchemaVersion.environment_template_versions())
values_allowed = ", ".join(
str(s.value) for s in SchemaVersion.environment_template_versions()
)
raise DecodeValidationError(
f"Unknown template version: {document_version}. Allowed values are: {values_allowed}"
)

if not SchemaVersion.is_environment_template(schema_version):
values_allowed = ", ".join(str(s) for s in SchemaVersion.environment_template_versions())
values_allowed = ", ".join(
str(s.value) for s in SchemaVersion.environment_template_versions()
)
raise DecodeValidationError(
f"Unknown template version: {document_version}. Allowed values are: {values_allowed}"
f"Specification version '{document_version}' is not an Environment Template version. "
f"Allowed values for 'specificationVersion' are: {values_allowed}"
)

if schema_version == SchemaVersion.ENVIRONMENT_v2023_09:
Expand Down
Loading

0 comments on commit ad944eb

Please sign in to comment.