Skip to content

Commit

Permalink
Merge pull request #627 from IanCa/dev_refactor
Browse files Browse the repository at this point in the history
Add more unit tests.  better nan and empty column handling
  • Loading branch information
VisLab authored Mar 19, 2023
2 parents 67ab89b + 84cf4e0 commit a4d3931
Show file tree
Hide file tree
Showing 5 changed files with 357 additions and 126 deletions.
58 changes: 37 additions & 21 deletions hed/models/base_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hed.models.column_mapper import ColumnMapper
from hed.errors.exceptions import HedFileError, HedExceptions
from hed.errors.error_reporter import ErrorHandler
import pandas as pd


class BaseInput:
Expand Down Expand Up @@ -66,10 +67,7 @@ def __init__(self, file, file_type=None, worksheet_name=None, has_column_names=T
elif not file:
raise HedFileError(HedExceptions.FILE_NOT_FOUND, "Empty file passed to BaseInput.", file)
elif input_type in self.TEXT_EXTENSION:
self._dataframe = pandas.read_csv(file, delimiter='\t', header=pandas_header,
dtype=str, keep_default_na=True, na_values=None)
# Convert nan values to a known value
self._dataframe = self._dataframe.fillna("n/a")
self._dataframe = pandas.read_csv(file, delimiter='\t', header=pandas_header, dtype=str)
elif input_type in self.EXCEL_EXTENSION:
self._loaded_workbook = openpyxl.load_workbook(file)
loaded_worksheet = self.get_worksheet(self._worksheet_name)
Expand Down Expand Up @@ -364,7 +362,7 @@ def assemble(self, mapper=None):
"""
if mapper is None:
mapper = self._mapper
import pandas as pd

transformers, need_categorical = mapper.get_transformers()
if not transformers:
return None
Expand All @@ -374,44 +372,62 @@ def assemble(self, mapper=None):

all_columns = all_columns.transform(transformers)

possible_column_references = [f"{column_name}" for column_name in self.columns if
column_name.lower() != "hed"]
return self._insert_columns(all_columns, list(transformers.keys()))

@staticmethod
def _find_column_refs(df):
found_column_references = []
for column_name in all_columns:
df = all_columns[column_name].str.findall("\[([a-z_\-0-9]+)\]", re.IGNORECASE)
u_vals = pd.Series([j for i in df for j in i], dtype=str)
for column_name in df:
df_temp = df[column_name].str.findall("\[([a-z_\-0-9]+)\]", re.IGNORECASE)
u_vals = pd.Series([j for i in df_temp for j in i], dtype=str)
u_vals = u_vals.unique()
for val in u_vals:
if val not in found_column_references:
found_column_references.append(val)

return found_column_references

@staticmethod
def _insert_columns(df, known_columns=None):
if known_columns is None:
known_columns = list(df.columns)
possible_column_references = [f"{column_name}" for column_name in df.columns if
column_name.lower() != "hed"]
found_column_references = BaseInput._find_column_refs(df)

invalid_replacements = [col for col in found_column_references if col not in possible_column_references]
if invalid_replacements:
# todo: This check may be moved to validation
raise ValueError(f"Bad column references found(columns do not exist): {invalid_replacements}")
valid_replacements = [col for col in found_column_references if col in possible_column_references]

column_names = list(transformers.keys())
# todo: break this into a sub function(probably)
column_names = known_columns
for column_name in valid_replacements:
column_names.remove(column_name)
saved_columns = all_columns[valid_replacements]
saved_columns = df[valid_replacements]
for column_name in column_names:
for replacing_name in valid_replacements:
column_name_brackets = f"[{replacing_name}]"
all_columns[column_name] = pd.Series(x.replace(column_name_brackets, y) for x, y
in zip(all_columns[column_name], saved_columns[replacing_name]))
all_columns = all_columns[column_names]
df[column_name] = pd.Series(x.replace(column_name_brackets, y) for x, y
in zip(df[column_name], saved_columns[replacing_name]))
df = df[column_names]

return all_columns
return df

@staticmethod
def combine_dataframe(dataframe):
""" Combines all columns in the given dataframe into a single hed string series.
""" Combines all columns in the given dataframe into a single HED string series,
skipping empty columns and columns with empty strings.
Parameters:
dataframe(Dataframe): The dataframe to combine
Returns:
Series: the assembled series
"""
dataframe = dataframe.agg(', '.join, axis=1)
dataframe = dataframe.agg(
lambda x: ', '.join(filter(lambda e: pd.notna(e) and e != "", x)), axis=1
)

# Potentially better ways to handle removing n/a by never inserting them to begin with.
dataframe = dataframe.replace("(, n/a|n/a,)", "", regex=True)
return dataframe
return dataframe
1 change: 1 addition & 0 deletions hed/validator/spreadsheet_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def validate(self, data, def_dicts=None, name=None, error_handler=None):
# Check the structure of the input data, if it's a BaseInput
if isinstance(data, BaseInput):
issues += self._validate_column_structure(data, error_handler)
# todo ian: Add more checks here for column inserters
data = data.dataframe_a

# Check the rows of the input data
Expand Down
103 changes: 0 additions & 103 deletions tests/models/test_base_file_input.py

This file was deleted.

Loading

0 comments on commit a4d3931

Please sign in to comment.