Skip to content

Commit

Permalink
fix rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Sep 26, 2024
1 parent 2dbda67 commit 15ca8cd
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
11 changes: 5 additions & 6 deletions mlx/backend/common/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ collapse_contiguous_dims(
}

template <typename StrideT>
std::tuple<std::vector<int>, std::vector<StrideT>>
collapse_contiguous_dims_impl(
std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
StrideT size_cap) {
Expand All @@ -112,24 +111,24 @@ collapse_contiguous_dims_impl(
}
}

return std::make_tuple(collapsed_shape, collapsed_strides);
return std::make_pair(collapsed_shape, collapsed_strides);
}

std::tuple<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<int64_t>(shape, strides, size_cap);
}

std::tuple<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<size_t>(shape, strides, size_cap);
}

std::tuple<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const array& a,
size_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
return collapse_contiguous_dims_impl<size_t>(
Expand Down
6 changes: 3 additions & 3 deletions mlx/backend/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
}

// The single array version of the above.
std::tuple<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max());
std::tuple<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
std::tuple<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const array& a,
size_t size_cap = std::numeric_limits<int32_t>::max());

Expand Down
5 changes: 1 addition & 4 deletions mlx/backend/metal/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ void unary_op_gpu_inplace(

auto maybe_collapse = [contig, &in, &out]() {
if (!contig) {
auto [shape, strides] = collapse_contiguous_dims(
{in, out},
/* size_cap = */ INT32_MAX);
return std::make_pair(shape, strides[0]);
return collapse_contiguous_dims(in);
} else {
return std::make_pair(std::vector<int>{}, std::vector<size_t>{});
}
Expand Down

0 comments on commit 15ca8cd

Please sign in to comment.