Skip to content

Commit

Permalink
working ignore on predict no tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ardunn committed Oct 14, 2019
1 parent b82849f commit 1e77201
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
4 changes: 2 additions & 2 deletions automatminer/automl/adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ class SinglePipelineAdaptor(DFMLAdaptor, LoggableMixin):
(classification)
_regressor (BaseEstimator): The single pipeline to be used for
regression
_classifier (BaseEstimator)L The single pipeline to be used for
_classifier (BaseEstimator): The single pipeline to be used for
classification
"""
Expand Down Expand Up @@ -326,7 +326,7 @@ def fit(self, df, target, **fit_kwargs):
@property
@check_fitted
def backend(self):
return None
return self.best_pipeline

@property
@check_fitted
Expand Down
17 changes: 10 additions & 7 deletions automatminer/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ class MatPipe(DFTransformer, LoggableMixin):
logger (Logger, bool): A custom logger object to use for logging.
Alternatively, if set to True, the default automatminer logger will
be used. If set to False, then no logging will occur.
ignore ([str]): String names of columns in all dataframes to ignore.
This will not stop samples from being dropped, but will preserve
columns.
Attributes:
The following attributes are set during fitting. Each has their own set
Expand All @@ -86,7 +83,7 @@ class MatPipe(DFTransformer, LoggableMixin):
"""

def __init__(self, autofeaturizer=None, cleaner=None, reducer=None,
learner=None, logger=AMM_DEFAULT_LOGGER, ignore=None):
learner=None, logger=AMM_DEFAULT_LOGGER):
transformers = [autofeaturizer, cleaner, reducer, learner]
if not all(transformers):
if any(transformers):
Expand All @@ -106,7 +103,6 @@ def __init__(self, autofeaturizer=None, cleaner=None, reducer=None,
self.reducer = reducer
self.learner = learner
self.logger = logger
self.ignore = ignore
self.pre_fit_df = None
self.post_fit_df = None
self.ml_type = None
Expand Down Expand Up @@ -183,7 +179,7 @@ def transform(self, df, **transform_kwargs):
return self.predict(df, **transform_kwargs)

@check_fitted
def predict(self, df):
def predict(self, df, ignore=None):
"""
Predict a target property of a set of materials.
Expand All @@ -195,11 +191,18 @@ def predict(self, df):
Args:
df (pandas.DataFrame): Pipe will be fit to this dataframe.
ignore ([str]): String names of columns in all dataframes to ignore.
This will not stop samples from being dropped.
Returns:
(pandas.DataFrame): The dataframe with target property predictions.
"""
ignored_df = df[self.ignore] if self.ignore else None
if ignore:
ignored_df = df[ignore]
df = df.drop(columns=ignored_df)
else:
ignored_df = pd.DataFrame()

self.logger.info("Beginning MatPipe prediction using fitted pipeline.")
df = self.autofeaturizer.transform(df, self.target)
df = self.cleaner.transform(df, self.target)
Expand Down
4 changes: 1 addition & 3 deletions automatminer/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from automatminer.featurization import AutoFeaturizer
from automatminer.preprocessing import FeatureReducer, DataCleaner
from automatminer.automl import TPOTAdaptor, SinglePipelineAdaptor
from automatminer.utils.log import AMM_DEFAULT_LOGLVL, AMM_DEFAULT_LOGGER
from automatminer.utils.log import AMM_DEFAULT_LOGGER


def get_preset_config(preset: str = 'express', **powerups) -> dict:
Expand All @@ -39,7 +39,6 @@ def get_preset_config(preset: str = 'express', **powerups) -> dict:
Args:
preset (str): The name of the preset config you'd like to use.
**powerups: Various modifications as kwargs.
ignore ([str]): Column names for the pipeline to ignore.
cache_src (str): A file path. If specified, Autofeaturizer will use
feature caching with a file stored at this location. See
Autofeaturizer's cache_src argument for more information.
Expand Down Expand Up @@ -111,7 +110,6 @@ def get_preset_config(preset: str = 'express', **powerups) -> dict:

logger = powerups.get("logger", AMM_DEFAULT_LOGGER)
config["logger"] = logger
config["ignore"] = powerups.get("ignore", None)
return config


Expand Down

0 comments on commit 1e77201

Please sign in to comment.