Skip to content

Commit

Permalink
fix bokeh server port collision issue
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored May 2, 2023
1 parent 8da9304 commit d11f625
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 235 deletions.
54 changes: 30 additions & 24 deletions Docs/torch_code_examples/visualization_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,35 +55,41 @@ def model_compression_with_visualization(eval_func):
"""
Code example for compressing a model with a visualization url provided.
"""
visualization_url, process = start_bokeh_server_session(8002)
process = None
try:
visualization_url, process = start_bokeh_server_session()

input_shape = (1, 3, 224, 224)
model = models.resnet18(pretrained=True).to(torch.device('cuda'))
input_shape = (1, 3, 224, 224)
model = models.resnet18(pretrained=True).to(torch.device('cuda'))

modules_to_ignore = [model.conv1]
modules_to_ignore = [model.conv1]

greedy_params = aimet_common.defs.GreedySelectionParameters(target_comp_ratio=Decimal(0.65),
num_comp_ratio_candidates=10,
saved_eval_scores_dict=
'../data/resnet18_eval_scores.pkl')
greedy_params = aimet_common.defs.GreedySelectionParameters(target_comp_ratio=Decimal(0.65),
num_comp_ratio_candidates=10,
saved_eval_scores_dict=
'../data/resnet18_eval_scores.pkl')

auto_params = aimet_torch.defs.SpatialSvdParameters.AutoModeParams(greedy_params,
modules_to_ignore=modules_to_ignore)
auto_params = aimet_torch.defs.SpatialSvdParameters.AutoModeParams(greedy_params,
modules_to_ignore=modules_to_ignore)

params = aimet_torch.defs.SpatialSvdParameters(aimet_torch.defs.SpatialSvdParameters.Mode.auto, auto_params,
multiplicity=8)
params = aimet_torch.defs.SpatialSvdParameters(aimet_torch.defs.SpatialSvdParameters.Mode.auto, auto_params,
multiplicity=8)

# If no visualization URL is provided, during model compression execution no visualizations will be published.
ModelCompressor.compress_model(model=model, eval_callback=eval_func, eval_iterations=5,
input_shape=input_shape,
compress_scheme=aimet_common.defs.CompressionScheme.spatial_svd,
cost_metric=aimet_common.defs.CostMetric.mac, parameters=params,
visualization_url=None)
# If no visualization URL is provided, during model compression execution no visualizations will be published.
ModelCompressor.compress_model(model=model, eval_callback=eval_func, eval_iterations=5,
input_shape=input_shape,
compress_scheme=aimet_common.defs.CompressionScheme.spatial_svd,
cost_metric=aimet_common.defs.CostMetric.mac, parameters=params,
visualization_url=None)

comp_ratios_file_path = './data/greedy_selection_comp_ratios_list.pkl'
eval_scores_path = '../data/resnet18_eval_scores.pkl'
comp_ratios_file_path = './data/greedy_selection_comp_ratios_list.pkl'
eval_scores_path = '../data/resnet18_eval_scores.pkl'

# A user can visualize the eval scores dictionary and optimal compression ratios by executing the following code.
compression_visualizations = VisualizeCompression(visualization_url)
compression_visualizations.display_eval_scores(eval_scores_path)
compression_visualizations.display_comp_ratio_plot(comp_ratios_file_path)
# A user can visualize the eval scores dictionary and optimal compression ratios by executing the following code.
compression_visualizations = VisualizeCompression(visualization_url)
compression_visualizations.display_eval_scores(eval_scores_path)
compression_visualizations.display_comp_ratio_plot(comp_ratios_file_path)
finally:
if process:
process.terminate()
process.join()
54 changes: 40 additions & 14 deletions TrainingExtensions/common/src/python/aimet_common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import math
import os
import signal
import socket
import subprocess
import threading
import time
Expand Down Expand Up @@ -252,23 +251,50 @@ def kill_process_with_name_and_port_number(name: str, port_number: int):
break


def start_bokeh_server_session(port: int):
def start_bokeh_server_session(port: int = None):
"""
start a bokeh server programmatically. Used for testing purposes.
:param port: port number
:param port: Port number. If not specified, bokeh server will listen on an arbitrary free port.
:return: Returns the Bokeh Server URL and the process object used to create the child server process
"""

host_name = socket.gethostname()
bokeh_serve_command = "bokeh serve --allow-websocket-origin=" + \
host_name + ":" + str(port) + " --allow-websocket-origin=localhost:" + str(port) + " --port=" + str(port) + " &"
process = subprocess.Popen(bokeh_serve_command, # pylint: disable=subprocess-popen-preexec-fn
shell=True,
preexec_fn=os.setsid)
url = "http://" + host_name + ":" + str(port)
# Doesn't allow document to be added to server unless there is some sort of wait time.
time.sleep(4)
return url, process
from bokeh.server.server import Server
from bokeh.application import Application
import multiprocessing
manager = multiprocessing.Manager()
d = manager.dict()
server_started = manager.Event()

def start_bokeh_server(port: int = None):
os.setsid()

# If port is 0, server automatically finds and listens on an arbitrary free port.
port = port or 0
try:
server = Server({'/': Application()}, port=port)
server.start()
d['port'] = server.port
server_started.set()
server.run_until_shutdown()
except Exception as e:
d['exception'] = e
raise

proc = multiprocessing.Process(target=start_bokeh_server, args=(port,))

proc.start()
server_started.wait(timeout=3)

if 'port' not in d:
if 'exception' in d:
e = d['exception']
raise RuntimeError(f'Bokeh server failed with the following error: {e}')

raise RuntimeError('Bokeh Server failed with an unknown error')

port = d['port']
address = f'http://localhost:{port}'

return address, proc


def log_package_info():
Expand Down
Loading

0 comments on commit d11f625

Please sign in to comment.