Skip to content

Commit

Permalink
[Neo] Fix Neo Quantization properties output. Add some additional con…
Browse files Browse the repository at this point in the history
…figuration. (deepjavalibrary#2077)
  • Loading branch information
a-ys authored and sindhuvahinis committed Jun 18, 2024
1 parent 8b26687 commit 2156960
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
42 changes: 38 additions & 4 deletions serving/docker/partition/sm_neo_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from sm_neo_utils import (CompilationFatalError, write_error_to_file,
get_neo_env_vars)
from utils import extract_python_jar
from utils import extract_python_jar, load_properties
from properties_manager import PropertiesManager
from partition import PartitionService

Expand All @@ -38,13 +38,14 @@ def __init__(self):
self.INPUT_MODEL_DIRECTORY: Final[str] = env[1]
self.OUTPUT_MODEL_DIRECTORY: Final[str] = env[2]
self.COMPILATION_ERROR_FILE: Final[str] = env[3]
self.COMPILER_CACHE_LOCATION: Final[str] = env[4]
self.HF_CACHE_LOCATION: Final[str] = env[5]
self.TARGET_INSTANCE_TYPE: Final[str] = env[6]

def update_dataset_cache_location(self):
logging.info(
f"Updating HuggingFace Datasets cache directory to: {self.COMPILER_CACHE_LOCATION}"
f"Updating HuggingFace Datasets cache directory to: {self.HF_CACHE_LOCATION}"
)
os.environ['HF_DATASETS_CACHE'] = self.COMPILER_CACHE_LOCATION
os.environ['HF_DATASETS_CACHE'] = self.HF_CACHE_LOCATION
#os.environ['HF_DATASETS_OFFLINE'] = "1"

def initialize_partition_args_namespace(self):
Expand Down Expand Up @@ -72,6 +73,8 @@ def construct_properties_manager(self):
given serving.properties
"""
# Default to awq quantization
# TODO: update this when new quantization methods are added,
# since envvar overrides customer serving.properties
os.environ['OPTION_QUANTIZE'] = 'awq'
logging.debug("Constructing PropertiesManager from "
f"serving.properties\nargs:{self.args}\n")
Expand All @@ -89,11 +92,42 @@ def run_quantization(self) -> str:
raise CompilationFatalError(
f"Encountered an error during quantization: {exc}")

def write_properties(self):
"""
Updates outputted serving.properties.
If a user passes in tensor_parallel_degree, it is passed through to the output.
Otherwise, tensor_parallel_degree is not outputted so that it can be defined
during serving.
"""
customer_properties = load_properties(self.INPUT_MODEL_DIRECTORY)
user_tensor_parallel_degree = customer_properties.get(
"option.tensor_parallel_degree")
if os.environ.get("OPTION_TENSOR_PARALLEL_DEGREE"):
user_tensor_parallel_degree = os.environ.get(
"OPTION_TENSOR_PARALLEL_DEGREE")

output_properties = self.properties_manager.properties
if user_tensor_parallel_degree:
logging.info(
f"User passed tensor_parallel_degree={user_tensor_parallel_degree}"
)
output_properties[
"option.tensor_parallel_degree"] = user_tensor_parallel_degree
else:
logging.info(
"User did not passs tensor_parallel_degree. Outputted serving.properties"
"will not include this field.")
del output_properties["option.tensor_parallel_degree"]

self.properties_manager.properties = output_properties
self.properties_manager.generate_properties_file()

def neo_quantize(self):
self.update_dataset_cache_location()
self.initialize_partition_args_namespace()
self.construct_properties_manager()
self.run_quantization()
self.write_properties()


def main():
Expand Down
3 changes: 2 additions & 1 deletion serving/docker/partition/sm_neo_trt_llm_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def main():
compilation_error_file = None
try:
(compiler_options, input_model_directory, compiled_model_directory,
compilation_error_file, neo_cache_dir) = get_neo_env_vars()
compilation_error_file, neo_cache_dir,
neo_hf_cache_dir) = get_neo_env_vars()

# Neo requires that serving.properties is in same dir as model files
serving_properties = load_properties(input_model_directory)
Expand Down
7 changes: 5 additions & 2 deletions serving/docker/partition/sm_neo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,19 @@ def write_error_to_file(error_message, error_file):
def get_neo_env_vars():
"""
Get environment variables required by the SageMaker Neo interface
TODO: Update the return type to a dictionary to allow for easier changes
"""
try:
compiler_options = os.environ.get("COMPILER_OPTIONS")
input_model_directory = os.environ["SM_NEO_INPUT_MODEL_DIR"]
compiled_model_directory = os.environ["SM_NEO_COMPILED_MODEL_DIR"]
compilation_error_file = os.environ["SM_NEO_COMPILATION_ERROR_FILE"]
neo_cache_dir = os.environ["SM_NEO_CACHE_DIR"]
neo_cache_dir = os.environ.get("SM_NEO_CACHE_DIR")
neo_hf_cache_dir = os.environ.get("SM_NEO_HF_CACHE_DIR")
return (compiler_options, input_model_directory,
compiled_model_directory, compilation_error_file,
neo_cache_dir)
neo_cache_dir, neo_hf_cache_dir)
except KeyError as exc:
raise InputConfiguration(
f"SageMaker Neo environment variable '{exc.args[0]}' expected but not found"
Expand Down

0 comments on commit 2156960

Please sign in to comment.