-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
92 lines (71 loc) · 2.92 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from flask import Flask, abort, jsonify, request
import json
from inference import InferenceWrapper
from aethel.mill.serialization import serial_proof_to_json, serialize_proof
from aethel.utils.tex import sample_to_tex
import logging
logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger("spindle")
def create_app():
log.info('Loading model')
inferer = InferenceWrapper(
weight_path="./data/model_weights.pt",
atom_map_path="./data/atom_map.tsv",
config_path="./data/bert_config.json",
device="cpu",
) # replace with 'cpu' if no GPU accelaration
log.info('Loaded model')
app = Flask(__name__)
@app.route("/status/", methods=["GET"])
def status():
return jsonify(dict(ok=True))
@app.route("/", methods=["POST"])
def handle_request():
log.info("Request received!")
request_body = request.data.decode("utf-8")
try:
request_body_json = json.loads(request_body)
except json.JSONDecodeError:
log.error("Failed to parse request body as JSON.")
abort(400)
if "input" not in request_body_json:
log.error("Request body does not contain 'input' field.")
abort(400)
spindle_input = request_body_json["input"]
if not isinstance(spindle_input, str):
log.error("Input is not a string.")
abort(400)
log.info("Starting analysis with input:", spindle_input)
results = inferer.analyze([spindle_input])
log.info("Analysis complete!")
log.info("Results: %s", results)
if len(results) < 1:
log.error("Got no results")
abort(500)
analysis = results[0]
# spindle will store an exception value in the proof variable, at least in some failure modes
if isinstance(analysis.proof, Exception):
log.error("Error in analysis", exc_info=analysis.proof)
abort(500)
try:
tex_from_sample = sample_to_tex(analysis)
except:
log.exception("Failed to convert result to TeX.")
abort(500)
return # not necessary given abort, but helps type-checker understand that we leave the function here
log.info("TeX conversion successful.")
log.info("TeX: %s", tex_from_sample)
# prepare json-ready version of proof and lexical phrases
proof = serial_proof_to_json(serialize_proof(analysis.proof))
lexical_phrases = [phrase.json() for phrase in analysis.lexical_phrases]
response = dict(
tex=tex_from_sample,
proof=proof,
lexical_phrases=lexical_phrases)
return jsonify(response)
log.info('App is ready')
return app
if __name__ == "__main__":
log.info("Starting Spindle Server!")
app = create_app()
app.run()