forked from atpaino/deep-text-corrector
-
Notifications
You must be signed in to change notification settings - Fork 1
/
dtc_lambda.py
85 lines (63 loc) · 2.46 KB
/
dtc_lambda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import print_function
import json
import os
import pickle
import boto3
import tensorflow as tf
from correct_text import create_model, DefaultMovieDialogConfig, decode_sentence
from text_corrector_data_readers import MovieDialogReader
def safe_mkdir(path):
try:
os.mkdir(path)
except OSError:
pass
def download(client, filename, local_path=None, s3_path=None):
if s3_path is None:
s3_path = MODEL_PARAMS_DIR + "/" + filename
if local_path is None:
local_path = os.path.join(MODEL_PATH, filename)
print("Downloading " + filename)
client.download_file(BUCKET_NAME, s3_path, local_path)
# Define resources on S3.
BUCKET_NAME = "deeptextcorrecter"
ROOT_DATA_PATH = "/tmp/"
MODEL_PARAMS_DIR = "model_params"
MODEL_PATH = os.path.join(ROOT_DATA_PATH, MODEL_PARAMS_DIR)
# Create tmp dirs for storing data locally.
safe_mkdir(ROOT_DATA_PATH)
safe_mkdir(MODEL_PATH)
# Download files from S3 to local disk.
s3_client = boto3.client('s3')
model_ckpt = "41900"
tf_meta_filename = "translate.ckpt-{}.meta".format(model_ckpt)
download(s3_client, tf_meta_filename)
tf_params_filename = "translate.ckpt-{}".format(model_ckpt)
download(s3_client, tf_params_filename)
tf_ckpt_filename = "checkpoint"
download(s3_client, tf_ckpt_filename)
corrective_tokens_filename = "corrective_tokens.pickle"
corrective_tokens_path = os.path.join(ROOT_DATA_PATH,
corrective_tokens_filename)
download(s3_client, corrective_tokens_filename,
local_path=corrective_tokens_path)
token_to_id_filename = "token_to_id.pickle"
token_to_id_path = os.path.join(ROOT_DATA_PATH, token_to_id_filename)
download(s3_client, token_to_id_filename, local_path=token_to_id_path)
# Load model.
config = DefaultMovieDialogConfig()
sess = tf.Session()
print("Loading model")
model = create_model(sess, True, MODEL_PATH, config=config)
print("Loaded model")
with open(corrective_tokens_path) as f:
corrective_tokens = pickle.load(f)
with open(token_to_id_path) as f:
token_to_id = pickle.load(f)
data_reader = MovieDialogReader(config, token_to_id=token_to_id)
print("Done initializing.")
def process_event(event, context):
print("Received event: " + json.dumps(event, indent=2))
outputs = decode_sentence(sess, model, data_reader, event["text"],
corrective_tokens=corrective_tokens,
verbose=False)
return {"input": event["text"], "output": " ".join(outputs)}