Skip to content

Commit

Permalink
[Math/Matrix] Fixed the matrix/matrix multiplication operator
Browse files Browse the repository at this point in the history
- The left-hand side matrix had a height equal to the right-hand side's

- Moved the vector/matrix multiplication operator to Matrix.hpp/inl, as it makes more sense there
  • Loading branch information
Razakhel committed Dec 15, 2023
1 parent 25bc6ea commit de0574f
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 134 deletions.
57 changes: 34 additions & 23 deletions include/RaZ/Math/Matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

namespace Raz {

template <typename T, std::size_t Size>
class Vector;

template <typename T, std::size_t W, std::size_t H>
class Matrix;

template <typename T, std::size_t Size>
class Vector;

template <typename T, std::size_t W, std::size_t H>
std::ostream& operator<<(std::ostream& stream, const Matrix<T, W, H>& mat);

Expand Down Expand Up @@ -61,16 +61,16 @@ class Matrix {
static constexpr Matrix identity() noexcept;
/// Constructs a matrix from the given row vectors.
/// \note All vectors must be of the same inner type as the matrix's, and must have a size equal to the matrix's width.
/// \tparam Vecs Types of the vectors to construct the matrix with.
/// \tparam VecsTs Types of the vectors to construct the matrix with.
/// \param vecs Row vectors to construct the matrix with.
template <typename... Vecs>
static constexpr Matrix fromRows(Vecs&&... vecs) noexcept;
template <typename... VecsTs>
static constexpr Matrix fromRows(VecsTs&&... vecs) noexcept;
/// Constructs a matrix from the given column vectors.
/// \note All vectors must be of the same inner type as the matrix's, and must have a size equal to the matrix's height.
/// \tparam Vecs Types of the vectors to construct the matrix with.
/// \tparam VecsTs Types of the vectors to construct the matrix with.
/// \param vecs Column vectors to construct the matrix with.
template <typename... Vecs>
static constexpr Matrix fromColumns(Vecs&&... vecs) noexcept;
template <typename... VecsTs>
static constexpr Matrix fromColumns(VecsTs&&... vecs) noexcept;

/// Transposed matrix computation.
/// \return Transposed matrix.
Expand Down Expand Up @@ -211,36 +211,47 @@ class Matrix {
template <std::size_t WI, std::size_t HI, typename T2, typename... Args>
constexpr void setValues(T2&& val, Args&&... args) noexcept;

template <typename Vec, typename... Vecs>
constexpr void setRows(Vec&& vec, Vecs&&... args) noexcept;
template <typename VecT, typename... VecsTs>
constexpr void setRows(VecT&& vec, VecsTs&&... args) noexcept;

template <typename Vec, typename... Vecs>
constexpr void setColumns(Vec&& vec, Vecs&&... args) noexcept;
template <typename VecT, typename... VecsTs>
constexpr void setColumns(VecT&& vec, VecsTs&&... args) noexcept;

std::array<T, W * H> m_data {};
};

/// Matrix-matrix multiplication operator.
/// \tparam W Width of the left-hand matrix.
/// \tparam H Height of the left-hand matrix & width of the resulting one.
/// \tparam WI Width of the right-hand matrix & height of the resulting one.
/// \tparam HI Height of the right-hand matrix.
/// \param mat1 Left-hand matrix.
/// \param mat2 Right-hand matrix.
/// \tparam T Type of the matrices' data.
/// \tparam WL Width of the left-hand side matrix.
/// \tparam HL Height of the left-hand side matrix & width of the resulting one.
/// \tparam WR Width of the right-hand side matrix & height of the resulting one.
/// \tparam HR Height of the right-hand side matrix.
/// \param mat1 Left-hand side matrix.
/// \param mat2 Right-hand side matrix.
/// \return Result of the multiplied matrices.
template <typename T, std::size_t W, std::size_t H, std::size_t WI, std::size_t HI>
constexpr Matrix<T, H, WI> operator*(const Matrix<T, WI, H>& mat1, const Matrix<T, WI, HI>& mat2) noexcept;
template <typename T, std::size_t WL, std::size_t HL, std::size_t WR, std::size_t HR>
constexpr Matrix<T, HL, WR> operator*(const Matrix<T, WL, HL>& mat1, const Matrix<T, WR, HR>& mat2) noexcept;

/// Matrix-vector multiplication operator (assumes the vector to be vertical).
/// \tparam T Type of the matrix's & vector's data.
/// \tparam W Width of the matrix & size of the input vector.
/// \tparam H Height of the matrix & size of the resulting vector.
/// \param mat Left-hand matrix.
/// \param vec Right-hand vector.
/// \param mat Left-hand side matrix.
/// \param vec Right-hand side vector.
/// \return Result of the matrix-vector multiplication.
template <typename T, std::size_t W, std::size_t H>
constexpr Vector<T, H> operator*(const Matrix<T, W, H>& mat, const Vector<T, W>& vec) noexcept;

/// Vector-matrix multiplication operator (assumes the vector to be horizontal).
/// \tparam T Type of the vector's & matrix's data.
/// \tparam W Width of the matrix & size of the resulting vector.
/// \tparam H Height of the matrix & size of the input vector.
/// \param vec Left-hand side vector.
/// \param mat Right-hand side matrix.
/// \return Result of the vector-matrix multiplication.
template <typename T, std::size_t W, std::size_t H>
constexpr Vector<T, W> operator*(const Vector<T, H>& vec, const Matrix<T, W, H>& mat) noexcept;

// Aliases

template <typename T> using Mat2 = Matrix<T, 2, 2>;
Expand Down
117 changes: 66 additions & 51 deletions include/RaZ/Math/Matrix.inl
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ constexpr Matrix<T, W, H>::Matrix(const Matrix<T, W + 1, H + 1>& mat) noexcept {
std::size_t widthStride = 0;

for (std::size_t heightIndex = 0; heightIndex < H; ++heightIndex) {
std::size_t resIndex = heightIndex * W;
const std::size_t resIndex = heightIndex * W;

for (std::size_t widthIndex = 0; widthIndex < W; ++widthIndex) {
const std::size_t finalIndex = resIndex + widthIndex;
Expand Down Expand Up @@ -219,7 +219,7 @@ template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H> Matrix<T, W, H>::identity() noexcept {
static_assert(W == H, "Error: Matrix must be a square one.");

Matrix<T, W, H> res;
Matrix res;

for (std::size_t diagIndex = 0; diagIndex < W; ++diagIndex)
res[diagIndex * W + diagIndex] = 1.f;
Expand All @@ -228,22 +228,22 @@ constexpr Matrix<T, W, H> Matrix<T, W, H>::identity() noexcept {
}

template <typename T, std::size_t W, std::size_t H>
template <typename... Vecs>
constexpr Matrix<T, W, H> Matrix<T, W, H>::fromRows(Vecs&&... vecs) noexcept {
static_assert(sizeof...(Vecs) == H, "Error: A Matrix can't be constructed with more or less vectors than it can hold.");
template <typename... VecsTs>
constexpr Matrix<T, W, H> Matrix<T, W, H>::fromRows(VecsTs&&... vecs) noexcept {
static_assert(sizeof...(VecsTs) == H, "Error: A Matrix can't be constructed with more or less vectors than it can hold.");

Matrix res;
res.setRows(std::forward<Vecs>(vecs)...);
res.setRows(std::forward<VecsTs>(vecs)...);
return res;
}

template <typename T, std::size_t W, std::size_t H>
template <typename... Vecs>
constexpr Matrix<T, W, H> Matrix<T, W, H>::fromColumns(Vecs&&... vecs) noexcept {
static_assert(sizeof...(Vecs) == W, "Error: A Matrix can't be constructed with more or less vectors than it can hold.");
template <typename... VecsTs>
constexpr Matrix<T, W, H> Matrix<T, W, H>::fromColumns(VecsTs&&... vecs) noexcept {
static_assert(sizeof...(VecsTs) == W, "Error: A Matrix can't be constructed with more or less vectors than it can hold.");

Matrix res;
res.setColumns(std::forward<Vecs>(vecs)...);
res.setColumns(std::forward<VecsTs>(vecs)...);
return res;
}

Expand Down Expand Up @@ -306,7 +306,7 @@ constexpr Vector<T, H> Matrix<T, W, H>::recoverColumn(std::size_t columnIndex) c
}

template <typename T, std::size_t W, std::size_t H>
constexpr bool Matrix<T, W, H>::strictlyEquals(const Matrix<T, W, H>& mat) const noexcept {
constexpr bool Matrix<T, W, H>::strictlyEquals(const Matrix& mat) const noexcept {
return std::equal(m_data.cbegin(), m_data.cend(), mat.getData().cbegin());
}

Expand All @@ -322,124 +322,124 @@ constexpr std::size_t Matrix<T, W, H>::hash(std::size_t seed) const noexcept {

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H> Matrix<T, W, H>::operator+(const Matrix& mat) const noexcept {
Matrix<T, W, H> res = *this;
Matrix res = *this;
res += mat;
return res;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H> Matrix<T, W, H>::operator+(T val) const noexcept {
Matrix<T, W, H> res = *this;
Matrix res = *this;
res += val;
return res;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H> Matrix<T, W, H>::operator-(const Matrix& mat) const noexcept {
Matrix<T, W, H> res = *this;
Matrix res = *this;
res -= mat;
return res;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H> Matrix<T, W, H>::operator-(T val) const noexcept {
Matrix<T, W, H> res = *this;
Matrix res = *this;
res -= val;
return res;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H> Matrix<T, W, H>::operator%(const Matrix& mat) const noexcept {
Matrix<T, W, H> res = *this;
Matrix res = *this;
res %= mat;
return res;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H> Matrix<T, W, H>::operator*(T val) const noexcept {
Matrix<T, W, H> res = *this;
Matrix res = *this;
res *= val;
return res;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H> Matrix<T, W, H>::operator/(const Matrix& mat) const noexcept {
Matrix<T, W, H> res = *this;
Matrix res = *this;
res /= mat;
return res;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H> Matrix<T, W, H>::operator/(T val) const noexcept {
Matrix<T, W, H> res = *this;
Matrix res = *this;
res /= val;
return res;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator+=(const Matrix<T, W, H>& mat) noexcept {
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator+=(const Matrix& mat) noexcept {
for (std::size_t i = 0; i < m_data.size(); ++i)
m_data[i] += mat[i];
return *this;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator+=(T val) noexcept {
for (T& it : m_data)
it += val;
for (T& elt : m_data)
elt += val;
return *this;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator-=(const Matrix<T, W, H>& mat) noexcept {
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator-=(const Matrix& mat) noexcept {
for (std::size_t i = 0; i < m_data.size(); ++i)
m_data[i] -= mat[i];
return *this;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator-=(T val) noexcept {
for (T& it : m_data)
it -= val;
for (T& elt : m_data)
elt -= val;
return *this;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator%=(const Matrix<T, W, H>& mat) noexcept {
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator%=(const Matrix& mat) noexcept {
for (std::size_t i = 0; i < m_data.size(); ++i)
m_data[i] *= mat[i];
return *this;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator*=(T val) noexcept {
for (T& it : m_data)
it *= val;
for (T& elt : m_data)
elt *= val;
return *this;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator/=(const Matrix<T, W, H>& mat) noexcept {
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator/=(const Matrix& mat) noexcept {
for (std::size_t i = 0; i < m_data.size(); ++i)
m_data[i] /= mat[i];
return *this;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator/=(T val) noexcept {
for (T& it : m_data)
it /= val;
for (T& elt : m_data)
elt /= val;
return *this;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator*=(const Matrix<T, W, H>& mat) noexcept {
constexpr Matrix<T, W, H>& Matrix<T, W, H>::operator*=(const Matrix& mat) noexcept {
*this = *this * mat;
return *this;
}

template <typename T, std::size_t W, std::size_t H>
constexpr bool Matrix<T, W, H>::operator==(const Matrix<T, W, H>& mat) const noexcept {
constexpr bool Matrix<T, W, H>::operator==(const Matrix& mat) const noexcept {
if constexpr (std::is_floating_point_v<T>)
return FloatUtils::areNearlyEqual(*this, mat);
else
Expand Down Expand Up @@ -488,46 +488,46 @@ constexpr void Matrix<T, W, H>::setValues(T2&& val, Args&&... args) noexcept {
}

template <typename T, std::size_t W, std::size_t H>
template <typename Vec, typename... Vecs>
constexpr void Matrix<T, W, H>::setRows(Vec&& vec, Vecs&&... args) noexcept {
static_assert(std::is_same_v<std::decay_t<Vec>, Vector<T, W>>, "Error: Rows must all be vectors of the same type & size.");
template <typename VecT, typename... VecsTs>
constexpr void Matrix<T, W, H>::setRows(VecT&& vec, VecsTs&&... args) noexcept {
static_assert(std::is_same_v<std::decay_t<VecT>, Vector<T, W>>, "Error: Rows must all be vectors of the same type & size.");

constexpr std::size_t firstIndex = H - sizeof...(args) - 1;

for (std::size_t widthIndex = 0; widthIndex < W; ++widthIndex)
m_data[firstIndex + widthIndex * H] = vec[widthIndex];
m_data[firstIndex + widthIndex * H] = std::forward<VecT>(vec)[widthIndex];

if constexpr (sizeof...(args) > 0)
setRows(std::forward<Vecs>(args)...);
setRows(std::forward<VecsTs>(args)...);
}

template <typename T, std::size_t W, std::size_t H>
template <typename Vec, typename... Vecs>
constexpr void Matrix<T, W, H>::setColumns(Vec&& vec, Vecs&&... args) noexcept {
static_assert(std::is_same_v<std::decay_t<Vec>, Vector<T, H>>, "Error: Columns must all be vectors of the same type & size.");
template <typename VecT, typename... VecsTs>
constexpr void Matrix<T, W, H>::setColumns(VecT&& vec, VecsTs&&... args) noexcept {
static_assert(std::is_same_v<std::decay_t<VecT>, Vector<T, H>>, "Error: Columns must all be vectors of the same type & size.");

constexpr std::size_t firstIndex = H * (W - sizeof...(args) - 1);

for (std::size_t heightIndex = 0; heightIndex < H; ++heightIndex)
m_data[firstIndex + heightIndex] = vec[heightIndex];
m_data[firstIndex + heightIndex] = std::forward<VecT>(vec)[heightIndex];

if constexpr (sizeof...(args) > 0)
setColumns(std::forward<Vecs>(args)...);
setColumns(std::forward<VecsTs>(args)...);
}

template <typename T, std::size_t W, std::size_t H, std::size_t WI, std::size_t HI>
constexpr Matrix<T, H, WI> operator*(const Matrix<T, W, H>& mat1, const Matrix<T, WI, HI>& mat2) noexcept {
static_assert(W == HI, "Error: The left-hand matrix's width must be equal to the right-hand matrix's height.");
template <typename T, std::size_t WL, std::size_t HL, std::size_t WR, std::size_t HR>
constexpr Matrix<T, HL, WR> operator*(const Matrix<T, WL, HL>& mat1, const Matrix<T, WR, HR>& mat2) noexcept {
static_assert(WL == HR, "Error: The left-hand side matrix's width must be equal to the right-hand side matrix's height.");

Matrix<T, H, WI> res;
Matrix<T, HL, WR> res;

for (std::size_t widthIndex = 0; widthIndex < WI; ++widthIndex) {
const std::size_t finalWidthIndex = widthIndex * H;
for (std::size_t widthIndex = 0; widthIndex < WR; ++widthIndex) {
const std::size_t finalWidthIndex = widthIndex * HL;

for (std::size_t heightIndex = 0; heightIndex < H; ++heightIndex) {
for (std::size_t heightIndex = 0; heightIndex < HL; ++heightIndex) {
T& val = res[finalWidthIndex + heightIndex];

for (std::size_t stride = 0; stride < W; ++stride)
for (std::size_t stride = 0; stride < WL; ++stride)
val += mat1.getElement(stride, heightIndex) * mat2.getElement(widthIndex, stride);
}
}
Expand All @@ -550,4 +550,19 @@ constexpr Vector<T, H> operator*(const Matrix<T, W, H>& mat, const Vector<T, W>&
return res;
}

template <typename T, std::size_t W, std::size_t H>
constexpr Vector<T, W> operator*(const Vector<T, H>& vec, const Matrix<T, W, H>& mat) noexcept {
// This multiplication is made assuming the vector to be horizontal
Vector<T, W> res;

for (std::size_t widthIndex = 0; widthIndex < W; ++widthIndex) {
const std::size_t finalWidthIndex = widthIndex * H;

for (std::size_t heightIndex = 0; heightIndex < H; ++heightIndex)
res[widthIndex] += vec[heightIndex] * mat[finalWidthIndex + heightIndex];
}

return res;
}

} // namespace Raz
2 changes: 1 addition & 1 deletion include/RaZ/Math/Quaternion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Quaternion {

/// Creates a quaternion representing an identity transformation.
/// \return Identity quaternion.
static constexpr Quaternion identity() noexcept { return Quaternion<T>(1, 0, 0, 0); }
static constexpr Quaternion identity() noexcept { return Quaternion(1, 0, 0, 0); }

/// Computes the dot product between quaternions.
/// \param quat Quaternion to compute the dot product with.
Expand Down
Loading

0 comments on commit de0574f

Please sign in to comment.