Skip to content

Commit

Permalink
fix naming & code style & bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
xzyaoi committed Sep 9, 2018
1 parent 84c848e commit 5644116
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 20 deletions.
Empty file added cvpm/bootstrap.py
Empty file.
26 changes: 26 additions & 0 deletions cvpm/bundle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import json


def is_jsonable(x):
try:
json.dumps(x)
return True
except:
return False


class Bundle(object):
__SOLVERS__ = []

def add_solver(self, solver):
self.__SOLVERS__.append(solver)

@classmethod
def members(self):
results = {}
for attr in dir(self):
if not attr.startswith("__"):
if is_jsonable(getattr(self, attr)):
results[attr] = getattr(self, attr)
print(results)
return results
30 changes: 18 additions & 12 deletions cvpm/Server.py → cvpm/server.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os
import json
import logging
import os
import socket
import traceback
import logging

import gevent.pywsgi
from flask import g
from flask import Flask
from flask import request
from flask import Flask, g, request
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.utils import secure_filename

logger = logging.getLogger()
Expand All @@ -19,6 +19,7 @@

server.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER


def _isPortOpen(port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = s.connect_ex(('127.0.0.1', port))
Expand All @@ -38,10 +39,11 @@ def get_available_port(start=8080):
break
return port


def allowed_file(filename, phase):
ALLOWED_EXTENSIONS = ALLOWED_EXTENSIONS_TRAIN
if phase == 'infer':
ALLOWED_EXTENSIONS = ALLOWED_EXTENSIONS_INFER
ALLOWED_EXTENSIONS = ALLOWED_EXTENSIONS_INFER
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

Expand All @@ -56,23 +58,27 @@ def help():
def infer():
if request.method == 'POST':
results = {}
config = request.json
print(request)
config = ImmutableMultiDict(request.form)
if 'file' not in request.files:
return json.dumps({"error": "no file part!", "code": "400"}), 400
file = request.files['file']
if file and allowed_file(file.filename, 'infer'):
filename = secure_filename(file.filename)
file_abs_path = os.path.join(server.config['UPLOAD_FOLDER'], filename)
file_abs_path = os.path.join(server.config['UPLOAD_FOLDER'],
filename)
file.save(file_abs_path)
try:
results = server.solver.infer(file_abs_path, config)
results = server.solver.infer(file_abs_path, config.to_dict())
return json.dumps(results), 200
except Exception as e:
traceback.print_exc()
return json.dumps({"error": str(e), "code":"500"}), 500
return json.dumps({"error": str(e), "code": "500"}), 500
else:
return json.dumps({"error": "Forbidden Filename!", "code": "400"}), 400
return json.dumps({
"error": "Forbidden Filename!",
"code": "400"
}), 400


@server.route("/train", methods=["GET", "POST"])
def train():
Expand Down
20 changes: 15 additions & 5 deletions cvpm/Solver.py → cvpm/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import toml
import tqdm

from cvpm.Server import run_server
from cvpm.Utility import BundleAnalyzer, Downloader
from cvpm.bundle import Bundle
from cvpm.server import run_server
from cvpm.utility import BundleAnalyzer, Downloader


class Solver(object):
def __init__(self, toml_file=None):
self._isReady = False
self.bundle = {}
self.bundle = None
self._enable_train = False
if toml_file is None:
toml_file = "./pretrained/pretrained.toml"
Expand All @@ -27,8 +28,11 @@ def is_ready(self):

@property
def help_message(self):
ba = BundleAnalyzer(self.bundle)
return json.dumps(ba.load())
if self.is_ready:
members = self.bundle.members()
return json.dumps(members)
else:
return json.dumps({"error": "Initializing...", "code": "101"}), 101

def _prepare_models(self, toml_file):
parsed_toml = toml.load(toml_file)
Expand All @@ -39,6 +43,12 @@ def _prepare_models(self, toml_file):
def set_ready(self):
self._isReady = True

def set_bundle(self, bundle):
if issubclass(bundle, Bundle):
self.bundle = bundle
solver = self
bundle.add_solver(self=bundle, solver=solver)

def infer(self, input, config):
pass

Expand Down
4 changes: 1 addition & 3 deletions cvpm/Utility.py → cvpm/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def __init__(self, bundle):
def load(self):
result = {}
for name, value in vars(self.bundle).items():
if name not in [
"__doc__", "__module__", "__dict__", "__weakref__"
]:
if not name.startsWith("__"):
result[name] = value
return result

0 comments on commit 5644116

Please sign in to comment.