Skip to content

Commit

Permalink
fix code style handly
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhen38 committed Dec 4, 2021
1 parent 7859916 commit 353c5f7
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 82 deletions.
164 changes: 104 additions & 60 deletions models/rank/bert4rec/data_augment_candi_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ class FreqVocab(object):
"""Runs end-to-end tokenziation."""

def __init__(self, user_to_list):
self.counter = Counter(
)
self.counter = Counter()
self.user_set = set()
for u, item_list in user_to_list.items():
self.counter.update(item_list)
Expand Down Expand Up @@ -103,12 +102,14 @@ def get_special_token(self):
return self.special_tokens

def get_vocab_size(self):
return self.get_item_count() + self.get_special_token_count() + 1 #self.get_user_count()
return self.get_item_count() + self.get_special_token_count(
) + 1 #self.get_user_count()


random_seed = 12345
short_seq_prob = 0 # Probability of creating sequences which are shorter than the maximum length。


def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""

Expand Down Expand Up @@ -192,18 +193,22 @@ def create_training_instances(output_filenamem,
if len(item_seq) <= max_num_tokens:
all_documents[user] = [item_seq]
else:
beg_idx = list(range(len(item_seq) - max_num_tokens, 0, -sliding_step))
beg_idx = list(
range(len(item_seq) - max_num_tokens, 0, -sliding_step))
beg_idx.append(0)
all_documents[user] = [item_seq[i:i + max_num_tokens] for i in beg_idx[::-1]]
all_documents[user] = [
item_seq[i:i + max_num_tokens] for i in beg_idx[::-1]
]

instances = []
if force_last:
for user in all_documents:
instances.extend(
create_instances_from_document_test(
all_documents, user, max_seq_length))
create_instances_from_document_test(all_documents, user,
max_seq_length))
print("num of instance:{}".format(len(instances)))
write_sample_data(vocab, instances, max_seq_length, max_predictions_per_seq,
write_sample_data(vocab, instances, max_seq_length,
max_predictions_per_seq,
output_filenamem + "-test" + ".txt")
else:
start_time = time.clock()
Expand All @@ -212,54 +217,63 @@ def create_training_instances(output_filenamem,
print("document num: {}".format(len(all_documents)))

def log_result(result):
print("callback function result type: {}, size: {} ".format(type(result), len(result)))
print("callback function result type: {}, size: {} ".format(
type(result), len(result)))
# instances.extend(result)

for step in range(dupe_factor):
pool.apply_async(
create_instances_threading, args=(
all_documents, user, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab, random.Random(random.randint(1, 10000)),
mask_prob, step, output_filenamem), callback=log_result)
create_instances_threading,
args=(all_documents, user, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab,
random.Random(random.randint(1, 10000)), mask_prob, step,
output_filenamem),
callback=log_result)
pool.close()
pool.join()

for user in all_documents:
instances.extend(
mask_last(
all_documents, user, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab, rng, output_filenamem))
mask_last(all_documents, user, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab, rng,
output_filenamem))

print("num of instance:{}; time:{}".format(len(instances), time.clock() - start_time))
print("num of instance:{}; time:{}".format(
len(instances), time.clock() - start_time))

rng.shuffle(instances)
return instances


