Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for multiple frontend instances and a single Skyline server #29

Merged
merged 1 commit into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions protocol/innpv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ message InitializeRequest {
// compatability. If the protocol version is unsupported by the server, it
// will respond with a ProtocolError.
uint32 protocol_version = 1;

string project_root = 2;
string entry_point = 3;
}

message AnalysisRequest {
Expand Down
3 changes: 2 additions & 1 deletion skyline/analysis/request_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def _handle_analysis_request(self, analysis_request, context):
context.sequence_number,
*(context.address),
)
connection = self._connection_manager.get_connection(context.address)
analyzer = analyze_project(
Config.project_root, Config.entry_point, self._nvml)
connection.project_root, connection.entry_point, self._nvml)

# Abort early if the connection has been closed
if not context.state.connected:
Expand Down
7 changes: 0 additions & 7 deletions skyline/commands/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ def register_command(subparsers):
"interactive",
help="Start a new Skyline interactive profiling session.",
)
parser.add_argument(
"entry_point",
help="The entry point file in this project that contains the Skyline "
"provider functions.",
)
parser.add_argument(
"--host",
default="",
Expand Down Expand Up @@ -93,8 +88,6 @@ def signal_handler(signal, frame):
"Listening on port %d.",
port,
)
logger.info("Project Root: %s", Config.project_root)
logger.info("Entry Point: %s", Config.entry_point)

