diff --git a/ppsci/constraint/boundary_constraint.py b/ppsci/constraint/boundary_constraint.py index 6420a96850..afbad13c16 100644 --- a/ppsci/constraint/boundary_constraint.py +++ b/ppsci/constraint/boundary_constraint.py @@ -85,10 +85,9 @@ def __init__( weight_dict: Optional[Dict[str, Union[float, Callable]]] = None, name: str = "BC", ): - self.output_expr = output_expr self.label_dict = label_dict self.input_keys = geom.dim_keys - self.output_keys = list(label_dict.keys()) + self.output_keys = tuple(label_dict.keys()) self.output_expr = { k: v for k, v in output_expr.items() if k in self.output_keys } diff --git a/ppsci/constraint/initial_constraint.py b/ppsci/constraint/initial_constraint.py index d32d8c00c4..351af60c74 100644 --- a/ppsci/constraint/initial_constraint.py +++ b/ppsci/constraint/initial_constraint.py @@ -88,10 +88,9 @@ def __init__( weight_dict: Optional[Dict[str, Callable]] = None, name: str = "IC", ): - self.output_expr = output_expr self.label_dict = label_dict self.input_keys = geom.dim_keys - self.output_keys = list(label_dict.keys()) + self.output_keys = tuple(label_dict.keys()) self.output_expr = { k: v for k, v in output_expr.items() if k in self.output_keys } diff --git a/ppsci/constraint/integral_constraint.py b/ppsci/constraint/integral_constraint.py index 511f823738..63d8314fa3 100644 --- a/ppsci/constraint/integral_constraint.py +++ b/ppsci/constraint/integral_constraint.py @@ -85,10 +85,9 @@ def __init__( weight_dict: Optional[Dict[str, Callable]] = None, name: str = "IgC", ): - self.output_expr = output_expr self.label_dict = label_dict self.input_keys = geom.dim_keys - self.output_keys = list(label_dict.keys()) + self.output_keys = tuple(label_dict.keys()) self.output_expr = { k: v for k, v in output_expr.items() if k in self.output_keys } diff --git a/ppsci/constraint/interior_constraint.py b/ppsci/constraint/interior_constraint.py index d0c77df10a..a333c82db3 100644 --- a/ppsci/constraint/interior_constraint.py +++ b/ppsci/constraint/interior_constraint.py @@ -85,10 +85,9 @@ def __init__( weight_dict: Optional[Dict[str, Union[Callable, float]]] = None, name: str = "EQ", ): - self.output_expr = output_expr self.label_dict = label_dict self.input_keys = geom.dim_keys - self.output_keys = list(label_dict.keys()) + self.output_keys = tuple(label_dict.keys()) self.output_expr = { k: v for k, v in output_expr.items() if k in self.output_keys } diff --git a/ppsci/constraint/periodic_constraint.py b/ppsci/constraint/periodic_constraint.py index 7ad3e2fc1d..6aace571ec 100644 --- a/ppsci/constraint/periodic_constraint.py +++ b/ppsci/constraint/periodic_constraint.py @@ -72,9 +72,8 @@ def __init__( weight_dict: Optional[Dict[str, Callable]] = None, name: str = "PeriodicBC", ): - self.output_expr = output_expr self.input_keys = geom.dim_keys - self.output_keys = list(output_expr.keys()) + self.output_keys = tuple(output_expr.keys()) self.output_expr = { k: v for k, v in output_expr.items() if k in self.output_keys } diff --git a/ppsci/constraint/supervised_constraint.py b/ppsci/constraint/supervised_constraint.py index a0c34d8bec..84b8816222 100644 --- a/ppsci/constraint/supervised_constraint.py +++ b/ppsci/constraint/supervised_constraint.py @@ -60,19 +60,20 @@ def __init__( output_expr: Optional[Dict[str, Callable]] = None, name: str = "Sup", ): - self.output_expr = output_expr - # build dataset _dataset = dataset.build_dataset(dataloader_cfg["dataset"]) self.input_keys = _dataset.input_keys self.output_keys = ( - list(output_expr.keys()) if output_expr is not None else _dataset.label_keys + tuple(output_expr.keys()) + if output_expr is not None + else _dataset.label_keys ) + self.output_expr = output_expr if self.output_expr is None: self.output_expr = { - key: lambda out, k=key: out[k] for key in self.output_keys + key: (lambda out, k=key: out[k]) for key in self.output_keys } # construct dataloader with dataset and dataloader_cfg diff --git a/ppsci/validate/geo_validator.py b/ppsci/validate/geo_validator.py index b9b781f87a..6741baddc7 100644 --- a/ppsci/validate/geo_validator.py +++ b/ppsci/validate/geo_validator.py @@ -86,7 +86,7 @@ def __init__( self.output_expr = output_expr self.label_dict = label_dict self.input_keys = geom.dim_keys - self.output_keys = list(label_dict.keys()) + self.output_keys = tuple(label_dict.keys()) nx = dataloader_cfg["total_size"] self.num_timestamps = 1 diff --git a/ppsci/validate/sup_validator.py b/ppsci/validate/sup_validator.py index a45a28f983..56f2b9a50a 100644 --- a/ppsci/validate/sup_validator.py +++ b/ppsci/validate/sup_validator.py @@ -75,7 +75,9 @@ def __init__( self.input_keys = _dataset.input_keys self.output_keys = ( - list(output_expr.keys()) if output_expr is not None else _dataset.label_keys + tuple(output_expr.keys()) + if output_expr is not None + else _dataset.label_keys ) if self.output_expr is None: