Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interactive mode #52

Merged
merged 8 commits into from
Aug 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,23 @@ Kubernetes by Google’s Bern
```

You can also train a new model, with support for word level embeddings and bidirectional RNN layers by adding `new_model=True` to any train function.

### Interactive mode
It's also possible to get involved in how the output unfolds, step by step. Interactive mode will suggest you the *top N* options for the next char/word, and allows you to pick one.

Just pass `interactive=True` and `top=N`. N defaults to 3.

```python
from textgenrnn import textgenrnn

textgen = textgenrnn()
textgen.generate(interactive=True, top_n=5)
```

![word_level_demo](/docs/word_level_demo.gif)

This can add a *human touch* to the output; it feels like you're the writer! ([reference](https://fivethirtyeight.com/features/some-like-it-bot/))

## Usage

textgenrnn can be installed [from pypi](https://pypi.python.org/pypi/textgenrnn) via `pip`:
Expand Down Expand Up @@ -107,8 +123,6 @@ Additionally, the retraining is done with a momentum-based optimizer and a linea

* A way to visualize the attention-layer outputs to see how the network "learns."

* Supervised text generation mode: allow the model to present the top *n* options and user select the next char/word ([reference](https://fivethirtyeight.com/features/some-like-it-bot/))

* A mode to allow the model architecture to be used for chatbot conversations (may be released as a separate project)

* More depth toward context (positional context + allowing multiple context labels)
Expand Down
Binary file added docs/word_level_demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 5 additions & 2 deletions textgenrnn/textgenrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(self, weights_path=None,
self.indices_char = dict((self.vocab[c], c) for c in self.vocab)

def generate(self, n=1, return_as_list=False, prefix=None,
temperature=0.5, max_gen_length=300):
temperature=0.5, max_gen_length=300, interactive=False,
top_n=3):
gen_texts = []
for _ in range(n):
gen_text = textgenrnn_generate(self.model,
Expand All @@ -79,7 +80,9 @@ def generate(self, n=1, return_as_list=False, prefix=None,
self.config['word_level'],
self.config.get(
'single_text', False),
max_gen_length)
max_gen_length,
interactive,
top_n)
if not return_as_list:
print("{}\n".format(gen_text))
gen_texts.append(gen_text)
Expand Down
78 changes: 63 additions & 15 deletions textgenrnn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import re


def textgenrnn_sample(preds, temperature):
def textgenrnn_sample(preds, temperature, interactive=False, top_n=3):
'''
Samples predicted probabilities of the next character to allow
for the network to show "creativity."
Expand All @@ -26,12 +26,18 @@ def textgenrnn_sample(preds, temperature):
exp_preds = np.exp(preds)
preds = exp_preds / np.sum(exp_preds)
probas = np.random.multinomial(1, preds, 1)
index = np.argmax(probas)

# prevent function from being able to choose 0 (placeholder)
# choose 2nd best index from preds
if index == 0:
index = np.argsort(preds)[-2]

if not interactive:
index = np.argmax(probas)

# prevent function from being able to choose 0 (placeholder)
# choose 2nd best index from preds
if index == 0:
index = np.argsort(preds)[-2]
else:
# return list of top N chars/words
# descending order, based on probability
index = (-preds).argsort()[:top_n]

return index

Expand All @@ -41,11 +47,15 @@ def textgenrnn_generate(model, vocab,
maxlen=40, meta_token='<s>',
word_level=False,
single_text=False,
max_gen_length=300):
max_gen_length=300,
interactive=False,
top_n=3):
'''
Generates and returns a single text.
'''

collapse_char = ' ' if word_level else ''

# If generating word level, must add spaces around each punctuation.
# https://stackoverflow.com/a/3645946/9314418
if word_level and prefix:
Expand All @@ -72,15 +82,53 @@ def textgenrnn_generate(model, vocab,

while next_char != meta_token and len(text) < max_gen_length:
encoded_text = textgenrnn_encode_sequence(text[-maxlen:],
vocab, maxlen)
vocab, maxlen)
next_temperature = temperature[(len(text) - 1) % len(temperature)]
next_index = textgenrnn_sample(
model.predict(encoded_text, batch_size=1)[0],
next_temperature)
next_char = indices_char[next_index]
text += [next_char]

collapse_char = ' ' if word_level else ''
if not interactive:
# auto-generate text without user intervention
next_index = textgenrnn_sample(
model.predict(encoded_text, batch_size=1)[0],
next_temperature)
next_char = indices_char[next_index]
text += [next_char]
else:
# ask user what the next char/word should be
options_index = textgenrnn_sample(
model.predict(encoded_text, batch_size=1)[0],
next_temperature,
interactive=interactive,
top_n=top_n
)
options = [indices_char[idx] for idx in options_index]
print('Controls:\n\ts: stop.\tx: backspace.\to: write your own.')
print('\nOptions:')

for i, option in enumerate(options, 1):
print('\t{}: {}'.format(i, option))

print('\nProgress: {}'.format(collapse_char.join(text)[3:]))
print('\nYour choice?')
user_input = input('> ')

try:
user_input = int(user_input)
next_char = options[user_input-1]
text += [next_char]
except ValueError:
if user_input == 's':
next_char = '<s>'
text += [next_char]
elif user_input == 'o':
other = input('> ')
text += [other]
elif user_input == 'x':
try:
del text[-1]
except IndexError:
pass
else:
print('That\'s not an option!')

# if single text, ignore sequences generated w/ padding
# if not single text, strip the <s> meta_tokens
Expand Down