From bb70db8aa0622870fac1cca96634d9435b4f15ed Mon Sep 17 00:00:00 2001 From: Kobi Felton Date: Fri, 30 Sep 2022 22:26:29 +0100 Subject: [PATCH] Format parity_plot correctly --- summit/benchmarks/experimental_emulator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/summit/benchmarks/experimental_emulator.py b/summit/benchmarks/experimental_emulator.py index e94e52a7..e55804b9 100644 --- a/summit/benchmarks/experimental_emulator.py +++ b/summit/benchmarks/experimental_emulator.py @@ -445,7 +445,7 @@ def _caclulate_input_dimensions(domain: Domain, descriptors_features): @staticmethod def _create_input_preprocessor(domain, **kwargs): - """Create feature preprocessors """ + """Create feature preprocessors""" transformers = [] # Numeric transforms numeric_features = [ @@ -503,7 +503,7 @@ def _create_input_preprocessor(domain, **kwargs): @staticmethod def _create_output_preprocessor(output_variable_names): - """"Create target preprocessors""" + """ "Create target preprocessors""" transformers = [ ("scale", StandardScaler(), output_variable_names), ("dst", FunctionTransformer(numpy_to_tensor), output_variable_names), @@ -865,13 +865,13 @@ def make_parity_plot( handles = [] r2_train = r2_score(y_train, y_train_pred) r2_train_patch = mpatches.Patch( - label=f"Train R2 = {r2_train:.2f}", color=train_color + label=r"Train $R^2$ =" + f"{r2_train:.2f}", color=train_color ) handles.append(r2_train_patch) if y_test is not None: r2_test = r2_score(y_test, y_test_pred) r2_test_patch = mpatches.Patch( - label=f"Test R2 = {r2_test:.2f}", color=test_color + label=r"Test $R^2$ =" + f"{r2_test:.2f}", color=test_color ) handles.append(r2_test_patch) @@ -888,7 +888,7 @@ def make_parity_plot( def numpy_to_tensor(X): - """Convert datasets into """ + """Convert datasets into""" if issparse(X): X = X.todense() return torch.tensor(X).float()