-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathanalysis.py
117 lines (96 loc) · 2.94 KB
/
analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# -*- coding: utf-8 -*-
# @Author: Aastha Gupta
# @Date: 2017-04-21 03:19:48
# @Last Modified by: Aastha Gupta
# @Last Modified time: 2017-05-19 15:51:13
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Dropout, LSTM, TimeDistributed, Activation
import config
import os
import glob
import pickle
import string
def generate(seed_char):
global model
global char_to_int
global int_to_char
text = seed_char
seed = char_to_int[seed_char]
X = np.zeros((1, config.LEN_TO_GEN_2, config.VOCAB_SIZE))
i = 0
while i < config.LEN_TO_GEN_2:
X[0, i, :][seed] = 1.
Y = model.predict(X[:, :i+1, :])[0]
next = np.argmax(Y, 1)
seed = next[-1]
next_char = int_to_char[seed]
text = text + next_char
i = i + 1
# print (int_to_char[seed],seed)
# print (X[0, i, seed])
# print (Y)
# print(next)
# print (int_to_char[seed],seed)
# if next_char == '\n':
# break
return text,i
# load data
pkl_filename = os.path.join(config.PATH,"dict.pkl")
with open(pkl_filename,"rb") as f:
char_to_int = pickle.load(f)
pkl_filename = os.path.join(config.PATH,"rev_dict.pkl")
with open(pkl_filename,"rb") as f:
int_to_char = pickle.load(f)
# define the LSTM model
model = Sequential()
# define model's network here
model.add(LSTM(config.HIDDEN_DIM, input_shape=(config.SEQ_LENGTH, config.VOCAB_SIZE), return_sequences=True))
model.add(Dropout(0.2))
for i in range(config.LAYER_NUM - 1):
if i != config.LAYER_NUM - 2:
model.add(LSTM(config.HIDDEN_DIM, return_sequences=True))
else:
model.add(LSTM(config.HIDDEN_DIM))
model.add(Dense(config.VOCAB_SIZE, activation = "softmax"))
# open file handle
f = open(config.ANALYSIS_FILE,"w")
seed = "life"
print (len(seed))
output = seed
padding = ""
for i in range(config.SEQ_LENGTH-len(seed)):
padding = padding + " "
# load the network weights
filepath = config.CHKPT_PATH
for filename in glob.glob(os.path.join(filepath, '*.hdf5')):
print(os.path.basename(filename)[:-5])
f.write(os.path.basename(filename)[:-5]+"\n")
model.load_weights(filename)
# compile this model
model.compile(loss="categorical_crossentropy", metrics=['accuracy'])
X_sequence = padding + seed
X_sequence_int = [char_to_int[c] for c in X_sequence]
for i in range(config.LEN_TO_GEN):
input_sequence = np.zeros((1, config.SEQ_LENGTH, config.VOCAB_SIZE))
for j in range(config.SEQ_LENGTH):
input_sequence[0][j][X_sequence_int[j]] = 1.
Y = model.predict(input_sequence)
next_int = np.argmax(Y, 1)[-1]
next_char = int_to_char[next_int]
# print(next_int,next_char)
output = output + next_char
del(X_sequence_int[0])
X_sequence_int.append(next_int)
# format output
output = "\n".join(s.capitalize() for s in output.split("\n"))
# print output
print(output)
# save generated lyrics in output file
output_filename = config.OUTPUT_FILE
with open(output_filename,"w") as f:
f.write(output)
print( "Saved in fun.txt file! :)" )
# close file handle
f.close()
print( "Saved in analysis.txt file!" )