Skip to content

Commit

Permalink
Add tensor_count property for ControlMessage (#2078)
Browse files Browse the repository at this point in the history
For ControlMessage, msg.tensors().count is a common pattern, calling msg.tensors() might require a bit more cost than we think. Add a `tensor_count` property to avoid the overhead.

Closes #1876 

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - Yuchen Zhang (https://github.com/yczhang-nv)

Approvers:
  - David Gardner (https://github.com/dagardner-nv)

URL: #2078
  • Loading branch information
yczhang-nv authored Dec 5, 2024
1 parent 59aeaca commit 32c982e
Show file tree
Hide file tree
Showing 17 changed files with 68 additions and 38 deletions.
24 changes: 12 additions & 12 deletions examples/log_parsing/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ class TritonInferenceLogParsing(TritonInferenceWorker):
"""

def build_output_message(self, msg: ControlMessage) -> ControlMessage:
seq_ids = cp.zeros((msg.tensors().count, 3), dtype=cp.uint32)
seq_ids[:, 0] = cp.arange(0, msg.tensors().count, dtype=cp.uint32)
seq_ids = cp.zeros((msg.tensor_count(), 3), dtype=cp.uint32)
seq_ids[:, 0] = cp.arange(0, msg.tensor_count(), dtype=cp.uint32)
seq_ids[:, 2] = msg.tensors().get_tensor('seq_ids')[:, 2]

memory = TensorMemory(
count=msg.tensors().count,
count=msg.tensor_count(),
tensors={
'confidences': cp.zeros((msg.tensors().count, self._inputs[list(self._inputs.keys())[0]].shape[1])),
'labels': cp.zeros((msg.tensors().count, self._inputs[list(self._inputs.keys())[0]].shape[1])),
'input_ids': cp.zeros((msg.tensors().count, msg.tensors().get_tensor('input_ids').shape[1])),
'confidences': cp.zeros((msg.tensor_count(), self._inputs[list(self._inputs.keys())[0]].shape[1])),
'labels': cp.zeros((msg.tensor_count(), self._inputs[list(self._inputs.keys())[0]].shape[1])),
'input_ids': cp.zeros((msg.tensor_count(), msg.tensors().get_tensor('input_ids').shape[1])),
'seq_ids': seq_ids
})

Expand Down Expand Up @@ -154,19 +154,19 @@ def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: Tens
seq_offset = seq_ids[0, 0].item()
seq_count = seq_ids[-1, 0].item() + 1 - seq_offset

input_ids[batch_offset:inf.tensors().count + batch_offset, :] = inf.tensors().get_tensor('input_ids')
out_seq_ids[batch_offset:inf.tensors().count + batch_offset, :] = seq_ids
input_ids[batch_offset:inf.tensor_count() + batch_offset, :] = inf.tensors().get_tensor('input_ids')
out_seq_ids[batch_offset:inf.tensor_count() + batch_offset, :] = seq_ids

resp_confidences = res.get_tensor('confidences')
resp_labels = res.get_tensor('labels')

# Two scenarios:
if (inf.payload().count == inf.tensors().count):
if (inf.payload().count == inf.tensor_count()):
assert seq_count == res.count
confidences[batch_offset:inf.tensors().count + batch_offset, :] = resp_confidences
labels[batch_offset:inf.tensors().count + batch_offset, :] = resp_labels
confidences[batch_offset:inf.tensor_count() + batch_offset, :] = resp_confidences
labels[batch_offset:inf.tensor_count() + batch_offset, :] = resp_labels
else:
assert inf.tensors().count == res.count
assert inf.tensor_count() == res.count

mess_ids = seq_ids[:, 0].get().tolist()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "morpheus/export.h" // for MORPHEUS_EXPORT
#include "morpheus/messages/meta.hpp" // for MessageMeta
#include "morpheus/types.hpp"
#include "morpheus/utilities/json_types.hpp" // for json_t

#include <pybind11/pytypes.h> // for object, dict, list
Expand Down Expand Up @@ -197,6 +198,13 @@ class MORPHEUS_EXPORT ControlMessage
*/
void tensors(const std::shared_ptr<TensorMemory>& tensor_memory);

/**
* @brief Get the length of tensors in the tensor memory.
*
* @return The length of tensors in the tensor memory.
*/
TensorIndex tensor_count();

/**
* @brief Get the type of task associated with the control message.
* @return An enum value indicating the task type.
Expand Down Expand Up @@ -262,6 +270,7 @@ class MORPHEUS_EXPORT ControlMessage
ControlMessageType m_cm_type{ControlMessageType::NONE};
std::shared_ptr<MessageMeta> m_payload{nullptr};
std::shared_ptr<TensorMemory> m_tensors{nullptr};
TensorIndex m_tensor_count{0};

morpheus::utilities::json_t m_tasks{};
morpheus::utilities::json_t m_config{};
Expand Down
1 change: 1 addition & 0 deletions python/morpheus/morpheus/_lib/messages/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class ControlMessage():
def task_type(self) -> ControlMessageType: ...
@typing.overload
def task_type(self, task_type: ControlMessageType) -> None: ...
def tensor_count(self) -> int: ...
@typing.overload
def tensors(self) -> TensorMemory: ...
@typing.overload
Expand Down
1 change: 1 addition & 0 deletions python/morpheus/morpheus/_lib/messages/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ PYBIND11_MODULE(messages, _module)
py::arg("meta"))
.def("tensors", pybind11::overload_cast<>(&ControlMessage::tensors))
.def("tensors", pybind11::overload_cast<const std::shared_ptr<TensorMemory>&>(&ControlMessage::tensors))
.def("tensor_count", &ControlMessage::tensor_count)
.def("remove_task", &ControlMessage::remove_task, py::arg("task_type"))
.def("set_metadata", &ControlMessage::set_metadata, py::arg("key"), py::arg("value"))
.def("task_type", pybind11::overload_cast<>(&ControlMessage::task_type))
Expand Down
15 changes: 11 additions & 4 deletions python/morpheus/morpheus/_lib/src/messages/control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ ControlMessage::ControlMessage(const morpheus::utilities::json_t& _config) :

ControlMessage::ControlMessage(const ControlMessage& other)
{
m_cm_type = other.m_cm_type;
m_payload = other.m_payload;
m_tensors = other.m_tensors;
m_cm_type = other.m_cm_type;
m_payload = other.m_payload;
m_tensors = other.m_tensors;
m_tensor_count = other.m_tensor_count;

m_config = other.m_config;
m_tasks = other.m_tasks;
Expand Down Expand Up @@ -256,7 +257,13 @@ std::shared_ptr<TensorMemory> ControlMessage::tensors()

void ControlMessage::tensors(const std::shared_ptr<TensorMemory>& tensors)
{
m_tensors = tensors;
m_tensors = tensors;
m_tensor_count = tensors ? tensors->count : 0;
}

TensorIndex ControlMessage::tensor_count()
{
return m_tensor_count;
}

ControlMessageType ControlMessage::task_type()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static ShapeType get_seq_ids(const std::shared_ptr<ControlMessage>& message)
auto seq_ids = message->tensors()->get_tensor("seq_ids");
const auto item_size = seq_ids.dtype().item_size();

ShapeType host_seq_ids(message->tensors()->count);
ShapeType host_seq_ids(message->tensor_count());
MRC_CHECK_CUDA(cudaMemcpy2D(host_seq_ids.data(),
item_size,
seq_ids.data(),
Expand All @@ -82,7 +82,7 @@ static TensorObject get_tensor(std::shared_ptr<ControlMessage> message, std::str

static void reduce_outputs(std::shared_ptr<ControlMessage> const& message, TensorMap& output_tensors)
{
if (message->payload()->count() == message->tensors()->count)
if (message->payload()->count() == message->tensor_count())
{
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "morpheus/messages/control.hpp" // for ControlMessage
#include "morpheus/messages/memory/tensor_memory.hpp" // for TensorMemory
#include "morpheus/messages/meta.hpp" // for MessageMeta
#include "morpheus/types.hpp"
#include "morpheus/utilities/json_types.hpp" // for PythonByteContainer

#include <gtest/gtest.h> // for Message, TestPartResult, AssertionResult, TestInfo
Expand Down Expand Up @@ -298,7 +299,8 @@ TEST_F(TestControlMessage, SetAndGetTensorMemory)
{
auto msg = ControlMessage();

auto tensorMemory = std::make_shared<TensorMemory>(0);
TensorIndex count = 5;
auto tensorMemory = std::make_shared<TensorMemory>(count);
// Optionally, modify tensorMemory here if it has any mutable state to test

// Set the tensor memory
Expand All @@ -309,6 +311,7 @@ TEST_F(TestControlMessage, SetAndGetTensorMemory)

// Verify that the retrieved tensor memory matches what was set
EXPECT_EQ(tensorMemory, retrievedTensorMemory);
EXPECT_EQ(count, msg.tensor_count());
}

// Test setting TensorMemory to nullptr
Expand Down
5 changes: 5 additions & 0 deletions python/morpheus/morpheus/messages/control_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, config_or_message: typing.Union["ControlMessage", dict] = Non

self._payload: MessageMeta = None
self._tensors: TensorMemory = None
self._tensor_count: int = 0

self._tasks: dict[str, deque] = defaultdict(deque)
self._timestamps: dict[str, datetime] = {}
Expand Down Expand Up @@ -147,9 +148,13 @@ def payload(self, payload: MessageMeta = None) -> MessageMeta | None:
def tensors(self, tensors: TensorMemory = None) -> TensorMemory | None:
if tensors is not None:
self._tensors = tensors
self._tensor_count = tensors.count

return self._tensors

def tensor_count(self) -> int:
return self._tensor_count

def task_type(self, new_task_type: ControlMessageType = None) -> ControlMessageType:
if new_task_type is not None:
self._type = new_task_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ def __init__(self, inf_queue: ProducerConsumerQueue, c: Config):
self._seq_length = c.feature_length

def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple:
return (msg.tensors().count, self._seq_length)
return (msg.tensor_count(), self._seq_length)

def process(self, batch: ControlMessage, callback: typing.Callable[[TensorMemory], None]):

def tmp(batch: ControlMessage, f):
count = batch.tensors().count
count = batch.tensor_count()
f(TensorMemory(
count=count,
tensors={'probs': cp.zeros((count, self._seq_length), dtype=cp.float32)},
Expand Down
6 changes: 3 additions & 3 deletions python/morpheus/morpheus/stages/inference/inference_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def set_output_fut(resp: TensorMemory, inner_batch, batch_future: mrc.Future):
nonlocal outstanding_requests
nonlocal batch_offset
mess = self._convert_one_response(output_message, inner_batch, resp, batch_offset)
batch_offset += inner_batch.tensors().count
batch_offset += inner_batch.tensor_count()
outstanding_requests -= 1

batch_future.set_result(mess)
Expand Down Expand Up @@ -359,13 +359,13 @@ def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: Tens
seq_count = seq_ids[-1, 0].item() + 1 - seq_offset

# Two scenarios:
if (inf.payload().count == inf.tensors().count):
if (inf.payload().count == inf.tensor_count()):
assert seq_count == res.count

# In message and out message have same count. Just use probs as is
probs[seq_offset:seq_offset + seq_count, :] = resp_probs
else:
assert inf.tensors().count == res.count
assert inf.tensor_count() == res.count

mess_ids = seq_ids[:, 0].get().tolist()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def init(self):
def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple:
input_ids = msg.tensors().get_tensor("input_ids")
input_mask = msg.tensors().get_tensor("input_mask")
count = msg.tensors().count
count = msg.tensor_count()
# If we haven't cached the output dimension, do that here
if (not self._output_size):
test_intput_ids_shape = (self._max_batch_size, ) + input_ids.shape[1:]
Expand All @@ -91,7 +91,7 @@ def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple:
def process(self, batch: ControlMessage, callback: typing.Callable[[TensorMemory], None]):
input_ids = batch.tensors().get_tensor("input_ids")
input_mask = batch.tensors().get_tensor("input_mask")
count = batch.tensors().count
count = batch.tensor_count()

# convert from cupy to torch tensor using dlpack
input_ids = from_dlpack(input_ids.astype(cp.float).toDlpack()).type(torch.long)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def create_wrapper():
raise ex

def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple:
return (msg.tensors().count, self._outputs[list(self._outputs.keys())[0]].shape[1])
return (msg.tensor_count(), self._outputs[list(self._outputs.keys())[0]].shape[1])

def _build_response(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _calc_drift(self, msg: ControlMessage):
for label in range(len(self._labels), shifted.shape[1]):
self._labels.append(str(label))

count = msg.tensors().count
count = msg.tensor_count()

for i in list(range(0, count, self._batch_size)):
start = i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def check_inf_message(msg: ControlMessage,
expected_input__0: cp.ndarray):
assert isinstance(msg, ControlMessage)
assert msg.payload().count == expected_mess_count
assert msg.tensors().count == expected_count
assert msg.tensor_count() == expected_count

df = msg.payload().get_data()
assert 'flow_id' in df
Expand Down
4 changes: 2 additions & 2 deletions tests/examples/log_parsing/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_log_parsing_triton_inference_log_parsing_build_output_message(config: C
msg = worker.build_output_message(input_msg)
assert msg.payload() is input_msg.payload()
assert msg.payload().count == mess_count
assert msg.tensors().count == count
assert msg.tensor_count() == count

assert set(msg.tensors().tensor_names).issuperset(('confidences', 'labels', 'input_ids', 'seq_ids'))
assert msg.tensors().get_tensor('confidences').shape == (count, 2)
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_log_parsing_inference_stage_convert_one_response(import_mod: typing.Lis
assert isinstance(output_msg, ControlMessage)
assert output_msg.payload() is input_inf.payload()
assert output_msg.payload().count == mess_count
assert output_msg.tensors().count == count
assert output_msg.tensor_count() == count

assert (output_msg.tensors().get_tensor('seq_ids') == input_inf.tensors().get_tensor('seq_ids')).all()
assert (output_msg.tensors().get_tensor('input_ids') == input_inf.tensors().get_tensor('input_ids')).all()
Expand Down
16 changes: 10 additions & 6 deletions tests/morpheus/messages/test_control_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _verify_metadata(msg: messages.ControlMessage, metadata: dict):

@pytest.mark.gpu_and_cpu_mode
def test_control_message_init(dataset: DatasetManager):
# Explicitly performing copies of the metadata, config and the dataframe, to ensure tha the original data is not
# Explicitly performing copies of the metadata, config and the dataframe, to ensure that the original data is not
# being modified in place in some way.
msg = messages.ControlMessage()
assert msg.get_metadata() == {} # pylint: disable=use-implicit-booleaness-not-comparison
Expand Down Expand Up @@ -318,9 +318,9 @@ def test_tensors_setting_and_getting(config: Config):

message.tensors(tensor_memory)

retrieved_tensors = message.tensors()
assert retrieved_tensors.count == data["input_ids"].shape[0], "Tensor count mismatch."
assert message.tensor_count() == data["input_ids"].shape[0], "Tensor count mismatch."

retrieved_tensors = message.tensors()
for key, val in data.items():
assert array_pkg.allclose(retrieved_tensors.get_tensor(key), val), f"Mismatch in tensor data for {key}."

Expand Down Expand Up @@ -363,6 +363,7 @@ def test_tensor_manipulation_after_retrieval(config: Config):
new_tensor = array_pkg.array([4, 5, 6])
retrieved_tensors.set_tensor("new_tensor", new_tensor)

assert message.tensor_count() == tokenized_data["input_ids"].shape[0], "Tensor count mismatch"
assert array_pkg.allclose(retrieved_tensors.get_tensor("new_tensor"), new_tensor), "New tensor data mismatch."


Expand All @@ -389,8 +390,9 @@ def test_tensor_update(config: Config):

tensor_memory.set_tensors(new_tensors)

updated_tensors = message.tensors()
assert message.tensor_count() == tokenized_data["input_ids"].shape[0], "Tensor count mismatch"

updated_tensors = message.tensors()
for key, val in new_tensors.items():
assert array_pkg.allclose(updated_tensors.get_tensor(key), val), f"Mismatch in updated tensor data for {key}."

Expand All @@ -408,6 +410,7 @@ def test_update_individual_tensor(config: Config):
tensor_memory.set_tensor("input_ids", update_data["input_ids"])
retrieved_tensors = message.tensors()

assert message.tensor_count() == initial_data["input_ids"].shape[0], "Tensor count mismatch"
# Check updated tensor
assert array_pkg.allclose(retrieved_tensors.get_tensor("input_ids"),
update_data["input_ids"]), "Input IDs update mismatch."
Expand All @@ -422,8 +425,9 @@ def test_behavior_with_empty_tensors():
tensor_memory = TensorMemory(count=0)
message.tensors(tensor_memory)

assert message.tensor_count() == 0, "Tensor count should be 0 for empty tensor memory."

retrieved_tensors = message.tensors()
assert retrieved_tensors.count == 0, "Tensor count should be 0 for empty tensor memory."
assert len(retrieved_tensors.tensor_names) == 0, "There should be no tensor names for empty tensor memory."


Expand All @@ -442,8 +446,8 @@ def test_consistency_after_multiple_operations(config: Config):
new_tensor = {"new_tensor": array_pkg.array([7, 8, 9])}
tensor_memory.set_tensor("new_tensor", new_tensor["new_tensor"])

assert message.tensor_count() == initial_data["input_ids"].shape[0], "Tensor count mismatch."
retrieved_tensors = message.tensors()
assert retrieved_tensors.count == 3, "Tensor count mismatch after multiple operations."
assert array_pkg.allclose(retrieved_tensors.get_tensor("input_ids"),
array_pkg.array([4, 5, 6])), "Mismatch in input_ids after update."
assert array_pkg.allclose(retrieved_tensors.get_tensor("new_tensor"),
Expand Down
2 changes: 1 addition & 1 deletion tests/morpheus/stages/test_inference_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_convert_one_response():
cm = InferenceStageT._convert_one_response(output, inf, res, batch_offset)
assert cm.payload() == inf.payload()
assert cm.payload().count == 4
assert cm.tensors().count == 4
assert cm.tensor_count() == 4
assert cp.all(cm.tensors().get_tensor("probs") == res.get_tensor("probs"))

# Test for the second branch
Expand Down

0 comments on commit 32c982e

Please sign in to comment.