diff --git a/requirements.txt b/requirements.txt index 42c7166..c3727b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ pandas==1.4.0 -pymongo==4.0.1 matplotlib==3.5.1 requests==2.27.1 requests_futures==1.0.0 @@ -7,4 +6,4 @@ boto3==1.20.44 sklearn numpy==1.23.4 lmfit==1.0.3 -scipy==1.9.3 \ No newline at end of file +scipy==1.9.3 diff --git a/spot/Spot.py b/spot/Spot.py index 0f0dc91..1ffeee9 100644 --- a/spot/Spot.py +++ b/spot/Spot.py @@ -3,32 +3,27 @@ import time import os import numpy as np +import pickle from datetime import datetime from spot.prices.aws_price_retriever import AWSPriceRetriever from spot.logs.aws_log_retriever import AWSLogRetriever -from spot.invocation.aws_function_invocator import AWSFunctionInvocator from spot.invocation.aws_lambda_invoker import AWSLambdaInvoker -from spot.configs.aws_config_retriever import AWSConfigRetriever -from spot.mlModel.linear_regression import LinearRegressionModel -from spot.invocation.config_updater import ConfigUpdater -from spot.db.db import DBClient +from spot.context import Context from spot.benchmark_config import BenchmarkConfig -from spot.constants import ROOT_DIR -from spot.visualize.Plot import Plot from spot.recommendation_engine.recommendation_engine import RecommendationEngine from spot.constants import * -from spot.mlModel.polynomial_regression import PolynomialRegressionModel -from spot.logs.log_propagation_waiter import LogPropagationWaiter class Spot: - def __init__(self, config_dir: str, model: str): + def __init__(self, config_dir: str, aws_session): # Load configuration values from config.json self.config: BenchmarkConfig self.path: str = config_dir self.workload_file_path = os.path.join(self.path, "workload.json") self.config_file_path = os.path.join(self.path, "config.json") - self.db = DBClient() + + # TODO: implement checkpoint & restore on Context (loading from pickle?). + self.ctx = Context() with open(self.config_file_path) as f: self.config = BenchmarkConfig() @@ -37,133 +32,46 @@ def __init__(self, config_dir: str, model: str): json.dump(self.config.workload, json_file, indent=4) self.benchmark_dir = self.path - self.log_prop_waiter = LogPropagationWaiter(self.config.function_name) - try: - self.last_log_timestamp = self.db.execute_max_value( - self.config.function_name, DB_NAME_LOGS, "timestamp" - ) - except: - print( - "No data for the serverless function found yet. Setting last timestamp for the serverless function to 0.", - ) - self.last_log_timestamp = 0 + # try: + # self.last_log_timestamp = self.ctx.execute_max_value( + # self.config.function_name, DB_NAME_LOGS, "timestamp" + # ) + # except: + # print( + # "No data for the serverless function found yet. Setting last timestamp for the serverless function to 0.", + # ) + # self.last_log_timestamp = None + self.last_log_timestamp = None - # Create function db if not exists - self.db.create_function_db(self.config.function_name) + self.ctx.create_function_df(self.config.function_name) # Instantiate SPOT system components - self.price_retriever = AWSPriceRetriever(self.db, self.config.region) - self.log_retriever = AWSLogRetriever(self.config.function_name) - # self.function_invocator = AWSFunctionInvocator( - # self.workload_file_path, - # self.config.function_name, - # self.config.mem_size, - # self.config.region, - # self.db, - # ) - self.function_invoker = AWSLambdaInvoker(lambda_name=self.config.function_name) - self.config_retriever = AWSConfigRetriever(self.config.function_name, self.db) - self.sampler = RecommendationEngine( - self.function_invoker, self.workload_file_path, self.config.workload + self.price_retriever = AWSPriceRetriever(self.ctx, self.config.region) + self.log_retriever = AWSLogRetriever( + self.ctx, aws_session, self.config.function_name + ) + function_invoker = AWSLambdaInvoker( + self.ctx, aws_session, self.config.function_name + ) + self.recommendation_engine = RecommendationEngine( + function_invoker, self.workload_file_path, self.config.workload ) - # self.ml_model = self.select_model(model) - # self.recommendation_engine = RecommendationEngine( - # self.config_file_path, - # self.config, - # self.ml_model, - # self.db, - # self.benchmark_dir, - # ) - - def invoke(self): - # fetch configs and most up to date prices - self.config_retriever.get_latest_config() - self.price_retriever.fetch_current_pricing() - - # invoke function - start = datetime.now().timestamp() - self.function_invocator.invoke_all() - self.log_prop_waiter.wait_by_count(start, self.function_invocator.invoke_cnt) def optimize(self): - self.sampler.run() + self.recommendation_engine.run() def collect_data(self): # retrieve latest config, logs, pricing scheme - self.config_retriever.get_latest_config() self.price_retriever.fetch_current_pricing() - # FIXME: now AWSLogRetriever::get_logs returns a pandas DataFrame. - self.last_log_timestamp = self.log_retriever.get_logs() - - def train_model(self): - self.ml_model.fetch_data() - self.ml_model.train_model() - - def select_model(self, model): - if model == "LinearRegression": - return LinearRegressionModel( - self.config.function_name, - self.config.vendor, - self.db, - self.last_log_timestamp, - self.benchmark_dir, - ) - if model == "polynomial": - return PolynomialRegressionModel( - self.config.function_name, - self.config.vendor, - self.db, - self.last_log_timestamp, - self.benchmark_dir, - self.config.mem_bounds, - ) + self.last_log_timestamp = self.log_retriever.get_logs(self.last_log_timestamp) - # Runs the workload with different configs to profile the serverless function - def profile(self): - mem_size = self.config.mem_bounds[0] - start = datetime.now().timestamp() - invoke_cnt = 0 - while mem_size <= self.config.mem_bounds[1]: - print("Invoking sample workload with mem_size: ", mem_size) - # fetch configs and most up to date prices - self.config_retriever.get_latest_config() - self.price_retriever.fetch_current_pricing() - self.function_invocator.invoke_all(mem_size) - invoke_cnt += self.function_invocator.invoke_cnt - mem_size *= 2 - self.log_prop_waiter.wait_by_count(start, invoke_cnt) + def invoke(self, memory_mb): + self.recommendation_engine.invoke_once(memory_mb) - def update_config(self): - self.recommendation_engine.update_config() - - def plot_error_vs_epoch(self): - self.recommendation_engine.plot_error_vs_epoch() - - def plot_config_vs_epoch(self): - self.recommendation_engine.plot_config_vs_epoch() - - def plot_memsize_vs_cost(self): - self.ml_model.plot_memsize_vs_cost() - - def recommend(self): - self.recommendation = self.recommendation_engine.recommend() - - def get_prediction_error_rate(self): - # TODO: ensure it's called after update_config, or ensure memory is updated in invoke() - self.invoke() - self.collect_data() - - log_cnt = self.function_invocator.invoke_cnt - self.ml_model.fetch_data(log_cnt) - - # only take the last few because _df may have already contain data - costs = self.ml_model._df["Cost"].values[-log_cnt:] - pred = self.recommendation_engine.get_pred_cost() - err = sum([(cost - pred) ** 2 for cost in costs]) / len(costs) - print(f"{err=}") - self.db.add_document_to_collection( - self.config.function_name, DB_NAME_ERROR, {ERR_VAL: err} - ) - self.recommendation_engine.plot_config_vs_epoch() - self.recommendation_engine.plot_error_vs_epoch() + def teardown(self): + # Just saving the Context for now. + os.makedirs(CTX_DIR, exist_ok=True) + ctx_file = os.path.join(CTX_DIR, f"{int(time.time() * 1000)}.pkl") + with open(ctx_file, "wb") as f: + pickle.dump(self.ctx, f) diff --git a/spot/configs/aws_config_retriever.py b/spot/configs/aws_config_retriever.py index 0bb16be..d3828f2 100644 --- a/spot/configs/aws_config_retriever.py +++ b/spot/configs/aws_config_retriever.py @@ -1,11 +1,9 @@ import boto3 -from spot.db.db import DBClient import datetime class AWSConfigRetriever: - def __init__(self, function_name, db: DBClient): - self.DBClient = db + def __init__(self, function_name): self.function_name = function_name def get_latest_config(self): @@ -17,13 +15,4 @@ def get_latest_config(self): ) last_modified_ms = int(last_modified.timestamp() * 1000) config["LastModifiedInMs"] = int(last_modified_ms) - config["Architectures"] = config["Architectures"][0] - self.DBClient.add_new_config_if_changed(self.function_name, "config", config) - - def print_configs(self): - iterator = self.DBClient.get_all_collection_documents( - self.function_name, "config" - ) - for config in iterator: - print(config) diff --git a/spot/constants.py b/spot/constants.py index e4119bb..1e49bea 100644 --- a/spot/constants.py +++ b/spot/constants.py @@ -16,8 +16,6 @@ DURATION_PRICE = "duration_price" REQUEST_PRICE = "request_price" -DB_NAME_PRICING = "pricing" -DB_NAME_CONFIG = "config" DB_NAME_LOGS = "logs" # TODO: maybe put config prediction and error into the same database? DB_NAME_RECOMMENDATION = "recommendation" @@ -26,8 +24,7 @@ DB_ID = "_id" ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) -DATA_DIR = os.path.join(ROOT_DIR, "data") - +CTX_DIR = os.path.join(ROOT_DIR, "__context_cache__") SAMPLE_POINTS = [128, 2048] MEMORY_RANGE = [128, 3008] diff --git a/spot/context.py b/spot/context.py new file mode 100644 index 0000000..74f1b72 --- /dev/null +++ b/spot/context.py @@ -0,0 +1,25 @@ +import os +import subprocess +import json +import pandas as pd +import numpy as np + +from spot.constants import * + + +class Context: + def __init__(self): + self.function_dfs = {} + self.pricing_df = pd.DataFrame() + + # Creates database for the function name if the doesnt exist already + def create_function_df(self, function_name): + self.function_dfs[function_name] = pd.DataFrame() + + def save_invokation_result(self, function_name, result_df): + old = self.function_dfs.get(function_name, pd.DataFrame()) + self.function_dfs[function_name] = pd.concat([old, result_df]) + + def record_pricing(self, row): + df = pd.DataFrame(row) + self.pricing_df = pd.concat([self.pricing_df, df]) diff --git a/spot/db/README.md b/spot/db/README.md deleted file mode 100644 index e708e4e..0000000 --- a/spot/db/README.md +++ /dev/null @@ -1,3 +0,0 @@ -## DB - -The database client class(DBClient) adds a layer of abstraction for different components of SPOT to interact with the local MongoDB database. The module utilizes the PyMongo package to interact with the mongoDB client through a well-defined API. It saves, modifies, reads various data from/to the local database. \ No newline at end of file diff --git a/spot/db/db.py b/spot/db/db.py deleted file mode 100644 index cde54a3..0000000 --- a/spot/db/db.py +++ /dev/null @@ -1,89 +0,0 @@ -import os -import subprocess -import json -from pymongo import MongoClient -import pymongo - - -class DBClient: - def __init__(self, url="localhost", port=27017): - self.url = url - self.port = port - self.client = MongoClient(self.url, self.port) - - # Creates database for the function name if the doesnt exist already - def create_function_db(self, function_name): - new_function_db = self.client[function_name] - return - - def add_collection(self, function_name, collection_name): - function_db = self.client[function_name] - new_collection = function_db[collection_name] - new_collection.insert_one({}) - return - - def get_all_collection_documents(self, function_name, collection_name): - function_db = self.client[function_name] - collection = function_db[collection_name] - return collection.find() - - def add_document_to_collection(self, function_name, collection_name, document): - function_db = self.client[function_name] - collection = function_db[collection_name] - collection.insert_one(document) - return - - def add_document_to_collection_if_not_exists( - self, function_name, collection_name, document, criteria - ): - function_db = self.client[function_name] - collection = function_db[collection_name] - if not collection.find_one(criteria): - collection.insert_one(document) - - def remove_document_from_collection(self, function_name, collection_name, query): - function_db = self.client[function_name] - collection = function_db[collection_name] - collection.delete_one(query) - - def add_new_config_if_changed(self, function_name, collection_name, document): - function_db = self.client[function_name] - collection = function_db[collection_name] - - latest_saved_config = collection.find_one(sort=[("_id", pymongo.DESCENDING)]) - - # Delete unique identifier fields to be able to configure current and most recent config - if latest_saved_config: - del latest_saved_config["_id"] - del latest_saved_config["LastModified"] - del latest_saved_config["RevisionId"] - del latest_saved_config["LastModifiedInMs"] - del latest_saved_config["ResponseMetadata"] - - current_config = document.copy() - del current_config["LastModified"] - del current_config["LastModifiedInMs"] - del current_config["RevisionId"] - del current_config["ResponseMetadata"] - - if not latest_saved_config == current_config: - collection.insert_one(document) - elif not current_config.keys() == latest_saved_config.keys(): - print("Warning: AWS might have changed configuration parameters") - - def execute_query( - self, function_name, collection_name, select_fields, display_fields - ): - function_db = self.client[function_name] - collection = function_db[collection_name] - return collection.find(select_fields, display_fields) - - def execute_max_value(self, function_name: str, collection_name: str, field: str): - function_db = self.client[function_name] - collection = function_db[collection_name] - return collection.find().sort(f"{field}", -1).limit(1)[0][field] - - def get_top_docs(self, function_name: str, collection_name: str, doc_cont: int): - function_db = self.client[function_name] - collection = function_db[collection_name] - return collection.find(sort=[("_id", pymongo.DESCENDING)]).limit(doc_cont) diff --git a/spot/invocation/aws_function_invocator.py b/spot/invocation/aws_function_invocator.py deleted file mode 100644 index 9bea646..0000000 --- a/spot/invocation/aws_function_invocator.py +++ /dev/null @@ -1,146 +0,0 @@ -import concurrent.futures -import json -import os -import time -import threading -import boto3 -import botocore - -from spot.invocation.JSONConfigHelper import CheckJSONConfig, ReadJSONConfig -from spot.invocation.WorkloadChecker import CheckWorkloadValidity -from spot.invocation.EventGenerator import GenericEventGenerator -from spot.invocation.config_updater import ConfigUpdater -from spot.db.db import DBClient - - -class InvalidWorkloadFileException(Exception): - pass - - -class AWSFunctionInvocator: - """ - Invokes a function based on the parameters specified in a workload json file - - Args: - workload_path: file path to workload specifications - function_name: name of the function on lambda - mem_size: the initial memory size for the serverless function to run - region: region of the serverless function - - Attributes: - invoke_cnt: the number of total function invocations of a certain workload setting - - Raises: - InvalidWorkloadFileException: if the workload file is not of json format or some fields have wrong types - """ - - def __init__( - self, - workload_path: str, - function_name: str, - mem_size: int, - region: str, - db: DBClient, - ) -> None: - self._read_workload(workload_path) - self._workload_path: str = os.path.dirname(workload_path) - self._config = ConfigUpdater(function_name, mem_size, region) - self._config.set_mem_size(mem_size) - self._all_events, _ = GenericEventGenerator(self._workload) - self.function_name = function_name - self._futures = [] - self._thread = [] - self.DBClient = db - self.invoke_cnt = 0 - - def _read_workload(self, path: str) -> None: - if not CheckJSONConfig(path): - raise InvalidWorkloadFileException - workload = ReadJSONConfig(path) - if not CheckWorkloadValidity(workload=workload): - raise InvalidWorkloadFileException - self._workload = workload - - def _append_threads(self, instance: str, instance_times: list) -> None: - payload_file = self._workload["instances"][instance]["payload"] - application = self._workload["instances"][instance]["application"] - client = boto3.client("lambda") - - try: - f = open(os.path.join(self._workload_path, payload_file), "r") - except IOError: - f = None - payload = json.load(f) if f else None - self.payload = payload - - self._threads.append( - threading.Thread( - target=self._invoke, args=[client, application, payload, instance_times] - ) - ) - - def _invoke( - self, - client: "botocore.client.logs", - function_name: str, - payload: list, - instance_times: list, - ) -> bool: - with concurrent.futures.ThreadPoolExecutor(max_workers=15) as executor: - # TODO: store input and invocation info to db - st = 0 - after_time, before_time = 0, 0 - cnt = 0 - - for t in instance_times: - st = t - (after_time - before_time) - if st > 0: - time.sleep(st) - input_data = json.dumps( - payload[cnt % len(payload)] if payload else None - ) - cnt += 1 - before_time = time.time() - future = executor.submit( - client.invoke, FunctionName=function_name, Payload=input_data - ) - self.invoke_cnt += 1 - self._futures.append(future) - after_time = time.time() - - return True - - def invoke_all(self, mem: int = -1) -> None: - """Invoke the function with user specified inputs and parameters asynchronously""" - self.invoke_cnt = 0 - self._threads = [] - request_ids = [] - for (instance, instance_times) in self._all_events.items(): - self._config.set_instance( - self._workload["instances"][instance]["application"] - ) - if mem != -1: - self._config.set_mem_size(mem) - self._append_threads(instance, instance_times) - for thread in self._threads: - thread.start() - for thread in self._threads: - thread.join() - for future in self._futures: - res = future.result() - req_id = res["ResponseMetadata"]["RequestId"] - status = res["StatusCode"] - error = False - if status < 200 or status >= 300: - print(f"WARNING: Status code {status} for request id {req_id}") - if "FunctionError" in res: - error = True - print( - f"WARNING: Function error for request id {req_id}. The memory configuration being used may be too low" - ) - print(res["FunctionError"]) - request_ids.append({"_id": req_id, "status": status, "error": error}) - for request in request_ids: - self.DBClient.add_document_to_collection_if_not_exists( - self.function_name, "requests", request, {"_id": request["_id"]} - ) diff --git a/spot/invocation/aws_lambda_invoker.py b/spot/invocation/aws_lambda_invoker.py index 2784aa1..30e06d2 100644 --- a/spot/invocation/aws_lambda_invoker.py +++ b/spot/invocation/aws_lambda_invoker.py @@ -2,7 +2,6 @@ import base64 import time -import boto3 import re import pandas as pd from concurrent.futures import ThreadPoolExecutor @@ -13,11 +12,19 @@ class AWSLambdaInvoker: Invokes AWS Lambda with the specified config. """ - def __init__(self, lambda_name): + def __init__(self, ctx, aws_session, lambda_name): self.lambda_name = lambda_name - self.client = boto3.client("lambda") - - def invoke(self, invocation_count, parallelism, memory_mb, payload_filename): + self.client = aws_session.client("lambda") + self.ctx = ctx + + def invoke( + self, + invocation_count, + parallelism, + memory_mb, + payload_filename, + save_to_ctx=True, + ): """ Invokes the specified lambda with given memory config. Returns pandas DataFrame representing the execution logs @@ -62,7 +69,10 @@ def invoke_sequential(count): if len(errors) != 0: raise LambdaInvocationError(errors) - return pd.DataFrame.from_dict(results) + result_df = pd.DataFrame.from_dict(results) + if save_to_ctx: + self.ctx.save_invokation_result(self.lambda_name, result_df) + return result_df def _check_and_set_memory_value(self, memory_mb): config = self.client.get_function_configuration(FunctionName=self.lambda_name) diff --git a/spot/logs/aws_log_retriever.py b/spot/logs/aws_log_retriever.py index 0894d79..862b62c 100644 --- a/spot/logs/aws_log_retriever.py +++ b/spot/logs/aws_log_retriever.py @@ -1,13 +1,17 @@ -import boto3 -import pandas as pd +import time import re +import pandas as pd + +from spot.context import Context + class AWSLogRetriever: - def __init__(self, function_name, max_log_count=None): + def __init__(self, ctx: Context, aws_session, function_name, max_log_count=None): self.function_name = function_name - self.client = boto3.client("logs") + self.client = aws_session.client("logs") self.max_log_count = max_log_count + self.ctx = ctx def get_logs(self, start_timestamp=None): path = f"/aws/lambda/{self.function_name}" @@ -52,7 +56,9 @@ def get_logs(self, start_timestamp=None): if not is_newly_added: break - return self._parse_logs(response) + logs = self._parse_logs(response) + ctx.save_logs(logs) + return int(time.time() * 1000) def _parse_logs(self, response): df = pd.DataFrame(response, columns=["timestamp", "message", "ingestionTime"]) diff --git a/spot/main.py b/spot/main.py index 38b1e76..1816eeb 100644 --- a/spot/main.py +++ b/spot/main.py @@ -1,4 +1,5 @@ import time +import boto3 import argparse import os from spot.constants import ROOT_DIR @@ -19,133 +20,47 @@ def main(): action="store_true", help="Return best memory configuration for lowest cost", ) - parser.add_argument( - "--invoke", - "-i", - action="store_true", - help="Run the function with the given workload", - ) parser.add_argument( "--fetch", "-f", action="store_true", help="Fetch log and config data from AWS" ) parser.add_argument( - "--train", - "-t", - action="store_true", - help="Train the model based on the fetched log and config data", - ) - parser.add_argument( - "--recommend", - "-r", - action="store_true", - help="Recommend a memory config based on the trained model", - ) - parser.add_argument( - "--profile", - "-p", - action="store_true", - help="Test multiple memory configs to determine the optimal one", - ) - parser.add_argument( - "--model", - "-m", - type=str, - help="The ML model to use to train the model", - ) - parser.add_argument( - "--update_config", - "-u", - action="store_true", - help="Update lambda function config with the optimal config current model suggests", - ) - parser.add_argument( - "--plot_error_vs_epoch", - "-ee", - action="store_true", - help="Plot error vs epoch", - ) - parser.add_argument( - "--plot_config_vs_epoch", - "-ce", - action="store_true", - help="Plot config vs epoch", + "--invoke", "-i", action="store_true", help="Invoke the function" ) parser.add_argument( - "--plot_memsize_vs_cost", - "-mc", - action="store_true", - help="Plot Memory Size vs Cost", - ) - parser.add_argument( - "--full", - action="store_true", - help="End-to-end execution of full lifecycle: profiling then fetching newly created logs, then training the model, then recommending the optimal config and updating the serverless function config with the new config", + "--memory_mb", "-m", type=int, help="Memory (MB) of the function" ) + parser.add_argument("--aws_profile", "-p", type=str, help="AWS profile") args = parser.parse_args() - if args.full: - """ - End-to-end execution of full lifecycle: - 1. profiling - 2. fetching newly created logs - 3. training the model - 4. recommending the optimal config - 5. updating the serverless function config with the new config - """ - if args.model: - args.profile = args.fetch = args.train = args.update_config = True - else: - print("Please specify model") - return - - if args.function: - path = os.path.join(ROOT_DIR, "../", FUNCTION_DIR, args.function) - if os.path.isdir(path): - function = Spot(path, args.model) - if args.optimize: - function.optimize() - if args.invoke: - function.invoke() - if args.profile: - function.profile() - if args.fetch: - function.collect_data() - if args.train: - if args.model: - function.train_model() - else: - print("Please specify model") - return - if args.recommend: - if args.model: - function.recommend() - else: - print("Please specify model") - return - if args.update_config: - if args.model: - function.update_config() - function.get_prediction_error_rate() - else: - print("Please specify model") - return - if args.plot_error_vs_epoch: - function.plot_error_vs_epoch() - if args.plot_config_vs_epoch: - function.plot_config_vs_epoch() - if args.plot_memsize_vs_cost: - if (args.train or args.full) and args.model: - function.plot_memsize_vs_cost() - else: - print("Memsize vs Cost plot can be generated only after training") - return - else: - print( - f"Could not find the serverless function {args.function} in '{path}'. Functions are case sensitive" - ) + if args.aws_profile: + session = boto3.Session(profile_name=args.aws_profile) else: + session = boto3.Session() + + if not args.function: print(f"Please specify a serverless function from the {FUNCTION_DIR} directory") + exit(1) + + path = os.path.join(ROOT_DIR, "../", FUNCTION_DIR, args.function) + if not os.path.isdir(path): + print( + f"Could not find the serverless function {args.function} in '{path}'. Functions are case sensitive" + ) + exit(1) + + spot = Spot(path, session) + if args.optimize: + spot.optimize() + if args.fetch: + spot.collect_data() + if args.invoke: + if not args.memory_mb: + print("Please specify a memory value when invoking a function") + exit(1) + spot.invoke(args.memory_mb) + + spot.teardown() if __name__ == "__main__": diff --git a/spot/prices/aws_price_retriever.py b/spot/prices/aws_price_retriever.py index 931ff63..5a2d5c3 100644 --- a/spot/prices/aws_price_retriever.py +++ b/spot/prices/aws_price_retriever.py @@ -1,49 +1,35 @@ from spot.prices.price_retriever import PriceRetriever -import time as time +import time -from spot.db.db import DBClient +from spot.context import Context from spot.constants import * class AWSPriceRetriever(PriceRetriever): - def __init__(self, db: DBClient, region): + def __init__(self, ctx: Context, region): super().__init__() - self.DBClient = db + self.ctx = ctx self.region = region def fetch_current_pricing(self) -> dict: - current_pricing = {} - parameters = { "vendor": "aws", "service": "AWSLambda", "family": "Serverless", "region": self.region, - "type": "AWS-Lambda-Requests", + "type": "AWS-Lambda-Duration", "purchaseOption": "on_demand", } request_price = self._current_price(parameters) - current_pricing[REQUEST_PRICE] = request_price - parameters["type"] = "AWS-Lambda-Duration" duration_price = self._current_price(parameters) - current_pricing[DURATION_PRICE] = duration_price - current_pricing[TIMESTAMP] = int(time.time() * 100) - current_pricing[REGION] = self.region - - self.DBClient.add_document_to_collection_if_not_exists( - DB_NAME_PRICING, - "AWS", + current_pricing = { + "provider": "AWS", + REQUEST_PRICE: request_price, + DURATION_PRICE: duration_price, + TIMESTAMP: int(time.time() * 1000), + REGION: self.region, + } + self.ctx.record_pricing( current_pricing, - { - REQUEST_PRICE: request_price, - DURATION_PRICE: duration_price, - REGION: self.region, - }, ) return current_pricing - - -""" -a = AWSPriceRetriever("localhost", 27017) -a.fetch_current_pricing() -""" diff --git a/spot/recommendation_engine/recommendation_engine.py b/spot/recommendation_engine/recommendation_engine.py index 18ba364..25869f2 100644 --- a/spot/recommendation_engine/recommendation_engine.py +++ b/spot/recommendation_engine/recommendation_engine.py @@ -77,6 +77,7 @@ def sample(self, x): parallelism=2, memory_mb=x, payload_filename=self.payload, + save_to_ctx=False, ) assert all( result["Memory Size"] == x @@ -105,6 +106,15 @@ def sample(self, x): print(f"finished sampling {x} with {len(values)} samples") self.objective.update_knowledge(x) + def invoke_once(self, memory_mb): + result = self.function_invocator.invoke( + invocation_count=1, + parallelism=1, + memory_mb=memory_mb, + payload_filename=self.payload, + ) + return result + def choose_sample_point(self): max_value = MEMORY_RANGE[0] max_obj = np.inf