Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor model config and convert it to pydantic #8501

Merged
merged 5 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)

import xarray as xr
from pydantic import ValidationError as PydanticValidationError
from typing_extensions import Self

from ert.config.gen_data_config import GenDataConfig
Expand Down Expand Up @@ -105,6 +106,7 @@ class ErtConfig:
model_config: ModelConfig = field(default_factory=ModelConfig)
user_config_file: str = "no_config"
config_path: str = field(init=False)
obs_config_file: Optional[str] = None

def __eq__(self, other: object) -> bool:
if not isinstance(other, ErtConfig):
Expand All @@ -118,7 +120,7 @@ def __post_init__(self) -> None:
if self.user_config_file
else os.getcwd()
)
self.enkf_obs: EnkfObs = self._create_observations()
self.enkf_obs: EnkfObs = self._create_observations(self.obs_config_file)

if len(self.summary_keys) != 0:
self.ensemble_config.addNode(self._create_summary_config())
Expand Down Expand Up @@ -197,6 +199,11 @@ def from_dict(cls, config_dict) -> Self:
substitution_list["<ECLBASE>"] = eclbase
except ConfigValidationError as e:
errors.append(e)
except PydanticValidationError as err:
# pydantic catches ValueError (which ConfigValidationError inherits from),
# so we need to unpack them again.
for e in err.errors():
errors.append(e["ctx"]["error"])

try:
workflow_jobs, workflows, hooked_workflows = cls._workflows_from_dict(
Expand Down Expand Up @@ -261,6 +268,7 @@ def from_dict(cls, config_dict) -> Self:
),
model_config=model_config,
user_config_file=config_file_path,
obs_config_file=config_dict.get(ConfigKeys.OBS_CONFIG),
)

@classmethod
Expand Down Expand Up @@ -884,9 +892,8 @@ def _create_summary_config(self) -> SummaryConfig:
refcase=time_map,
)

def _create_observations(self) -> EnkfObs:
def _create_observations(self, obs_config_file: str) -> EnkfObs:
obs_vectors: Dict[str, ObsVector] = {}
obs_config_file = self.model_config.obs_config_file
obs_time_list: Sequence[datetime] = []
if self.ensemble_config.refcase is not None:
obs_time_list = self.ensemble_config.refcase.all_dates
Expand Down
120 changes: 42 additions & 78 deletions src/ert/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import logging
import os.path
from datetime import datetime
from typing import TYPE_CHECKING, Optional, no_type_check
from typing import List, Optional, no_type_check

from .parsing import ConfigDict, ConfigKeys, ConfigValidationError, HistorySource
from pydantic import field_validator
from pydantic.dataclasses import dataclass

if TYPE_CHECKING:
from typing import List
from .parsing import ConfigDict, ConfigKeys, ConfigValidationError, HistorySource

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,76 +38,75 @@ def str_to_datetime(date_str: str) -> datetime:
DEFAULT_ECLBASE_FORMAT = "ECLBASE<IENS>"


@dataclass
class ModelConfig:
def __init__(
self,
num_realizations: int = 1,
history_source: HistorySource = DEFAULT_HISTORY_SOURCE,
runpath_format_string: str = DEFAULT_RUNPATH,
jobname_format_string: str = DEFAULT_JOBNAME_FORMAT,
eclbase_format_string: str = DEFAULT_ECLBASE_FORMAT,
gen_kw_export_name: str = DEFAULT_GEN_KW_EXPORT_NAME,
obs_config_file: Optional[str] = None,
time_map_file: Optional[str] = None,
):
self.num_realizations = num_realizations
self.history_source = history_source
self.jobname_format_string = _replace_runpath_format(jobname_format_string)
self.eclbase_format_string = _replace_runpath_format(eclbase_format_string)

# do not combine styles
num_realizations: int = 1
history_source: HistorySource = DEFAULT_HISTORY_SOURCE
runpath_format_string: str = DEFAULT_RUNPATH
jobname_format_string: str = DEFAULT_JOBNAME_FORMAT
eclbase_format_string: str = DEFAULT_ECLBASE_FORMAT
gen_kw_export_name: str = DEFAULT_GEN_KW_EXPORT_NAME
time_map: Optional[List[datetime]] = None

@field_validator("runpath_format_string", mode="before")
@classmethod
def validate_runpath(cls, runpath_format_string: str) -> str:
if "%d" in runpath_format_string and any(
x in runpath_format_string for x in ["<ITER>", "<IENS>"]
):
raise ConfigValidationError(
f"RUNPATH cannot combine deprecated and new style placeholders: `{runpath_format_string}`. Valid example `{DEFAULT_RUNPATH}`"
)

