Skip to content

Commit

Permalink
remove redundant line.
Browse files Browse the repository at this point in the history
  • Loading branch information
TianyuDu committed Jun 26, 2023
1 parent a4e228c commit 42b9200
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 1 deletion.
1 change: 0 additions & 1 deletion tutorials/supermarket/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def load_configs(yaml_file: str):
defaults = {
'num_verify_val': 10,
'early_stopping': {'validation_llh_flat': -1},
'write_best_model': True
'write_best_model': True,
'pred_item' : True,
}
Expand Down
227 changes: 227 additions & 0 deletions tutorials/supermarket/main_shopper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import argparse
import os
import sys

import numpy as np
import pandas as pd
import torch
import yaml
from sklearn.preprocessing import LabelEncoder
from termcolor import cprint
from example_customized_module import ExampleCustomizedModule
from torch_choice.data import ChoiceDataset
from bemb.model import LitBEMBFlex
from bemb.utils.run_helper import run


def load_configs(yaml_file: str):
with open(yaml_file, 'r') as file:
data_loaded = yaml.safe_load(file)
# Add defaults
defaults = {
'num_verify_val': 10,
'early_stopping': {'validation_llh_flat': -1},
'write_best_model': True,
'pred_item' : True,
}
defaults.update(data_loaded)
configs = argparse.Namespace(**defaults)
return configs


def is_sorted(x):
return all(x == np.sort(x))


def load_tsv(file_name, data_dir):
return pd.read_csv(os.path.join(data_dir, file_name),
sep='\t',
index_col=None,
names=['user_id', 'item_id', 'session_id', 'quantity'])


if __name__ == '__main__':
cprint('Your are running an example script.', 'green')
# sys.argv[1] should be the yaml file.
configs = load_configs(sys.argv[1])

# ==============================================================================================
# Load standard BEMB inputs.
# ==============================================================================================
train = load_tsv('train.tsv', configs.data_dir)
# read standard BEMB input files.
validation = load_tsv('validation.tsv', configs.data_dir)
test = load_tsv('test.tsv', configs.data_dir)

# ==============================================================================================
# Encode users and items to {0, 1, ..., num-1}.
# ==============================================================================================
# combine data for encoding.
data_all = pd.concat([train, validation, test], axis=0)
# encode user.
user_encoder = LabelEncoder().fit(data_all['user_id'].values)
configs.num_users = len(user_encoder.classes_)
assert is_sorted(user_encoder.classes_)
# encode items.
item_encoder = LabelEncoder().fit(data_all['item_id'].values)
configs.num_items = len(item_encoder.classes_)
assert is_sorted(item_encoder.classes_)

# ==============================================================================================
# user observables
# ==============================================================================================
user_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsUser.tsv'),
sep='\t',
index_col=0,
header=None)
# TODO(Tianyu): there could be duplicate information for each user.
# do we need to catch it in some check process?
user_obs = user_obs.groupby(user_obs.index).first().sort_index()
user_obs = torch.Tensor(user_obs.values)
configs.num_user_obs = user_obs.shape[1]
configs.coef_dim_dict['obsuser_item'] = configs.num_user_obs

# ==============================================================================================
# item observables
# ==============================================================================================
item_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsItem.tsv'),
sep='\t',
index_col=0,
header=None)
item_obs = item_obs.groupby(item_obs.index).first().sort_index()
item_obs = torch.Tensor(item_obs.values)
configs.num_item_obs = item_obs.shape[1]
configs.coef_dim_dict['obsitem_user'] = configs.num_item_obs

# ==============================================================================================
# item availability
# ==============================================================================================
# parse item availability.
# Try and catch? Optionally specify full availability?
a_tsv = pd.read_csv(os.path.join(configs.data_dir, 'availabilityList.tsv'),
sep='\t',
index_col=None,
header=None,
names=['session_id', 'item_id'])

# availability ties session as well.
session_encoder = LabelEncoder().fit(a_tsv['session_id'].values)
configs.num_sessions = len(session_encoder.classes_)
assert is_sorted(session_encoder.classes_)
# this loop could be slow, depends on # sessions.
item_availability = torch.zeros(configs.num_sessions, configs.num_items).bool()

a_tsv['item_id'] = item_encoder.transform(a_tsv['item_id'].values)
a_tsv['session_id'] = session_encoder.transform(a_tsv['session_id'].values)