def create_instances_threading(all_documents, user, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab, rng,
mask_prob, step, output_filenamem):
def create_instances_threading(all_documents, user, max_seq_length,
short_seq_prob, masked_lm_prob,
max_predictions_per_seq, vocab, rng, mask_prob,
step, output_filenamem):
cnt = 0
start_time = time.clock()
instances = []
for user in all_documents:
cnt += 1
if cnt % 1000 == 0:
print("step: {}, name: {}, step: {}, time: {}".format(step, multiprocessing.current_process().name, cnt,
time.clock() - start_time))
print("step: {}, name: {}, step: {}, time: {}".format(
step,
multiprocessing.current_process().name, cnt,
time.clock() - start_time))
start_time = time.clock()
instances.extend(create_instances_from_document_train(
all_documents, user, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab, rng,
mask_prob))
write_sample_data(vocab, instances, max_seq_length, max_predictions_per_seq,
instances.extend(
create_instances_from_document_train(
all_documents, user, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab, rng,
mask_prob))
write_sample_data(vocab, instances, max_seq_length,
max_predictions_per_seq,
output_filenamem + "_train_" + str(step) + ".txt")
return instances


def mask_last(
all_documents, user, max_seq_length, short_seq_prob, masked_lm_prob,
max_predictions_per_seq, vocab, rng, output_filename):
def mask_last(all_documents, user, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab, rng,
output_filename):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[user]
max_num_tokens = max_seq_length
Expand All @@ -279,7 +293,9 @@ def mask_last(
masked_lm_positions=masked_lm_positions,
masked_lm_labels=masked_lm_labels)
instances.append(instance)
write_sample_data(vocab, instances, max_seq_length, max_predictions_per_seq, output_filename + "_train_" +str(args.dupe_factor) + ".txt")
write_sample_data(
vocab, instances, max_seq_length, max_predictions_per_seq,
output_filename + "_train_" + str(args.dupe_factor) + ".txt")
return instances


Expand Down Expand Up @@ -323,8 +339,8 @@ def create_instances_from_document_train(

(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq,
vocab_items, rng, mask_prob)
tokens, masked_lm_prob, max_predictions_per_seq, vocab_items, rng,
mask_prob)
instance = TrainingInstance(
info=info,
tokens=tokens,
Expand Down Expand Up @@ -427,16 +443,18 @@ def gen_samples(data,
pool_size,
force_last=False):
# create train
instances = create_training_instances(output_filename,
data, max_seq_length, dupe_factor, short_seq_prob, masked_lm_prob,
max_predictions_per_seq, rng, vocab, mask_prob, prop_sliding_window,
pool_size, force_last)
instances = create_training_instances(
output_filename, data, max_seq_length, dupe_factor, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, rng, vocab, mask_prob,
prop_sliding_window, pool_size, force_last)

print("*** Writing to output files ***")
print(" %s", output_filename)


def write_sample_data(vocab, instances, max_seq_length, max_predictions_per_seq, output_file):
def write_sample_data(vocab, instances, max_seq_length,
max_predictions_per_seq, output_file):

fw = open(output_file, "w")
for (inst_index, instance) in enumerate(instances):
try:
Expand All @@ -463,8 +481,10 @@ def write_sample_data(vocab, instances, max_seq_length, max_predictions_per_seq,
fw.writelines(str(input_ids).replace('[', ' ').replace(']', '') + ";")
fw.writelines(str(input_mask).replace('[', ' ').replace(']', '') + ";")
fw.writelines(str(input_pos).replace('[', ' ').replace(']', '') + ";")
fw.writelines(str(masked_lm_positions).replace('[', ' ').replace(']', '') + ";")
fw.writelines(str(masked_lm_ids).replace('[', ' ').replace(']', '') + "\n")
fw.writelines(
str(masked_lm_positions).replace('[', ' ').replace(']', '') + ";")
fw.writelines(
str(masked_lm_ids).replace('[', ' ').replace(']', '') + "\n")

fw.close()

Expand Down Expand Up @@ -502,9 +522,9 @@ def main(args):
print('max:{}, min:{}'.format(max_len, min_len))

print('len_train:{}, len_valid:{}, len_test:{}, usernum:{}, itemnum:{}'.
format(
len(user_train),
len(user_valid), len(user_test), usernum, itemnum))
format(
len(user_train),
len(user_valid), len(user_test), usernum, itemnum))

for idx, u in enumerate(user_train):
if idx <= 1:
Expand All @@ -524,7 +544,7 @@ def main(args):
}
user_test_data = {
'user_' + str(u):
['item_' + str(item) for item in (user_train[u] + user_test[u])]
['item_' + str(item) for item in (user_train[u] + user_test[u])]
for u in user_train if len(user_train[u]) > 0 and len(user_test[u]) > 0
}
rng = random.Random(random_seed)
Expand Down Expand Up @@ -572,14 +592,18 @@ def main(args):
print('test:{}'.format(output_filename))
# notice that the final train data contain the 10 cloze forms of data and one mask_last form of data
# while the form of mask_last aims to narrow the gap of training and test
train_total = open(output_dir + 'train/' + dataset_name + "-train.txt", 'a')
train_total = open(output_dir + 'train/' + dataset_name + "-train.txt",
'a')
for i in range(dupe_factor):
f = open(output_dir + 'train/' + dataset_name + "_train_" + str(i) + ".txt")
f = open(output_dir + 'train/' + dataset_name + "_train_" + str(i) +
".txt")
buf = f.read()
f.close()
os.remove(output_dir + 'train/' + dataset_name + "_train_" + str(i) + ".txt")
os.remove(output_dir + 'train/' + dataset_name + "_train_" + str(i) +
".txt")
train_total.write(buf)
os.remove(output_dir + 'train/' + dataset_name + "_train_" + str(dupe_factor) + ".txt")
os.remove(output_dir + 'train/' + dataset_name + "_train_" + str(
dupe_factor) + ".txt")
train_total.close()

print('vocab_size:{}, user_size:{}, item_size:{}, item_with_other_size:{}'.
Expand Down Expand Up @@ -638,8 +662,12 @@ def main(args):
item_idx = [labels[idx][0]]
if vocab is not None:
while len(item_idx) < 101:
sampled_ids = np.random.choice(ids, 101, replace=False, p=probability)
sampled_ids = [x for x in sampled_ids if x not in rated and x not in item_idx]
sampled_ids = np.random.choice(
ids, 101, replace=False, p=probability)
sampled_ids = [
x for x in sampled_ids
if x not in rated and x not in item_idx
]
item_idx.extend(sampled_ids[:])
item_idx = item_idx[:101]
candidates.append(item_idx)
Expand All @@ -654,18 +682,34 @@ def main(args):

if __name__ == "__main__":
# Commandline arguments
parser = argparse.ArgumentParser(description="Parameter of data augmentation and paths")
parser = argparse.ArgumentParser(
description="Parameter of data augmentation and paths")
parser.add_argument('-pool_size', dest='pool_size', type=int, default=10)
parser.add_argument('-max_seq_length', dest='max_seq_length', type=int, default=50)
parser.add_argument('-max_predictions_per_seq', dest='max_predictions_per_seq', type=int, default=30)
parser.add_argument('-dupe_factor', dest='dupe_factor', type=int, default=10)
parser.add_argument('-masked_lm_prob', dest='masked_lm_prob', type=float, default=0.6)
parser.add_argument('-mask_prob', dest='mask_prob', type=float, default=1.0)
parser.add_argument('-prop_sliding_window', dest='prop_sliding_window', type=float, default=0.1)
parser.add_argument(
'-max_seq_length', dest='max_seq_length', type=int, default=50)
parser.add_argument(
'-max_predictions_per_seq',
dest='max_predictions_per_seq',
type=int,
default=30)
parser.add_argument(
'-dupe_factor', dest='dupe_factor', type=int, default=10)
parser.add_argument(
'-masked_lm_prob', dest='masked_lm_prob', type=float, default=0.6)
parser.add_argument(
'-mask_prob', dest='mask_prob', type=float, default=1.0)
parser.add_argument(
'-prop_sliding_window',
dest='prop_sliding_window',
type=float,
default=0.1)
parser.add_argument('-dataset_name', dest='dataset_name', default='beauty')
parser.add_argument('-test_set_dir', dest='test_set_dir', default="data/test/beauty-test.txt")
parser.add_argument('-vocab_path', dest='vocab_path', default="data/beauty.vocab")
parser.add_argument(
'-test_set_dir',
dest='test_set_dir',
default="data/test/beauty-test.txt")
parser.add_argument(
'-vocab_path', dest='vocab_path', default="data/beauty.vocab")
parser.add_argument('-data_dir', dest='data_dir', default='data/')
args = parser.parse_args()
main(args)

16 changes: 11 additions & 5 deletions models/rank/bert4rec/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self, data_path, config):
self.data_dir = data_path
self.config = config
self.batch_size = self.config.get("runner.data_batch_size")
self.max_len = self.config.get("hyper_parameters._max_position_seq_len")
self.max_len = self.config.get(
"hyper_parameters._max_position_seq_len")

def __iter__(self):
cnt = 0
Expand All @@ -51,7 +52,10 @@ def __iter__(self):
tmp_mask = split_samples[2].split(',')
input_mask.append([[int(x)] for x in tmp_mask])
tmp_mask_pos = split_samples[4].split(',')
mask_pos = mask_pos + [[int(x) + (sample_count % self.batch_size) * self.max_len] for x in tmp_mask_pos]
mask_pos = mask_pos + [[
int(x) +
(sample_count % self.batch_size) * self.max_len
] for x in tmp_mask_pos]
tmp_label = split_samples[5].split(',')
mask_label = mask_label + [[int(x)] for x in tmp_label]
sample_count += 1
Expand Down Expand Up @@ -99,8 +103,10 @@ def __iter__(self):
tmp_mask = split_samples[2].split(',')
input_mask.append([[int(x)] for x in tmp_mask])
tmp_mask_pos = split_samples[4].split(',')
mask_pos = mask_pos + [[int(x) + (sample_count % self.batch_size) * self.max_len] for x in
tmp_mask_pos]
mask_pos = mask_pos + [[
int(x) +
(sample_count % self.batch_size) * self.max_len
] for x in tmp_mask_pos]
tmp_label = split_samples[5].split(',')
mask_label = mask_label + [[int(x)] for x in tmp_label]

Expand All @@ -123,4 +129,4 @@ def __iter__(self):
output_list.append(candidate_)

cand_list = cand_list[self.batch_size:]
yield output_list
yield output_list
Loading

0 comments on commit 353c5f7

Please sign in to comment.