Skip to content

Commit

Permalink
Remove unnecessary Categories type
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Mar 21, 2024
1 parent 4b793fb commit 99fc65d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 30 deletions.
39 changes: 10 additions & 29 deletions ibisml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _ibis_table_to_numpy(table: ir.Table) -> np.ndarray:
raise ValueError(
"Not all output columns are numeric, cannot convert to a numpy array"
)
return table.to_pandas().values
return table.to_pandas().to_numpy()


def _y_as_dataframe(y: Any) -> pd.DataFrame:
Expand Down Expand Up @@ -224,46 +224,34 @@ def _(X, y=None, maintain_order=False):
return ibis.memtable(table), targets, index


class Categories:
def __init__(self, values: pa.Array):
self.values = values

def __repr__(self):
items = [repr(v) for v in self.values[:3].to_pylist()]
if len(self.values) > 3:
items.append("...")
values = ", ".join(items)
return f"Categories<{values}>"


class Metadata:
def __init__(
self,
categories: dict[str, Categories] | None = None,
categories: dict[str, pa.Array] | None = None,
targets: tuple[str, ...] = (),
):
self.categories = categories or {}
self.targets = targets

def get_categories(self, column: str) -> Categories | None:
def get_categories(self, column: str) -> pa.Array | None:
return self.categories.get(column)

def set_categories(self, column: str, values: pa.Array | list[Any]) -> None:
self.categories[column] = Categories(pa.array(values))
self.categories[column] = pa.array(values)

def drop_categories(self, column: str) -> None:
self.categories.pop(column, None)


def _categorize_wrap_reader(
reader: pa.RecordBatchReader, categories: dict[str, Categories]
reader: pa.RecordBatchReader, categories: dict[str, pa.Array]
) -> Iterable[pa.RecordBatch]:
for batch in reader:
out = {}
for name, col in zip(batch.schema.names, batch.columns):
cats = categories.get(name)
if cats is not None:
col = pa.DictionaryArray.from_arrays(col, cats.values)
col = pa.DictionaryArray.from_arrays(col, cats)
out[name] = col
yield pa.RecordBatch.from_pydict(out)

Expand Down Expand Up @@ -486,7 +474,7 @@ def _categorize_pandas(self, df: pd.DataFrame) -> pd.DataFrame:
codes = df[col].fillna(-1)
if not pd.api.types.is_integer_dtype(codes):
codes = codes.astype("int64")
df[col] = pd.Categorical.from_codes(cast(Sequence[int], codes), cats.values)
df[col] = pd.Categorical.from_codes(cast(Sequence[int], codes), cats)
return df

def _categorize_pyarrow(self, table: pa.Table) -> pa.Table:
Expand All @@ -499,7 +487,7 @@ def _categorize_pyarrow(self, table: pa.Table) -> pa.Table:
if cats is not None:
col = pa.chunked_array(
[
pa.DictionaryArray.from_arrays(chunk, cats.values)
pa.DictionaryArray.from_arrays(chunk, cats)
for chunk in col.chunks
]
)
Expand All @@ -509,13 +497,8 @@ def _categorize_pyarrow(self, table: pa.Table) -> pa.Table:
def _categorize_dask_dataframe(self, ddf: dd.DataFrame) -> dd.DataFrame:
if not self.metadata_.categories:
return ddf

categorize = _get_categorize_chunk()

categories = {
col: cats.values for col, cats in self.metadata_.categories.items()
}
return ddf.map_partitions(categorize, categories)
return ddf.map_partitions(categorize, self.metadata_.categories)

def _categorize_pyarrow_batches(
self, reader: pa.RecordBatchReader
Expand All @@ -528,9 +511,7 @@ def _categorize_pyarrow_batches(
cats = self.metadata_.categories.get(field.name)
if cats is not None:
field = pa.field(
field.name,
pa.dictionary(field.type, cats.values.type),
field.nullable,
field.name, pa.dictionary(field.type, cats.type), field.nullable
)
fields.append(field)

Expand Down
2 changes: 1 addition & 1 deletion ibisml/steps/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def fit_table(self, table: ir.Table, metadata: Metadata) -> None:
to_compute = []
for column in columns:
if cats := metadata.get_categories(column):
categories[column] = list(range(len(cats.values)))
categories[column] = list(range(len(cats)))
else:
to_compute.append(column)

Expand Down

0 comments on commit 99fc65d

Please sign in to comment.