Skip to content

Commit

Permalink
Record telemetry including acceptance rate (deepjavalibrary#2088)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachgk authored Jun 21, 2024
1 parent 44f9508 commit 9b6ff95
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class LmiDistRbProperties(Properties):
speculative_length: int = 5
draft_model_tp_size: int = 1
record_acceptance_rate: Optional[bool] = False
speculative_telemetry: Optional[bool] = True
enable_lora: Optional[bool] = False
max_loras: Optional[int] = 4
max_lora_rank: Optional[int] = 16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
import os
from typing import List
from collections import OrderedDict, defaultdict

Expand All @@ -26,6 +27,7 @@
get_speculative_decoding_metrics_record, update_request_cache_with_output,
supports_speculative_decoding, get_lora_request_params, DTYPE_MAPPER,
FINISH_REASON_MAPPER)
from djl_python.telemetry import telemetry_manager
from djl_python.properties_manager.lmi_dist_rb_properties import LmiDistRbProperties

_WARMUP_PREFILL_TOKENS = 4096
Expand Down Expand Up @@ -187,11 +189,17 @@ def inference(self,
self.request_cache, request_output, self.get_tokenizer())
# Record SD metrics
completion_output = request_output.outputs[0]
if self.lmi_dist_config.record_acceptance_rate and request_output.finished:
if (self.lmi_dist_config.record_acceptance_rate
or self.lmi_dist_config.speculative_telemetry
) and request_output.finished:
if self.supports_speculative_decoding and completion_output.acceptance_history:
record = get_speculative_decoding_metrics_record(
completion_output, request_output)
logging.info(f"Speculative Decoding {record}")
if self.lmi_dist_config.record_acceptance_rate:
logging.info(f"Speculative Decoding {record}")
if self.lmi_dist_config.speculative_telemetry and os.environ.get(
"SAGEMAKER_SECURE_MODE") == "true":
telemetry_manager.record_speculative(record)
else:
logging.warning(
f"Ignoring logging speculative decoding metrics")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def get_speculative_decoding_metrics_record(
completion_output.acceptance_history)
else:
record["mean_acceptance"] = 0
record["acceptance_history_len"] = len(
completion_output.acceptance_history)
record["prompt_size"] = len(request_output.prompt_token_ids)
record["output_size"] = len(completion_output.token_ids)
return record
Expand Down
5 changes: 4 additions & 1 deletion engines/python/setup/djl_python/sm_log_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

# https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/logging-and-monitoring.html
class SMLogFilter(logging.Filter):
sm_log_markers = ['ModelServerError', 'UserScriptError', 'SysHealth']
sm_log_markers = [
'ModelServerError', 'UserScriptError', 'SysHealth',
'ModelServerTelemetry'
]
counter = defaultdict(int)

def filter(self, record):
Expand Down
44 changes: 44 additions & 0 deletions engines/python/setup/djl_python/telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python
#
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import logging
import time

SPECULATIVE_FREQUENCY_SEC = 30.0


class TelemetryManager:

def __init__(self):
self.reset_speculative()

def record_speculative(self, data):
self.speculative_acceptance_rate_count = self.speculative_acceptance_rate_count + data[
"acceptance_history_len"]
self.speculative_acceptance_rate_total = self.speculative_acceptance_rate_total + data[
"mean_acceptance"] * data["acceptance_history_len"]
if time.time(
) - self.speculative_sent_time > SPECULATIVE_FREQUENCY_SEC:
mean_acceptance = 1.0 * self.speculative_acceptance_rate_total / self.speculative_acceptance_rate_count
logging.info(
f"ModelServerTelemetry: Speculative Decoding Mean Acceptance: {mean_acceptance} rate"
)
self.reset_speculative()

def reset_speculative(self):
self.speculative_sent_time = time.time()
self.speculative_acceptance_rate_count = 0
self.speculative_acceptance_rate_total = 0.0


telemetry_manager = TelemetryManager()
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ public void run() {
MODEL_METRIC.info("{}", Metric.parse(result.substring(metricLoc + 9)));
continue;
}
if (result.contains("ModelServerTelemetry:")) {
continue;
}

if (error) {
logger.warn("{}: {}", getName(), result);
Expand Down

0 comments on commit 9b6ff95

Please sign in to comment.