Skip to content

Commit

Permalink
update main files
Browse files Browse the repository at this point in the history
  • Loading branch information
kanodiaayush committed Dec 5, 2023
1 parent f6e7570 commit 51ba2f1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 35 deletions.
8 changes: 5 additions & 3 deletions bemb/utils/run_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ def section_print(input_text):
print('=' * 20, input_text, '=' * 20)


def run(model: LitBEMBFlex, dataset_list: List[ChoiceDataset], batch_size: int=-1, num_epochs: int=10, num_workers: int=8, **kwargs) -> LitBEMBFlex:
def run(model: LitBEMBFlex, dataset_list: List[ChoiceDataset], batch_size: int=-1, num_epochs: int=10, num_workers: int=8, run_test=True, **kwargs) -> LitBEMBFlex:
"""A standard pipeline of model training and evaluation.
Args:
model (LitBEMBFlex): the initialized pytorch-lightning wrapper of bemb.
dataset_list (List[ChoiceDataset]): train_dataset, validation_test, and test_dataset in a list of length 3.
batch_size (int, optional): batch_size for training and evaluation. Defaults to -1, which indicates full-batch training.
num_epochs (int, optional): number of epochs for training. Defaults to 10.
run_test (bool, optional): whether to run evaluation on test set. Defaults to True.
**kwargs: additional keyword argument for the pytorch-lightning Trainer.
Returns:
Expand Down Expand Up @@ -57,6 +58,7 @@ def run(model: LitBEMBFlex, dataset_list: List[ChoiceDataset], batch_size: int=-
trainer.fit(model, train_dataloaders=train, val_dataloaders=validation)
print(f'time taken: {time.time() - start_time}')

section_print('test performance')
trainer.test(model, dataloaders=test)
if run_test:
section_print('test performance')
trainer.test(model, dataloaders=test)
return model
79 changes: 47 additions & 32 deletions tutorials/supermarket/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from termcolor import cprint
from example_customized_module import ExampleCustomizedModule
from torch_choice.data import ChoiceDataset
from bemb.model import LitBEMBFlex
# from bemb.model import LitBEMBFlex
from bemb.model.bemb_supermarket_lightning import LitBEMBFlex
from bemb.utils.run_helper import run


Expand Down Expand Up @@ -70,28 +71,30 @@ def load_tsv(file_name, data_dir):
# ==============================================================================================
# user observables
# ==============================================================================================
user_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsUser.tsv'),
sep='\t',
index_col=0,
header=None)
# TODO(Tianyu): there could be duplicate information for each user.
# do we need to catch it in some check process?
user_obs = user_obs.groupby(user_obs.index).first().sort_index()
user_obs = torch.Tensor(user_obs.values)
configs.num_user_obs = user_obs.shape[1]
configs.coef_dim_dict['obsuser_item'] = configs.num_user_obs
if configs.obs_user:
user_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsUser.tsv'),
sep='\t',
index_col=0,
header=None)
# TODO(Tianyu): there could be duplicate information for each user.
# do we need to catch it in some check process?
user_obs = user_obs.groupby(user_obs.index).first().sort_index()
user_obs = torch.Tensor(user_obs.values)
configs.num_user_obs = user_obs.shape[1]
configs.coef_dim_dict['obsuser_item'] = configs.num_user_obs

# ==============================================================================================
# item observables
# ==============================================================================================
item_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsItem.tsv'),
sep='\t',
index_col=0,
header=None)
item_obs = item_obs.groupby(item_obs.index).first().sort_index()
item_obs = torch.Tensor(item_obs.values)
configs.num_item_obs = item_obs.shape[1]
configs.coef_dim_dict['obsitem_user'] = configs.num_item_obs
if configs.obs_item:
item_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsItem.tsv'),
sep='\t',
index_col=0,
header=None)
item_obs = item_obs.groupby(item_obs.index).first().sort_index()
item_obs = torch.Tensor(item_obs.values)
configs.num_item_obs = item_obs.shape[1]
configs.coef_dim_dict['obsitem_user'] = configs.num_item_obs

# ==============================================================================================
# item availability
Expand Down Expand Up @@ -152,14 +155,22 @@ def load_tsv(file_name, data_dir):
# example day of week, random example.
session_day_of_week = torch.LongTensor(np.random.randint(0, 7, configs.num_sessions))

choice_dataset = ChoiceDataset(item_index=label,
user_index=user_index,
session_index=session_index,
item_availability=item_availability,
user_obs=user_obs,
item_obs=item_obs,
price_obs=price_obs,
session_day_of_week=session_day_of_week)
choice_dataset_args = {
"item_index": label,
"user_index": user_index,
"session_index": session_index,
"item_availability": item_availability,
"price_obs": price_obs,
"session_day_of_week": session_day_of_week
}

if configs.obs_user:
choice_dataset_args["user_obs"] = user_obs

if configs.obs_item:
choice_dataset_args["item_obs"] = item_obs

choice_dataset = ChoiceDataset(**choice_dataset_args)

dataset_list.append(choice_dataset)

Expand Down Expand Up @@ -209,20 +220,24 @@ def load_tsv(file_name, data_dir):
coef_dim_dict=configs.coef_dim_dict,
trace_log_q=configs.trace_log_q,
category_to_item=category_to_item,
num_user_obs=configs.num_user_obs,
num_item_obs=configs.num_item_obs,
# num_price_obs=configs.num_price_obs,
num_user_obs=configs.num_user_obs if configs.obs_user else None,
num_item_obs=configs.num_item_obs if configs.obs_item else None,
prior_variance=configs.prior_variance,
num_price_obs=configs.num_price_obs,
preprocess=False,
# additional_modules=[ExampleCustomizedModule()]
)

bemb = bemb.to(configs.device)
bemb = run(bemb, dataset_list, batch_size=configs.batch_size, num_epochs=configs.num_epochs)
bemb = run(bemb, dataset_list, batch_size=configs.batch_size, num_epochs=configs.num_epochs, run_test=False)

# '''
coeffs = bemb.model.coef_dict['gamma_user'].variational_mean.detach().cpu().numpy()
coeffs = coeffs**2
# coeffs = coeffs**2
# give distribution statistics
print('Coefficients statistics:')
print(pd.DataFrame(coeffs).describe())
# '''

# ==============================================================================================
# inference example
Expand Down

0 comments on commit 51ba2f1

Please sign in to comment.