for session_id, df_group in a_tsv.groupby('session_id'):
# get IDs of items available at this date.
a_item_ids = df_group['item_id'].unique() # this unique is not necessary if the dataset is well-prepared.
item_availability[session_id, a_item_ids] = True

# ==============================================================================================
# price observables
# ==============================================================================================
df_price = pd.read_csv(os.path.join(configs.data_dir, 'item_sess_price.tsv'),
sep='\t',
names=['item_id', 'session_id', 'price'])

# only keep prices of relevant items.
mask = df_price['item_id'].isin(item_encoder.classes_)
df_price = df_price[mask]

df_price['item_id'] = item_encoder.transform(df_price['item_id'].values)
df_price['session_id'] = session_encoder.transform(df_price['session_id'].values)
df_price = df_price.pivot(index='session_id', columns='item_id')
# NAN prices.
df_price.fillna(0.0, inplace=True)
price_obs = torch.Tensor(df_price.values).view(configs.num_sessions, configs.num_items, 1)
configs.num_price_obs = 1

# ==============================================================================================
# create datasets
# ==============================================================================================
dataset_list = list()
for d in (train, validation, test):
user_index = torch.LongTensor(user_encoder.transform(d['user_id'].values))
label = torch.LongTensor(item_encoder.transform(d['item_id'].values))
session_index = torch.LongTensor(session_encoder.transform(d['session_id'].values))
# get the date (aka session_id in the raw dataset) of each row in the dataset, retrieve
# the item availability information from that date.

# example day of week, random example.
session_day_of_week = torch.LongTensor(np.random.randint(0, 7, configs.num_sessions))

choice_dataset = ChoiceDataset(item_index=label,
user_index=user_index,
session_index=session_index,
item_availability=item_availability,
user_obs=user_obs,
item_obs=item_obs,
price_obs=price_obs,
session_day_of_week=session_day_of_week)

dataset_list.append(choice_dataset)

# ==============================================================================================
# category information
# ==============================================================================================
item_groups = pd.read_csv(os.path.join(configs.data_dir, 'itemGroup.tsv'),
sep='\t',
index_col=None,
names=['item_id', 'category_id'])

# TODO(Tianyu): handle duplicate group information.
item_groups = item_groups.groupby('item_id').first().reset_index()
# filter out items never purchased.
mask = item_groups['item_id'].isin(item_encoder.classes_)
item_groups = item_groups[mask].reset_index(drop=True)
item_groups = item_groups.sort_values(by='item_id')

category_encoder = LabelEncoder().fit(item_groups['category_id'])
configs.num_categories = len(category_encoder.classes_)

# encode them to consecutive integers {0, ..., num_items-1}.
item_groups['item_id'] = item_encoder.transform(
item_groups['item_id'].values)
item_groups['category_id'] = category_encoder.transform(
item_groups['category_id'].values)

print('Category sizes:')
print(item_groups.groupby('category_id').size().describe())
item_groups = item_groups.groupby('category_id')['item_id'].apply(list)
category_to_item = dict(zip(item_groups.index, item_groups.values))
# ==============================================================================================
# pytorch-lightning training
# ==============================================================================================
bemb = LitBEMBFlex(
# trainings args.
pred_item = configs.pred_item,
learning_rate=configs.learning_rate,
num_seeds=configs.num_mc_seeds,
# model args, will be passed to BEMB constructor.
utility_formula=configs.utility,
num_users=configs.num_users,
num_items=configs.num_items,
num_sessions=configs.num_sessions,
obs2prior_dict=configs.obs2prior_dict,
coef_dim_dict=configs.coef_dim_dict,
trace_log_q=configs.trace_log_q,
category_to_item=category_to_item,
num_user_obs=configs.num_user_obs,
num_item_obs=configs.num_item_obs,
# num_price_obs=configs.num_price_obs,
# additional_modules=[ExampleCustomizedModule()]
)

bemb = bemb.to(configs.device)
bemb = run(bemb, dataset_list, batch_size=configs.batch_size, num_epochs=configs.num_epochs)

# ==============================================================================================
# inference example
# ==============================================================================================
with torch.no_grad():
# disable gradient tracking to save computational cost.
utility_chosen = bemb.model(dataset_list[2], return_type='utility', return_scope='item_index')
# uses much higher memory!
# utility_all = bemb.model(dataset_list[2], return_logit=True, all_items=True)

0 comments on commit 42b9200

Please sign in to comment.