diff --git a/src/ert/config/ert_config.py b/src/ert/config/ert_config.py index c33bae44bd3..f220f6dfa67 100644 --- a/src/ert/config/ert_config.py +++ b/src/ert/config/ert_config.py @@ -24,6 +24,7 @@ ) import xarray as xr +from pydantic import ValidationError as PydanticValidationError from typing_extensions import Self from ert.config.gen_data_config import GenDataConfig @@ -105,6 +106,7 @@ class ErtConfig: model_config: ModelConfig = field(default_factory=ModelConfig) user_config_file: str = "no_config" config_path: str = field(init=False) + obs_config_file: Optional[str] = None def __eq__(self, other: object) -> bool: if not isinstance(other, ErtConfig): @@ -118,7 +120,7 @@ def __post_init__(self) -> None: if self.user_config_file else os.getcwd() ) - self.enkf_obs: EnkfObs = self._create_observations() + self.enkf_obs: EnkfObs = self._create_observations(self.obs_config_file) if len(self.summary_keys) != 0: self.ensemble_config.addNode(self._create_summary_config()) @@ -197,6 +199,11 @@ def from_dict(cls, config_dict) -> Self: substitution_list[""] = eclbase except ConfigValidationError as e: errors.append(e) + except PydanticValidationError as err: + # pydantic catches ValueError (which ConfigValidationError inherits from), + # so we need to unpack them again. + for e in err.errors(): + errors.append(e["ctx"]["error"]) try: workflow_jobs, workflows, hooked_workflows = cls._workflows_from_dict( @@ -261,6 +268,7 @@ def from_dict(cls, config_dict) -> Self: ), model_config=model_config, user_config_file=config_file_path, + obs_config_file=config_dict.get(ConfigKeys.OBS_CONFIG), ) @classmethod @@ -884,9 +892,8 @@ def _create_summary_config(self) -> SummaryConfig: refcase=time_map, ) - def _create_observations(self) -> EnkfObs: + def _create_observations(self, obs_config_file: str) -> EnkfObs: obs_vectors: Dict[str, ObsVector] = {} - obs_config_file = self.model_config.obs_config_file obs_time_list: Sequence[datetime] = [] if self.ensemble_config.refcase is not None: obs_time_list = self.ensemble_config.refcase.all_dates diff --git a/src/ert/config/model_config.py b/src/ert/config/model_config.py index 006fb6722de..2da3b805e4e 100644 --- a/src/ert/config/model_config.py +++ b/src/ert/config/model_config.py @@ -3,12 +3,12 @@ import logging import os.path from datetime import datetime -from typing import TYPE_CHECKING, Optional, no_type_check +from typing import List, Optional, no_type_check -from .parsing import ConfigDict, ConfigKeys, ConfigValidationError, HistorySource +from pydantic import field_validator +from pydantic.dataclasses import dataclass -if TYPE_CHECKING: - from typing import List +from .parsing import ConfigDict, ConfigKeys, ConfigValidationError, HistorySource logger = logging.getLogger(__name__) @@ -38,76 +38,75 @@ def str_to_datetime(date_str: str) -> datetime: DEFAULT_ECLBASE_FORMAT = "ECLBASE" +@dataclass class ModelConfig: - def __init__( - self, - num_realizations: int = 1, - history_source: HistorySource = DEFAULT_HISTORY_SOURCE, - runpath_format_string: str = DEFAULT_RUNPATH, - jobname_format_string: str = DEFAULT_JOBNAME_FORMAT, - eclbase_format_string: str = DEFAULT_ECLBASE_FORMAT, - gen_kw_export_name: str = DEFAULT_GEN_KW_EXPORT_NAME, - obs_config_file: Optional[str] = None, - time_map_file: Optional[str] = None, - ): - self.num_realizations = num_realizations - self.history_source = history_source - self.jobname_format_string = _replace_runpath_format(jobname_format_string) - self.eclbase_format_string = _replace_runpath_format(eclbase_format_string) - - # do not combine styles + num_realizations: int = 1 + history_source: HistorySource = DEFAULT_HISTORY_SOURCE + runpath_format_string: str = DEFAULT_RUNPATH + jobname_format_string: str = DEFAULT_JOBNAME_FORMAT + eclbase_format_string: str = DEFAULT_ECLBASE_FORMAT + gen_kw_export_name: str = DEFAULT_GEN_KW_EXPORT_NAME + time_map: Optional[List[datetime]] = None + + @field_validator("runpath_format_string", mode="before") + @classmethod + def validate_runpath(cls, runpath_format_string: str) -> str: if "%d" in runpath_format_string and any( x in runpath_format_string for x in ["", ""] ): raise ConfigValidationError( f"RUNPATH cannot combine deprecated and new style placeholders: `{runpath_format_string}`. Valid example `{DEFAULT_RUNPATH}`" ) - # do not allow multiple occurrences for kw in ["", ""]: if runpath_format_string.count(kw) > 1: raise ConfigValidationError( f"RUNPATH cannot contain multiple {kw} placeholders: `{runpath_format_string}`. Valid example `{DEFAULT_RUNPATH}`" ) - # do not allow too many placeholders if runpath_format_string.count("%d") > 2: raise ConfigValidationError( f"RUNPATH cannot contain more than two value placeholders: `{runpath_format_string}`. Valid example `{DEFAULT_RUNPATH}`" ) - - if "/" in self.jobname_format_string: - raise ConfigValidationError.with_context( - "JOBNAME cannot contain '/'.", jobname_format_string - ) - - self.runpath_format_string = _replace_runpath_format(runpath_format_string) - - if not any(x in self.runpath_format_string for x in ["", ""]): + result = _replace_runpath_format(runpath_format_string) + if not any(x in result for x in ["", ""]): logger.warning( "RUNPATH keyword contains no value placeholders: " f"`{runpath_format_string}`. Valid example: " f"`{DEFAULT_RUNPATH}` " ) + return result + + @field_validator("jobname_format_string", mode="before") + @classmethod + def validate_jobname(cls, jobname_format_string: str) -> str: + result = _replace_runpath_format(jobname_format_string) + if "/" in jobname_format_string: + raise ConfigValidationError.with_context( + "JOBNAME cannot contain '/'.", jobname_format_string + ) + return result + + @field_validator("eclbase_format_string", mode="before") + @classmethod + def transform(cls, eclbase_format_string: str) -> str: + return _replace_runpath_format(eclbase_format_string) - self.gen_kw_export_name = gen_kw_export_name - self.obs_config_file = obs_config_file - self.time_map = None - self._time_map_file = ( + @no_type_check + @classmethod + def from_dict(cls, config_dict: ConfigDict) -> "ModelConfig": + time_map_file = config_dict.get(ConfigKeys.TIME_MAP) + time_map_file = ( os.path.abspath(time_map_file) if time_map_file is not None else None ) - + time_map = None if time_map_file is not None: try: - self.time_map = _read_time_map(time_map_file) + time_map = _read_time_map(time_map_file) except (ValueError, IOError) as err: raise ConfigValidationError.with_context( f"Could not read timemap file {time_map_file}: {err}", time_map_file ) from err - - @no_type_check - @classmethod - def from_dict(cls, config_dict: ConfigDict) -> "ModelConfig": return cls( num_realizations=config_dict.get(ConfigKeys.NUM_REALIZATIONS, 1), history_source=config_dict.get( @@ -127,42 +126,7 @@ def from_dict(cls, config_dict: ConfigDict) -> "ModelConfig": gen_kw_export_name=config_dict.get( ConfigKeys.GEN_KW_EXPORT_NAME, DEFAULT_GEN_KW_EXPORT_NAME ), - obs_config_file=config_dict.get(ConfigKeys.OBS_CONFIG), - time_map_file=config_dict.get(ConfigKeys.TIME_MAP), - ) - - def __repr__(self) -> str: - return ( - "ModelConfig(" - f"num_realizations={self.num_realizations}, " - f"history_source={self.history_source}, " - f"runpath_format_string={self.runpath_format_string}, " - f"jobname_format_string={self.jobname_format_string}, " - f"eclbase_format_string={self.eclbase_format_string}, " - f"gen_kw_export_name={self.gen_kw_export_name}, " - f"obs_config_file={self.obs_config_file}, " - f"time_map_file={self._time_map_file}" - ")" - ) - - def __str__(self) -> str: - return repr(self) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ModelConfig): - return False - return all( - [ - self.num_realizations == other.num_realizations, - self.history_source == other.history_source, - self.runpath_format_string == other.runpath_format_string, - self.jobname_format_string == other.jobname_format_string, - self.eclbase_format_string == other.eclbase_format_string, - self.gen_kw_export_name == other.gen_kw_export_name, - self.obs_config_file == other.obs_config_file, - self._time_map_file == other._time_map_file, - self.time_map == other.time_map, - ] + time_map=time_map, )