From 51ba2f1bf9a531040f270c6fe9694108bb38f607 Mon Sep 17 00:00:00 2001 From: Ayush Kanodia Date: Mon, 4 Dec 2023 21:25:36 -0800 Subject: [PATCH] update main files --- bemb/utils/run_helper.py | 8 ++-- tutorials/supermarket/main.py | 79 +++++++++++++++++++++-------------- 2 files changed, 52 insertions(+), 35 deletions(-) diff --git a/bemb/utils/run_helper.py b/bemb/utils/run_helper.py index 410745d..0d5637a 100644 --- a/bemb/utils/run_helper.py +++ b/bemb/utils/run_helper.py @@ -17,7 +17,7 @@ def section_print(input_text): print('=' * 20, input_text, '=' * 20) -def run(model: LitBEMBFlex, dataset_list: List[ChoiceDataset], batch_size: int=-1, num_epochs: int=10, num_workers: int=8, **kwargs) -> LitBEMBFlex: +def run(model: LitBEMBFlex, dataset_list: List[ChoiceDataset], batch_size: int=-1, num_epochs: int=10, num_workers: int=8, run_test=True, **kwargs) -> LitBEMBFlex: """A standard pipeline of model training and evaluation. Args: @@ -25,6 +25,7 @@ def run(model: LitBEMBFlex, dataset_list: List[ChoiceDataset], batch_size: int=- dataset_list (List[ChoiceDataset]): train_dataset, validation_test, and test_dataset in a list of length 3. batch_size (int, optional): batch_size for training and evaluation. Defaults to -1, which indicates full-batch training. num_epochs (int, optional): number of epochs for training. Defaults to 10. + run_test (bool, optional): whether to run evaluation on test set. Defaults to True. **kwargs: additional keyword argument for the pytorch-lightning Trainer. Returns: @@ -57,6 +58,7 @@ def run(model: LitBEMBFlex, dataset_list: List[ChoiceDataset], batch_size: int=- trainer.fit(model, train_dataloaders=train, val_dataloaders=validation) print(f'time taken: {time.time() - start_time}') - section_print('test performance') - trainer.test(model, dataloaders=test) + if run_test: + section_print('test performance') + trainer.test(model, dataloaders=test) return model diff --git a/tutorials/supermarket/main.py b/tutorials/supermarket/main.py index 5cd5479..4982014 100644 --- a/tutorials/supermarket/main.py +++ b/tutorials/supermarket/main.py @@ -10,7 +10,8 @@ from termcolor import cprint from example_customized_module import ExampleCustomizedModule from torch_choice.data import ChoiceDataset -from bemb.model import LitBEMBFlex +# from bemb.model import LitBEMBFlex +from bemb.model.bemb_supermarket_lightning import LitBEMBFlex from bemb.utils.run_helper import run @@ -70,28 +71,30 @@ def load_tsv(file_name, data_dir): # ============================================================================================== # user observables # ============================================================================================== - user_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsUser.tsv'), - sep='\t', - index_col=0, - header=None) - # TODO(Tianyu): there could be duplicate information for each user. - # do we need to catch it in some check process? - user_obs = user_obs.groupby(user_obs.index).first().sort_index() - user_obs = torch.Tensor(user_obs.values) - configs.num_user_obs = user_obs.shape[1] - configs.coef_dim_dict['obsuser_item'] = configs.num_user_obs + if configs.obs_user: + user_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsUser.tsv'), + sep='\t', + index_col=0, + header=None) + # TODO(Tianyu): there could be duplicate information for each user. + # do we need to catch it in some check process? + user_obs = user_obs.groupby(user_obs.index).first().sort_index() + user_obs = torch.Tensor(user_obs.values) + configs.num_user_obs = user_obs.shape[1] + configs.coef_dim_dict['obsuser_item'] = configs.num_user_obs # ============================================================================================== # item observables # ============================================================================================== - item_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsItem.tsv'), - sep='\t', - index_col=0, - header=None) - item_obs = item_obs.groupby(item_obs.index).first().sort_index() - item_obs = torch.Tensor(item_obs.values) - configs.num_item_obs = item_obs.shape[1] - configs.coef_dim_dict['obsitem_user'] = configs.num_item_obs + if configs.obs_item: + item_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsItem.tsv'), + sep='\t', + index_col=0, + header=None) + item_obs = item_obs.groupby(item_obs.index).first().sort_index() + item_obs = torch.Tensor(item_obs.values) + configs.num_item_obs = item_obs.shape[1] + configs.coef_dim_dict['obsitem_user'] = configs.num_item_obs # ============================================================================================== # item availability @@ -152,14 +155,22 @@ def load_tsv(file_name, data_dir): # example day of week, random example. session_day_of_week = torch.LongTensor(np.random.randint(0, 7, configs.num_sessions)) - choice_dataset = ChoiceDataset(item_index=label, - user_index=user_index, - session_index=session_index, - item_availability=item_availability, - user_obs=user_obs, - item_obs=item_obs, - price_obs=price_obs, - session_day_of_week=session_day_of_week) + choice_dataset_args = { + "item_index": label, + "user_index": user_index, + "session_index": session_index, + "item_availability": item_availability, + "price_obs": price_obs, + "session_day_of_week": session_day_of_week + } + + if configs.obs_user: + choice_dataset_args["user_obs"] = user_obs + + if configs.obs_item: + choice_dataset_args["item_obs"] = item_obs + + choice_dataset = ChoiceDataset(**choice_dataset_args) dataset_list.append(choice_dataset) @@ -209,20 +220,24 @@ def load_tsv(file_name, data_dir): coef_dim_dict=configs.coef_dim_dict, trace_log_q=configs.trace_log_q, category_to_item=category_to_item, - num_user_obs=configs.num_user_obs, - num_item_obs=configs.num_item_obs, - # num_price_obs=configs.num_price_obs, + num_user_obs=configs.num_user_obs if configs.obs_user else None, + num_item_obs=configs.num_item_obs if configs.obs_item else None, + prior_variance=configs.prior_variance, + num_price_obs=configs.num_price_obs, + preprocess=False, # additional_modules=[ExampleCustomizedModule()] ) bemb = bemb.to(configs.device) - bemb = run(bemb, dataset_list, batch_size=configs.batch_size, num_epochs=configs.num_epochs) + bemb = run(bemb, dataset_list, batch_size=configs.batch_size, num_epochs=configs.num_epochs, run_test=False) + # ''' coeffs = bemb.model.coef_dict['gamma_user'].variational_mean.detach().cpu().numpy() - coeffs = coeffs**2 + # coeffs = coeffs**2 # give distribution statistics print('Coefficients statistics:') print(pd.DataFrame(coeffs).describe()) + # ''' # ============================================================================================== # inference example