# Run the server until asked to terminate
should_shutdown.wait()
Expand Down
3 changes: 2 additions & 1 deletion skyline/commands/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ def actual_main(args):
'samples_per_second',
'memory_usage_bytes',
])
project_root = os.cwd()
for batch_size in args.batch_sizes:
for trial in range(args.trials):
session = AnalysisSession.new_from(
Config.project_root, Config.entry_point)
project_root, args.entry_point)
samples_per_second, memory_usage_bytes = make_measurements(
session, batch_size)
writer.writerow([
Expand Down
3 changes: 2 additions & 1 deletion skyline/commands/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def actual_main(args):
sys.exit(1)

try:
project_root = os.cwd()
session = AnalysisSession.new_from(
Config.project_root, Config.entry_point)
project_root, args.entry_point)
session.generate_memory_usage_report(
save_report_to=args.output,
)
Expand Down
3 changes: 2 additions & 1 deletion skyline/commands/prediction_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ def actual_main(args):
'memory_usage_bytes_slope',
'memory_usage_bytes_bias',
])
project_root = os.cwd()
for batch_size in args.batch_sizes:
session = AnalysisSession.new_from(
Config.project_root, Config.entry_point)
project_root, args.entry_point)
memory_model, run_time_model = get_model(
session, batch_size)
writer.writerow([
Expand Down
3 changes: 2 additions & 1 deletion skyline/commands/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def actual_main(args):
sys.exit(1)

try:
project_root = os.cwd()
session = AnalysisSession.new_from(
Config.project_root, Config.entry_point)
project_root, args.entry_point)
session.generate_run_time_breakdown_report(
save_report_to=args.output,
)
Expand Down
7 changes: 0 additions & 7 deletions skyline/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ def __init__(self):
self.warm_up = 100
self.measure_for = 10

self.project_root = None
self.entry_point = None

def initialize_hints_config(self, hints_file):
if hints_file is None:
file_to_open = skyline.data.get_absolute_path('hints.yml')
Expand All @@ -32,9 +29,5 @@ def parse_args(self, args):
if 'measure_for' in args and args.measure_for is not None:
self.measure_for = args.measure_for

def set_project_paths(self, project_root, entry_point):
self.project_root = project_root
self.entry_point = entry_point


Config = _Config()
38 changes: 0 additions & 38 deletions skyline/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,7 @@ def initialize_skyline(args):
"""
from skyline.config import Config

project_root = os.getcwd()
entry_point = args.entry_point
if not _validate_paths(project_root, entry_point):
sys.exit(1)

Config.parse_args(args)
Config.set_project_paths(project_root, entry_point)


def _configure_logging(args):
kwargs = {
Expand Down Expand Up @@ -76,34 +69,3 @@ def _validate_gpu():
)
return False
return True


def _validate_paths(project_root, entry_point):
if not os.path.isabs(project_root):
logger.error(
"The project root that Skyline received is not an absolute path. "
"This is an unexpected error. Please report a bug."
)
logger.error("Current project root: %s", project_root)
return False

if os.path.isabs(entry_point):
logger.error(
"The entry point must be specified as a relative path to the "
"current directory. Please double check that the entry point you "
"are providing is a relative path.",
)
logger.error("Current entry point path: %s", entry_point)
return False

full_path = os.path.join(project_root, entry_point)
if not os.path.isfile(full_path):
logger.error(
"Either the specified entry point is not a file or its path was "
"specified incorrectly. Please double check that it exists and "
"that its path is correct.",
)
logger.error("Current absolute path to entry point: %s", full_path)
return False

return True
15 changes: 14 additions & 1 deletion skyline/io/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(self, socket, address, handler_function, closed_handler):
self._handler_function = handler_function
self._closed_handler = closed_handler
self._sentinel = Sentinel()
self._project_root = ""
self._entry_point = ""

def start(self):
self._sentinel.start()
Expand Down Expand Up @@ -91,7 +93,18 @@ def _socket_read(self):

except:
logger.exception("Connection unexpectedly stopping...")


@property
def project_root(self):
return self._project_root

@property
def entry_point(self):
return self._entry_point

def set_project_paths(self, project_root, entry_point):
self._project_root = project_root
self._entry_point = entry_point

class ConnectionState:
def __init__(self):
Expand Down
49 changes: 49 additions & 0 deletions skyline/protocol/message_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import logging
import os

from skyline.exceptions import AnalysisError, NoConnectionError

Expand All @@ -12,6 +13,36 @@
['address', 'state', 'sequence_number'],
)

def _validate_paths(project_root, entry_point):
if not os.path.isabs(project_root):
logger.error(
"The project root that Skyline received is not an absolute path. "
"This is an unexpected error. Please report a bug."
)
logger.error("Current project root: %s", project_root)
return False

if os.path.isabs(entry_point):
logger.error(
"The entry point must be specified as a relative path to the "
"current directory. Please double check that the entry point you "
"are providing is a relative path.",
)
logger.error("Current entry point path: %s", entry_point)
return False

full_path = os.path.join(project_root, entry_point)
if not os.path.isfile(full_path):
logger.error(
"Either the specified entry point is not a file or its path was "
"specified incorrectly. Please double check that it exists and "
"that its path is correct.",
)
logger.error("Current absolute path to entry point: %s", full_path)
return False

return True


class MessageHandler:
def __init__(
Expand Down Expand Up @@ -49,6 +80,23 @@ def _handle_initialize_request(self, message, context):
)
return


if not _validate_paths(message.project_root, message.entry_point):
# Change this to the error related to
self._message_sender.send_protocol_error(
pm.ProtocolError.ErrorCode.UNSUPPORTED_PROTOCOL_VERSION,
context,
)
self._connection_manager.remove_connection(context.address)
logger.error(
'Invalid project root or entry point.'
)
return
logger.info("Connection addr:(%s:%d)", *context.address)
logger.info("Project Root: %s", message.project_root)
logger.info("Entry Point: %s", message.entry_point)
self._connection_manager.get_connection(context.address).set_project_paths(message.project_root, message.entry_point)

context.state.initialized = True
self._message_sender.send_initialize_response(context)

Expand Down Expand Up @@ -107,3 +155,4 @@ def handle_message(self, raw_data, address):
'Processing message from (%s:%d) resulted in an exception.',
*address,
)

5 changes: 3 additions & 2 deletions skyline/protocol/message_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ def __init__(self, connection_manager):

def send_initialize_response(self, context):
message = pm.InitializeResponse()
message.server_project_root = Config.project_root
message.entry_point.components.extend(Config.entry_point.split(os.sep))
connection = self._connection_manager.get_connection(context.address)
message.server_project_root = connection.project_root
message.entry_point.components.extend(connection.entry_point.split(os.sep))

# Populate hardware info
message.hardware.hostname = platform.node()
Expand Down
Loading