forked from sherjilozair/char-rnn-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
barry.py
130 lines (117 loc) · 6.11 KB
/
barry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import discord
import asyncio
import random
import re
import argparse
from six import text_type
import json
import os
import train as tr
import sample as sp
import importlib
import sys
import tensorflow as tf
import re
# load SECRET DATA from JSON file
with open('client_info.json') as f:
client_info = json.load(f)
client_secret = client_info['secret']
client_channel = client_info['channel']
admin_id = client_info['admin']
# taken from sample.py and train.py to pass arguments to train files
sampleParser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
sampleParser.add_argument('--save_dir', type=str, default='save',
help='model directory to store checkpointed models')
sampleParser.add_argument('-n', type=int, default=500,
help='number of characters to sample')
sampleParser.add_argument('--sample', type=int, default=1,
help='0 to use max at each timestep, 1 to sample at '
'each timestep, 2 to sample on spaces')
sampleParser.add_argument('--prime', type=text_type, default='',
help='prime text')
sampleArgs = sampleParser.parse_args()
trainParser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Data and model checkpoints directories
trainParser.add_argument('--data_dir', type=str, default='data',
help='data directory containing input.txt with training examples')
trainParser.add_argument('--save_dir', type=str, default='save',
help='directory to store checkpointed models')
trainParser.add_argument('--log_dir', type=str, default='logs',
help='directory to store tensorboard logs')
trainParser.add_argument('--save_every', type=int, default=1000,
help='Save frequency. Number of passes between checkpoints of the model.')
# Model params
trainParser.add_argument('--model', type=str, default='lstm',
help='lstm, rnn, gru, or nas')
trainParser.add_argument('--rnn_size', type=int, default=128,
help='size of RNN hidden state')
trainParser.add_argument('--num_layers', type=int, default=2,
help='number of layers in the RNN')
# Optimization
trainParser.add_argument('--seq_length', type=int, default=50,
help='RNN sequence length. Number of timesteps to unroll for.')
trainParser.add_argument('--batch_size', type=int, default=50,
help="""minibatch size. Number of sequences propagated through the network in parallel.
Pick batch-sizes to fully leverage the GPU (e.g. until the memory is filled up)
commonly in the range 10-500.""")
trainParser.add_argument('--num_epochs', type=int, default=50,
help='number of epochs. Number of full passes through the training examples.')
trainParser.add_argument('--grad_clip', type=float, default=5.,
help='clip gradients at this value')
trainParser.add_argument('--learning_rate', type=float, default=0.002,
help='learning rate')
trainParser.add_argument('--decay_rate', type=float, default=0.97,
help='decay rate for rmsprop')
trainParser.add_argument('--output_keep_prob', type=float, default=1.0,
help='probability of keeping weights in the hidden layer')
trainParser.add_argument('--input_keep_prob', type=float, default=1.0,
help='probability of keeping weights in the input layer')
if os.path.isfile('save/config.pk1'):
trainParser.add_argument('--init_from', type=str, default='save', help="")
else:
trainParser.add_argument('--init_from', type=str, default=None, help="")
trainArgs = trainParser.parse_args()
client = discord.Client()
training = False
@client.event
async def on_ready():
print('Logged in as')
print(client.user.name)
print(client.user.id)
print('------')
@client.event
async def on_message(message):
global training
if client.user.mentioned_in(message):
sp.sample(sampleArgs, re.sub('(<@|<@!)([0-9])+>', '', message.content))
tf.reset_default_graph()
with open('output/output.txt', 'r') as the_file:
lines = the_file.read().split('\\r\\n')
# the training data im using produced a lot of double-escaped unicode, e.g. \\xf012 or something like that, so it has to decode twice, but python is funky so this is the ugly, horrible fix
# remove the re.sub() to allow barry to tag people
await client.send_message(discord.Object(id=client_channel), re.sub('(<@|<@!)([0-9])+>', '', lines[1].encode('ascii').decode('unicode_escape').encode('ascii').decode('unicode_escape')), tts=bool(random.getrandbits(1)))
elif message.content.startswith('!record') and message.author.id == admin_id:
print('Recording...')
with open('data/input.txt', 'w') as the_file:
async for log in client.logs_from(message.channel, limit=1000000000000000):
messageEncode = str(log.content.encode("utf-8"))[2:-1]
template = '{message}\n'
try:
the_file.write(template.format(message=messageEncode))
except:
the_file.write(template.format(message=messageEncode))
print('Data Collected from ' + message.channel.name)
elif message.content.startswith('!train') and message.author.id == admin_id:
if training != True:
# status change doesnt work right now. starting the training turns off the discord bot, probably just due to how the training function works
await client.change_presence(game=None, status='with his brain', afk=False)
training = True
tr.train(trainArgs)
elif training == True:
await client.change_presence(game=None, status=None, afk=False)
training = False
elif message.content.startswith('!leave') and message.author.id == admin_id:
await client.disconnect()
client.run(client_secret, bot=True)