Skip to content

Commit

Permalink
1. Allocate correct byte size for CORRID string
Browse files Browse the repository at this point in the history
2. Validate CORRID dtype matches override value dtype
  • Loading branch information
yinggeh committed Jan 7, 2025
1 parent fc02544 commit fa343be
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 5 deletions.
60 changes: 59 additions & 1 deletion src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1228,11 +1228,16 @@ InferenceRequest::Normalize()
}
}
}

if (model_config.has_sequence_batching()) {
RETURN_IF_ERROR(ValidateOverrideInputs());
}

return Status::Success;
}

Status
InferenceRequest::ValidateRequestInputs()
InferenceRequest::ValidateRequestInputs() const
{
const inference::ModelConfig& model_config = model_raw_->Config();
if ((original_inputs_.size() > (size_t)model_config.input_size()) ||
Expand Down Expand Up @@ -1404,6 +1409,59 @@ InferenceRequest::ValidateBytesInputs(
return Status::Success;
}

Status
InferenceRequest::ValidateOverrideInputs() const
{
const inference::ModelConfig& model_config = model_raw_->Config();
const std::string& model_name = ModelName();
std::string correlation_id_tensor_name;
inference::DataType correlation_id_datatype;

RETURN_IF_ERROR(GetTypedSequenceControlProperties(
model_config.sequence_batching(), model_config.name(),
inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_CORRID,
false /* required */, &correlation_id_tensor_name,
&correlation_id_datatype));

// Make sure request correlation ID type matches model configuration.
if (!correlation_id_tensor_name.empty()) {
const auto& correlation_id = CorrelationId();
bool dtypes_match = true;
std::string request_corrid_datatype;
if ((correlation_id.Type() ==
InferenceRequest::SequenceId::DataType::STRING) &&
(correlation_id_datatype != inference::DataType::TYPE_STRING)) {
dtypes_match = false;
request_corrid_datatype = triton::common::DataTypeToProtocolString(
inference::DataType::TYPE_STRING);
} else if (
(correlation_id.Type() ==
InferenceRequest::SequenceId::DataType::UINT64) &&
((correlation_id_datatype != inference::DataType::TYPE_UINT64) &&
(correlation_id_datatype != inference::DataType::TYPE_INT64) &&
(correlation_id_datatype != inference::DataType::TYPE_UINT32) &&
(correlation_id_datatype != inference::DataType::TYPE_INT32))) {
dtypes_match = false;
request_corrid_datatype = triton::common::DataTypeToProtocolString(
inference::DataType::TYPE_UINT64);
}

if (!dtypes_match) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "sequence batching control '" +
correlation_id_tensor_name + "' data-type is '" +
request_corrid_datatype + "', but model '" + model_name +
"' expects '" +
std::string(triton::common::DataTypeToProtocolString(
correlation_id_datatype)) +
"'");
}
}

return Status::Success;
}

#ifdef TRITON_ENABLE_STATS

void
Expand Down
4 changes: 3 additions & 1 deletion src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -771,13 +771,15 @@ class InferenceRequest {
Status Normalize();

// Helper for validating Inputs
Status ValidateRequestInputs();
Status ValidateRequestInputs() const;

Status ValidateBytesInputs(
const std::string& input_id, const Input& input,
const std::string& model_name,
TRITONSERVER_MemoryType* buffer_memory_type) const;

Status ValidateOverrideInputs() const;

// Helpers for pending request metrics
void IncrementPendingRequestCount();
void DecrementPendingRequestCount();
Expand Down
6 changes: 3 additions & 3 deletions src/sequence_batch_scheduler/sequence_batch_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1343,9 +1343,9 @@ SequenceBatch::SetControlTensors(
auto& seq_corr_id = seq_slot_corrid_override_;
size_t size_p = triton::common::GetDataTypeByteSize(seq_corr_id->DType());
if (seq_corr_id->DType() == inference::DataType::TYPE_STRING) {
// 4 bytes for length of string plus pre-defined max string correlation id
// length in bytes
size_p = 4 + triton::core::STRING_CORRELATION_ID_MAX_LENGTH_BYTES;
// 4 bytes for length of string plus string correlation id length in
// bytes.
size_p = 4 + corrid.StringValue().length();
}

TRITONSERVER_MemoryType memory_type;
Expand Down

0 comments on commit fa343be

Please sign in to comment.