Skip to content

Commit

Permalink
[Unity][lm_support] window kvcache sink (#16240)
Browse files Browse the repository at this point in the history
* attention sinks with correctness test

* fix override sink
  • Loading branch information
davidpissarra authored Dec 15, 2023
1 parent e1964ec commit cd9445d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 14 deletions.
44 changes: 30 additions & 14 deletions src/runtime/relax_vm/lm_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ class AttentionKVCacheObj : public Object {
/*!
* \brief Append value to the cache, overrides if full.
* \param value The value to override previous elements.
* \param max_cache_size max size of the cache.
* \param num_attention_sinks number of sinks to store (https://arxiv.org/abs/2309.17453).
*/
void WindowOverride(NDArray value, int64_t max_cache_size) {
void WindowOverride(NDArray value, int64_t max_cache_size, int64_t num_attention_sinks = 0) {
CHECK(data.DataType() == value.DataType()) << "dtype mismatch";
CHECK_LE(value->shape[0], max_cache_size) << "dim 0 of value too large";
CHECK_LE(value->shape[0], max_cache_size - num_attention_sinks) << "dim 0 of value too large";
// reallocate cache
if (fill_count + value->shape[0] <= max_cache_size) {
int64_t reserved_slots = data->shape[0];
Expand Down Expand Up @@ -148,20 +150,22 @@ class AttentionKVCacheObj : public Object {
shape.push_back(data->shape[i]);
}
int64_t num_filled_elements = window_attention_current_pos * num_elements_p_entry;

DLTensor copy_dst = *(data.operator->());
copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits * data->dtype.lanes + 7) / 8);
copy_dst.shape = &shape[0];

DLTensor copy_src = *(value.operator->());
copy_src.byte_offset = 0;
copy_src.shape = &shape[0];

NDArray::CopyFromTo(&copy_src, &copy_dst);
this->fill_count = std::min(this->fill_count + value->shape[0], max_cache_size);
this->window_attention_current_pos =
std::min(this->window_attention_current_pos + value->shape[0], max_cache_size);

if (num_elements_to_copy > 0) {
DLTensor copy_dst = *(data.operator->());
copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits * data->dtype.lanes + 7) / 8);
copy_dst.shape = &shape[0];

DLTensor copy_src = *(value.operator->());
copy_src.byte_offset = 0;
copy_src.shape = &shape[0];

NDArray::CopyFromTo(&copy_src, &copy_dst);
}

// copy the remainder to the beginning of the cache
if (num_elements_to_copy < value->shape[0]) {
ICHECK_EQ(this->fill_count, max_cache_size);
Expand All @@ -171,7 +175,8 @@ class AttentionKVCacheObj : public Object {
num_filled_elements = num_elements_to_copy * num_elements_p_entry;

DLTensor copy_dst = *(data.operator->());
copy_dst.byte_offset = 0;
copy_dst.byte_offset = (num_attention_sinks * num_elements_p_entry) *
((data->dtype.bits * data->dtype.lanes + 7) / 8);
copy_dst.shape = &shape[0];

DLTensor copy_src = *(value.operator->());
Expand All @@ -180,7 +185,8 @@ class AttentionKVCacheObj : public Object {
copy_src.shape = &shape[0];

NDArray::CopyFromTo(&copy_src, &copy_dst);
this->window_attention_current_pos = value->shape[0] - num_elements_to_copy;
this->window_attention_current_pos =
value->shape[0] - num_elements_to_copy + num_attention_sinks;
}
}

Expand Down Expand Up @@ -277,6 +283,16 @@ AttentionKVCache AttentionKVCacheWindowOverride(AttentionKVCache cache, NDArray
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override")
.set_body_typed(AttentionKVCacheWindowOverride);

AttentionKVCache AttentionKVCacheWindowOverrideWithSinks(AttentionKVCache cache, NDArray value,
int64_t max_cache_size,
int64_t num_attention_sinks) {
cache->WindowOverride(value, max_cache_size, num_attention_sinks);
return cache;
}

TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override_with_sinks")
.set_body_typed(AttentionKVCacheWindowOverrideWithSinks);

NDArray AttentionKVCacheView(AttentionKVCache cache, ShapeTuple shape) {
return cache->View(shape);
}
Expand Down
41 changes: 41 additions & 0 deletions tests/python/relax/test_runtime_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,46 @@ def test_attention_kv_cache_window_override():
).all()


def test_attention_kv_cache_window_override_with_sinks():
fcreate = tvm.get_global_func("vm.builtin.attention_kv_cache_create")
foverride = tvm.get_global_func("vm.builtin.attention_kv_cache_window_override_with_sinks")
fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view")

num_attention_sinks = 2
has_sink = False
current_pos = 0

cache = fcreate(
tvm.nd.array(np.full((16, 2), -1).astype("int32")),
tvm.runtime.ShapeTuple([16, 2]),
current_pos,
)
np_all_arrays = np.zeros((0, 2)).astype("int32")

num_steps = 40
for i in range(num_steps):
np_array = i * np.ones((1, 2)).astype("int32")
np_all_arrays = np.concatenate((np_all_arrays, np_array), axis=0)
cache = foverride(cache, tvm.nd.array(np_array), 16, num_attention_sinks)

if has_sink:
current_pos = max((current_pos + 1) % 16, num_attention_sinks)
else:
current_pos += 1
has_sink = current_pos >= num_attention_sinks

res = fview(cache, tvm.runtime.ShapeTuple((16, 2))).numpy()

# unrotate cache and assert cache matches last 16 elements
assert (
np.concatenate(
(np_all_arrays[:num_attention_sinks, :], np_all_arrays[-16 + num_attention_sinks :, :])
)
== np.concatenate(
(res[:num_attention_sinks], res[current_pos:], res[num_attention_sinks:current_pos])
)
).all()


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit cd9445d

Please sign in to comment.