# do not allow multiple occurrences
for kw in ["<ITER>", "<IENS>"]:
if runpath_format_string.count(kw) > 1:
raise ConfigValidationError(
f"RUNPATH cannot contain multiple {kw} placeholders: `{runpath_format_string}`. Valid example `{DEFAULT_RUNPATH}`"
)

# do not allow too many placeholders
if runpath_format_string.count("%d") > 2:
raise ConfigValidationError(
f"RUNPATH cannot contain more than two value placeholders: `{runpath_format_string}`. Valid example `{DEFAULT_RUNPATH}`"
)

if "/" in self.jobname_format_string:
raise ConfigValidationError.with_context(
"JOBNAME cannot contain '/'.", jobname_format_string
)

self.runpath_format_string = _replace_runpath_format(runpath_format_string)

if not any(x in self.runpath_format_string for x in ["<ITER>", "<IENS>"]):
result = _replace_runpath_format(runpath_format_string)
if not any(x in result for x in ["<ITER>", "<IENS>"]):
logger.warning(
"RUNPATH keyword contains no value placeholders: "
f"`{runpath_format_string}`. Valid example: "
f"`{DEFAULT_RUNPATH}` "
)
return result

@field_validator("jobname_format_string", mode="before")
@classmethod
def validate_jobname(cls, jobname_format_string: str) -> str:
result = _replace_runpath_format(jobname_format_string)
if "/" in jobname_format_string:
raise ConfigValidationError.with_context(
"JOBNAME cannot contain '/'.", jobname_format_string
)
return result

@field_validator("eclbase_format_string", mode="before")
@classmethod
def transform(cls, eclbase_format_string: str) -> str:
return _replace_runpath_format(eclbase_format_string)

self.gen_kw_export_name = gen_kw_export_name
self.obs_config_file = obs_config_file
self.time_map = None
self._time_map_file = (
@no_type_check
@classmethod
def from_dict(cls, config_dict: ConfigDict) -> "ModelConfig":
time_map_file = config_dict.get(ConfigKeys.TIME_MAP)
time_map_file = (
os.path.abspath(time_map_file) if time_map_file is not None else None
)

time_map = None
if time_map_file is not None:
try:
self.time_map = _read_time_map(time_map_file)
time_map = _read_time_map(time_map_file)
except (ValueError, IOError) as err:
raise ConfigValidationError.with_context(
f"Could not read timemap file {time_map_file}: {err}", time_map_file
) from err

@no_type_check
@classmethod
def from_dict(cls, config_dict: ConfigDict) -> "ModelConfig":
return cls(
num_realizations=config_dict.get(ConfigKeys.NUM_REALIZATIONS, 1),
history_source=config_dict.get(
Expand All @@ -127,42 +126,7 @@ def from_dict(cls, config_dict: ConfigDict) -> "ModelConfig":
gen_kw_export_name=config_dict.get(
ConfigKeys.GEN_KW_EXPORT_NAME, DEFAULT_GEN_KW_EXPORT_NAME
),
obs_config_file=config_dict.get(ConfigKeys.OBS_CONFIG),
time_map_file=config_dict.get(ConfigKeys.TIME_MAP),
)

def __repr__(self) -> str:
return (
"ModelConfig("
f"num_realizations={self.num_realizations}, "
f"history_source={self.history_source}, "
f"runpath_format_string={self.runpath_format_string}, "
f"jobname_format_string={self.jobname_format_string}, "
f"eclbase_format_string={self.eclbase_format_string}, "
f"gen_kw_export_name={self.gen_kw_export_name}, "
f"obs_config_file={self.obs_config_file}, "
f"time_map_file={self._time_map_file}"
")"
)

def __str__(self) -> str:
return repr(self)

def __eq__(self, other: object) -> bool:
if not isinstance(other, ModelConfig):
return False
return all(
[
self.num_realizations == other.num_realizations,
self.history_source == other.history_source,
self.runpath_format_string == other.runpath_format_string,
self.jobname_format_string == other.jobname_format_string,
self.eclbase_format_string == other.eclbase_format_string,
self.gen_kw_export_name == other.gen_kw_export_name,
self.obs_config_file == other.obs_config_file,
self._time_map_file == other._time_map_file,
self.time_map == other.time_map,
]
time_map=time_map,
)


Expand Down