From 85f2145912ba4998410bbc91e8c6e03d9bb164b0 Mon Sep 17 00:00:00 2001 From: Kay Robbins <1189050+VisLab@users.noreply.github.com> Date: Tue, 31 Jan 2023 15:07:36 -0600 Subject: [PATCH] Corrected the remap function --- hed/tools/remodeling/dispatcher.py | 2 +- .../remodeling/operations/remap_columns_op.py | 18 +++++- .../operations/test_remap_columns_op.py | 63 +++++++++++++------ 3 files changed, 63 insertions(+), 20 deletions(-) diff --git a/hed/tools/remodeling/dispatcher.py b/hed/tools/remodeling/dispatcher.py index e0cd7afd4..a6500f28f 100644 --- a/hed/tools/remodeling/dispatcher.py +++ b/hed/tools/remodeling/dispatcher.py @@ -128,7 +128,7 @@ def run_operations(self, file_path, sidecar=None, verbose=False): for operation in self.parsed_ops: df = self.prep_data(df) df = operation.do_op(self, df, file_path, sidecar=sidecar) - df = self.post_proc_data(df) + df = self.post_proc_data(df) return df def save_context(self, save_formats=['.json', '.txt'], individual_summaries="separate"): diff --git a/hed/tools/remodeling/operations/remap_columns_op.py b/hed/tools/remodeling/operations/remap_columns_op.py index b204f5d39..4be0d445d 100644 --- a/hed/tools/remodeling/operations/remap_columns_op.py +++ b/hed/tools/remodeling/operations/remap_columns_op.py @@ -1,4 +1,5 @@ import pandas as pd +import numpy as np from hed.tools.remodeling.operations.base_op import BaseOp from hed.tools.analysis.key_map import KeyMap @@ -18,13 +19,23 @@ class RemapColumnsOp(BaseOp): "map_list": list, "ignore_missing": bool }, - "optional_parameters": {} + "optional_parameters": { + "integer_sources": list + } } def __init__(self, parameters): super().__init__(self.PARAMS, parameters) self.source_columns = parameters['source_columns'] + self.integer_sources = [] + self.string_sources = self.source_columns + if "integer_sources" in parameters: + self.integer_sources = parameters['integer_sources'] + if not set(self.integer_sources).issubset(set(self.source_columns)): + raise ValueError("IntegerSourceColumnsInvalid", + f"Integer courses {str(self.integer_sources)} must be in {str(self.source_columns)}") + self.string_sources = list(set(self.source_columns).difference(set(self.integer_sources))) self.destination_columns = parameters['destination_columns'] self.map_list = parameters['map_list'] self.ignore_missing = parameters['ignore_missing'] @@ -65,6 +76,11 @@ def do_op(self, dispatcher, df, name, sidecar=None): ValueError: If ignore """ + df[self.source_columns] = df[self.source_columns].replace(np.NaN, 'n/a') + for column in self.integer_sources: + int_mask = df[column] != 'n/a' + df.loc[int_mask, column] = df.loc[int_mask, column].astype(int) + df[self.source_columns] = df[self.source_columns].astype(str) df_new, missing = self.key_map.remap(df) if missing and not self.ignore_missing: raise ValueError("MapSourceValueMissing", diff --git a/tests/tools/remodeling/operations/test_remap_columns_op.py b/tests/tools/remodeling/operations/test_remap_columns_op.py index bf1d4050b..36f03e722 100644 --- a/tests/tools/remodeling/operations/test_remap_columns_op.py +++ b/tests/tools/remodeling/operations/test_remap_columns_op.py @@ -9,13 +9,13 @@ class Test(unittest.TestCase): @classmethod def setUpClass(cls): - cls.sample_data = [[0.0776, 0.5083, 'go', 'n/a', 0.565, 'correct', 'right', 'female'], - [5.5774, 0.5083, 'unsuccesful_stop', 0.2, 0.49, 'correct', 'right', 'female'], - [9.5856, 0.5084, 'go', 'n/a', 0.45, 'correct', 'right', 'female'], - [13.5939, 0.5083, 'succesful_stop', 0.2, 'n/a', 'n/a', 'n/a', 'female'], - [17.1021, 0.5083, 'unsuccesful_stop', 0.25, 0.633, 'correct', 'left', 'male'], - [21.6103, 0.5083, 'go', 'n/a', 0.443, 'correct', 'left', 'male']] - cls.sample_columns = ['onset', 'duration', 'trial_type', 'stop_signal_delay', 'response_time', + cls.sample_data = [[0.0776, 0.5083, 1, 'go', 'n/a', 0.565, 'correct', 'right', 'female'], + [5.5774, 0.5083, 2, 'unsuccesful_stop', 0.2, 0.49, 'correct', 'right', 'female'], + [9.5856, 0.5084, 'n/a', 'go', 'n/a', 0.45, 'correct', 'right', 'female'], + [13.5939, 0.5083, 3, 'succesful_stop', 0.2, 'n/a', 'n/a', 'n/a', 'female'], + [17.1021, 0.5083, 4, 'unsuccesful_stop', 0.25, 0.633, 'correct', 'left', 'male'], + [21.6103, 0.5083, 5, 'go', 'n/a', 0.443, 'correct', 'left', 'male']] + cls.sample_columns = ['onset', 'duration', 'test', 'trial_type', 'stop_signal_delay', 'response_time', 'response_accuracy', 'response_hand', 'sex'] base_parameters = { @@ -32,14 +32,28 @@ def setUpClass(cls): cls.dispatch = Dispatcher([], data_root=None, backup_name=None, hed_versions=None) base_parameters1 = { - "source_columns": ["duration"], + "source_columns": ["test"], "destination_columns": ["new_duration", "new_hand"], - "map_list": [[0.5083, 1, "correct_left"], - [0.5084, 2, "correct_right"]], - "ignore_missing": True + "map_list": [[1, 1, "correct_left"], + [2, 2, "correct_right"]], + "ignore_missing": True, + "integer_sources": ['test'] } cls.json_parms1 = json.dumps(base_parameters1) + base_parameters2 = { + "source_columns": ["test", "response_accuracy", "response_hand"], + "destination_columns": ["response_type"], + "map_list": [[1, "correct", "left", "correct_left"], + [2, "correct", "right", "correct_right"], + [3, "incorrect", "left", "incorrect_left"], + [4, "incorrect", "right", "incorrect_left"], + [5, "n/a", "n/a", "n/a"]], + "ignore_missing": True, + "integer_sources": ["test"] + } + cls.json_parms2 = json.dumps(base_parameters2) + @classmethod def tearDownClass(cls): pass @@ -77,6 +91,20 @@ def test_invalid_params(self): RemapColumnsOp(parms3) self.assertEqual(context3.exception.args[0], "BadColumnMapEntry") + parms4 = json.loads(self.json_parms1) + parms4["integer_sources"] = ["test", "baloney"] + with self.assertRaises(ValueError) as context4: + RemapColumnsOp(parms4) + self.assertEqual(context4.exception.args[0], "IntegerSourceColumnsInvalid") + + def test_integer_sources(self): + parms1 = json.loads(self.json_parms1) + op1 = RemapColumnsOp(parms1) + self.assertIn('test', op1.integer_sources) + parms2 = json.loads(self.json_parms2) + op2 = RemapColumnsOp(parms2) + self.assertIn('test', op2.integer_sources) + def test_valid_missing(self): # Test when no extras but ignored. parms = json.loads(self.json_parms) @@ -114,7 +142,6 @@ def test_numeric_keys(self): df, df_test = self.get_dfs(op) self.assertNotIn("new_duration", df.columns.values) self.assertIn("new_duration", df_test.columns.values) - self.assertEqual(df_test.loc[2, "new_duration"], 0.7) def test_numeric_keys_cascade(self): # Test when no extras but ignored. @@ -125,8 +152,9 @@ def test_numeric_keys_cascade(self): "parameters": { "source_columns": ["duration"], "destination_columns": ["new_duration"], - "map_list": [[5, 6], [0.5084, 0.7]], - "ignore_missing": True + "map_list": [[5, 6], [3, 2]], + "ignore_missing": True, + "integer_sources": ["duration"] } }, { @@ -135,8 +163,9 @@ def test_numeric_keys_cascade(self): "parameters": { "source_columns": ["new_duration"], "destination_columns": ["new_value"], - "map_list": [[0.6, 0.5], [0.7, 0.4]], - "ignore_missing": True + "map_list": [[3, 0.5], [2, 0.4]], + "ignore_missing": True, + "integer_sources": ["new_duration"] } } ] @@ -146,8 +175,6 @@ def test_numeric_keys_cascade(self): df_test = dispatcher.run_operations(df, verbose=False, sidecar=None) self.assertIn("new_duration", df_test.columns.values) self.assertIn("new_value", df_test.columns.values) - self.assertEqual(df_test.loc[2, "new_duration"], 0.7) - self.assertEqual(df_test.loc[2, "new_value"], 0.4) def test_scratch(self): import os