Skip to content

Commit

Permalink
refactor file names and pickle-->jsongz
Browse files Browse the repository at this point in the history
  • Loading branch information
ardunn committed Sep 6, 2019
1 parent 350941f commit e716dd7
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 56 deletions.
95 changes: 49 additions & 46 deletions benchdev/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
# Local testing configuration...
LOCAL_DEBUG_REG = {
"name": "debug_local_reg",
"data_pickle": "debug_jdft2d.pickle.gz",
"data_file": "debug_jdft2d.json.gz",
"target": "exfoliation_en",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

LOCAL_DEBUG_CLF = {
"name": "debug_local_clf",
"data_pickle": "debug_expt_is_metal.pickle.gz",
"data_file": "debug_expt_is_metal.json.gz",
"target": "is_metal",
"problem_type": AMM_CLF_NAME,
"clf_pos_label": True
Expand All @@ -42,127 +42,130 @@
LOCAL_DEBUG_SET = [LOCAL_DEBUG_CLF, LOCAL_DEBUG_REG]

# Real benchmark sets
BULK = {
"name": "mp_bulk",
"data_pickle": "elasticity_K_VRH.pickle.gz",
"target": "K_VRH",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

SHEAR = {
"name": "mp_shear",
"data_pickle": "elasticity_G_VRH.pickle.gz",
"target": "G_VRH",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

LOG_BULK = {
"name": "mp_log_bulk",
"data_pickle": "elasticity_log10(K_VRH).pickle.gz",
LOG_KVRH = {
"name": "log_kvrh",
"data_file": "log_kvrh.json.gz",
"target": "log10(K_VRH)",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

LOG_SHEAR = {
"name": "mp_log_shear",
"data_pickle": "elasticity_log10(G_VRH).pickle.gz",
LOG_GVRH = {
"name": "log_gvrh",
"data_file": "log_gvrh.json.gz",
"target": "log10(G_VRH)",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

REFRACTIVE = {
"name": "refractive_index",
"data_pickle": "dielectric.pickle.gz",
DIELECTRIC = {
"name": "dielectric",
"data_file": "dielectric.json.gz",
"target": "n",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

JDFT2D = {
"name": "jdft2d",
"data_pickle": "jdft2d.pickle.gz",
"data_file": "jdft2d.json.gz",
"target": "exfoliation_en",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

MP_GAP = {
"name": "mp_gap",
"data_pickle": "mp_gap.pickle.gz",
"data_file": "mp_gap.json.gz",
"target": "gap pbe",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

MP_IS_METAL = {
"name": "mp_is_metal",
"data_pickle": "mp_is_metal.pickle.gz",
"data_file": "mp_is_metal.json.gz",
"target": "is_metal",
"problem_type": AMM_CLF_NAME,
"clf_pos_label": True
}

MP_E_FORM = {
"name": "mp_e_form",
"data_pickle": "mp_e_form.pickle.gz",
"data_file": "mp_e_form.json.gz",
"target": "e_form",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

CASTELLI_E_FORM = {
"name": "castelli",
"data_pickle": "castelli.pickle.gz",
PEROVSKITES = {
"name": "perovskites",
"data_file": "perovskites.json.gz",
"target": "e_form",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

GFA = {
"name": "glass_formation",
"data_pickle": "glass.pickle.gz",
GLASS = {
"name": "glass",
"data_file": "glass.json.gz",
"target": "gfa",
"problem_type": AMM_CLF_NAME,
"clf_pos_label": True
}

EXPT_IS_METAL = {
"name": "expt_is_metal",
"data_pickle": "expt_is_metal.pickle.gz",
"data_file": "expt_is_metal.json.gz",
"target": "is_metal",
"problem_type": AMM_CLF_NAME,
"clf_pos_label": True
}

EXPT_GAP = {
"name": "expt_gap",
"data_pickle": "expt_gap.pickle.gz",
"data_file": "expt_gap.json.gz",
"target": "gap expt",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

PHONONS = {
"name": "phonons",
"data_pickle": "phonons.pickle.gz",
"data_file": "phonons.json.gz",
"target": "last phdos peak",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

STEELS_YIELD = {
"name": "steels_yield",
"data_pickle": "steels_yield.pickle.gz",
STEELS = {
"name": "steels",
"data_file": "steels.json.gz",
"target": "yield strength",
"problem_type": AMM_REG_NAME,
"clf_pos_label": None
}

BENCHMARK_DEBUG_SET = [JDFT2D, PHONONS, EXPT_IS_METAL, STEELS_YIELD]
BENCHMARK_FULL_SET = [BULK, SHEAR, LOG_BULK, LOG_SHEAR, REFRACTIVE, JDFT2D,
MP_GAP, MP_IS_METAL, MP_E_FORM, CASTELLI_E_FORM, GFA,
EXPT_IS_METAL, EXPT_GAP, STEELS_YIELD, PHONONS]
BENCHMARK_DEBUG_SET = [JDFT2D, PHONONS, EXPT_IS_METAL, STEELS]
BENCHMARK_FULL_SET = [LOG_KVRH, LOG_GVRH, DIELECTRIC, JDFT2D,
MP_GAP, MP_IS_METAL, MP_E_FORM, PEROVSKITES, GLASS,
EXPT_IS_METAL, EXPT_GAP, STEELS, PHONONS]

# Extra datasets, probably not present
# BULK = {
# "name": "mp_bulk",
# "data_file": "elasticity_K_VRH.json.gz",
# "target": "K_VRH",
# "problem_type": AMM_REG_NAME,
# "clf_pos_label": None
# }
#
# SHEAR = {
# "name": "mp_shear",
# "data_file": "elasticity_G_VRH.json.gz",
# "target": "G_VRH",
# "problem_type": AMM_REG_NAME,
# "clf_pos_label": None
# }
# EXTRAS = [SHEAR, BULK]
6 changes: 3 additions & 3 deletions benchdev/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,16 @@ def run_task(self, fw_spec):
# Read data from fw_spec
pipe_config_dict = fw_spec["pipe_config"]
target = fw_spec["target"]
data_pickle = fw_spec["data_pickle"]
data_file = fw_spec["data_file"]
learner_name = pipe_config_dict["learner_name"]
learner_kwargs = pipe_config_dict["learner_kwargs"]
reducer_kwargs = pipe_config_dict["reducer_kwargs"]
cleaner_kwargs = pipe_config_dict["cleaner_kwargs"]
autofeaturizer_kwargs = pipe_config_dict["autofeaturizer_kwargs"]

# Modify data_pickle based on computing resource
# Modify data_file based on computing resource
data_dir = os.environ['AMM_DATASET_DIR']
data_file = os.path.join(data_dir, data_pickle)
data_file = os.path.join(data_dir, data_file)

# Modify save_dir based on computing resource
bench_dir = os.environ['AMM_MODEL_DIR']
Expand Down
14 changes: 7 additions & 7 deletions benchdev/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def check_pipe_config(pipe_config):
raise ValueError("Logger is set internally by tasks.")


def wf_single_fit(fworker, fit_name, pipe_config, name, data_pickle, target, *args,
def wf_single_fit(fworker, fit_name, pipe_config, name, data_file, target, *args,
tags=None, **kwargs):
"""
Submit a dataset to be fit for a single pipeline (i.e., to train on a
Expand All @@ -56,7 +56,7 @@ def wf_single_fit(fworker, fit_name, pipe_config, name, data_pickle, target, *ar
spec = {
"pipe_config": pipe_config,
"base_save_dir": base_save_dir,
"data_pickle": data_pickle,
"data_file": data_file,
"target": target,
"automatminer_commit": get_last_commit(),
"tags": tags if tags else [],
Expand Down Expand Up @@ -140,7 +140,7 @@ def wf_evaluate_build(fworker, build_name, dataset_set, pipe_config,
return wf


def wf_benchmark(fworker, pipe_config, name, data_pickle, target, problem_type,
def wf_benchmark(fworker, pipe_config, name, data_file, target, problem_type,
clf_pos_label,
cache=True, kfold_config=KFOLD_DEFAULT, tags=None,
return_fireworks=False, add_dataset_to_names=True,
Expand All @@ -157,11 +157,11 @@ def wf_benchmark(fworker, pipe_config, name, data_pickle, target, problem_type,
# pipe_config["autofeaturizer_kwargs"]["n_jobs"] = n_cori_jobs

# Single (run) hash is the combination of pipe configuration + last commit
# + data_pickle
# + data_file
last_commit = get_last_commit()
benchmark_config_for_hash = copy.deepcopy(pipe_config)
benchmark_config_for_hash["last_commit"] = last_commit
benchmark_config_for_hash["data_pickle"] = data_pickle
benchmark_config_for_hash["data_file"] = data_file
benchmark_config_for_hash["worker"] = fworker
benchmark_config_for_hash = str(benchmark_config_for_hash).encode("UTF-8")
benchmark_hash = hashlib.sha1(benchmark_config_for_hash).hexdigest()[:10]
Expand All @@ -171,7 +171,7 @@ def wf_benchmark(fworker, pipe_config, name, data_pickle, target, problem_type,
"pipe_config": pipe_config,
"base_save_dir": base_save_dir,
"kfold_config": kfold_config,
"data_pickle": data_pickle,
"data_file": data_file,
"target": target,
"clf_pos_label": clf_pos_label,
"problem_type": problem_type,
Expand Down Expand Up @@ -287,7 +287,7 @@ def wf_benchmark(fworker, pipe_config, name, data_pickle, target, problem_type,
# "debug"
]

from benchdev.config import MP_E_FORM, JDFT2D, BULK, GFA
from benchdev.config import MP_E_FORM, JDFT2D, BULK, GLASS
# wf = wf_evaluate_build("lrc", "set generation size", BENCHMARK_FULL_SET, pipe_config,
# include_tests=False, cache=True, tags=tags)

Expand Down

0 comments on commit e716dd7

Please sign in to comment.