Skip to content

Commit

Permalink
matrix chain multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhufei Chu committed Sep 15, 2024
1 parent 24a6ac0 commit 6e0f1d8
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 21 deletions.
67 changes: 63 additions & 4 deletions docs/stp_compute.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
STP Computation
===============

Semi-Tensor Product (STP) Computation based on Eigen Library
The basic Semi-Tensor Product (STP) Computation based on Eigen Library

Basic STP Computation
----------------------

Native Definition
-----------------
^^^^^^^^^^^^^^^^^^^^^

Given two matrices :math:`A_{m \times n}` and :math:`B_{p \times q}`, the STP
of :math:`A` and :math:`B` is defined as
Expand Down Expand Up @@ -58,7 +61,7 @@ According to the definition,
\end{align}
Copy Definition
-----------------
^^^^^^^^^^^^^^^^^^^^^
Continue with the above example, we can view :math:`A` be composed by two
submatrices :math:`A_{left}=\begin{bmatrix}1 & 0 \\ 0 & 1\end{bmatrix}`
and :math:`A_{right} = \begin{bmatrix}0 & 0 \\ 1 & 1\end{bmatrix}`. When we
Expand All @@ -69,7 +72,7 @@ partial result; otherwise, we copy :math:`A_{right}`. One can verify the
results are exactly the same as the ones computed by native definition.

Functions
----------------
^^^^^^^^^^^^^^^^^^^^^
In header file ``stp/stp_eigen.hpp``, we provide function::

matrix semi_tensor_product( const matrix& A, const matrix& B
Expand All @@ -96,3 +99,59 @@ Example
auto result = stp::semi_tensor_product( A, B, true, stp::stp_method::native_method );

One can find more examples or test cases in ``examples/stp_eigen.cpp`` and ``test/stp_eigen.cpp``.

Matrix Chain STP Computation
----------------------------
When we have :math:`n` matrices multiplication and :math:`n \ge 3`, we call
this as matrix chain STP computation.

Sequence
^^^^^^^^^^^^^^^^^^^^^
The matrix are multiplied one by one in sequence. For example, we have 4
matrices :math:`A`, :math:`B`, :math:`C`, and :math:`D`. The parenthesis of
the matrix chain is

.. math::
ABCD = (((AB)C)D).
Dynamic Programming
^^^^^^^^^^^^^^^^^^^^^
As the computation complexity is distinct if we use different parenthesis
method, we also propose a dynamic programming method for matrix chain STP
computation. We may have an optimal parenthesis for the matrix chain as

.. math::
ABCD = ((AB)(CD)).
Multi-threads
^^^^^^^^^^^^^^^^^^^^^
Once we obtained the computation orders based on dynamic programming, the
computation can also invoke multi-threads to accerlerate.

Functions
^^^^^^^^^^^^^^^^^^^^^
In header file ``stp/stp_eigen.hpp``, we provide function::

matrix matrix_chain_multiply( const matrix_chain& mc,
const bool verbose = false,
const mc_multiply_method method = mc_multiply_method::dynamic_programming )

to compute the STP of matrix chain :math:`mc`, where toggle ``verbose`` is off and toggle ``mc_multiply_method``
is used by the dynamic programming by default.

Example

.. code-block:: c++

matrix_chain mc;

//default
auto result = stp::matrix_chain_multiply( mc );

//print verbose information
auto result = stp::matrix_chain_multiply( mc, true );

//use sequence method for matrix chain STP computation
auto result = stp::matrix_chain_multiply( mc, false, mc_multiply_method::sequence );

One can find more test cases in ``test/stp_eigen.cpp``.
17 changes: 4 additions & 13 deletions examples/mc_mul_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,12 @@ void print(const matrix_chain& mc)
}
}

matrix random_generation(int row, int col)
{
matrix result(row, col);
for(int i = 0; i < row; i++)
for(int j = 0; j < col; j++)
result(i, j) = rand() % 2;
return result;
}

void test1()
{
matrix_chain mc;
for(int i = 0; i < 22; i++)
{
mc.push_back(random_generation(2, 4));
mc.push_back(stp::matrix_random_generation(2, 4));
}
// print(mc);
matrix r1 = stp::matrix_chain_multiply(mc, true);
Expand All @@ -40,7 +31,7 @@ void test2()
matrix_chain mc;
for(int i = 0; i < 22; i++)
{
mc.push_back(random_generation(4, 2));
mc.push_back(stp::matrix_random_generation(4, 2));
}
// print(mc);
matrix r1 = stp::matrix_chain_multiply(mc, true);
Expand All @@ -53,8 +44,8 @@ void test3()
matrix_chain mc;
for(int i = 0; i < 100; i++)
{
if(rand() % 2 == 1) mc.push_back(random_generation(4, 2));
else mc.push_back(random_generation(2, 4));
if(rand() % 2 == 1) mc.push_back(stp::matrix_random_generation(4, 2));
else mc.push_back(stp::matrix_random_generation(2, 4));
}
// print(mc);
matrix r1 = stp::matrix_chain_multiply(mc, true);
Expand Down
7 changes: 5 additions & 2 deletions include/stp/stp_eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ namespace stp

if(verbose)
{
print();
report();
}

return result;
Expand Down Expand Up @@ -492,11 +492,13 @@ namespace stp
return stacks[0];
}

void print()
void report()
{
std::cout << "------------------Matrix Chain STP Computation-----------------\n";
if( method == mc_multiply_method::dynamic_programming )
{
std::cout << "Use dynamic programming method for matrix chain multiply.\n";
std::cout << "The parenthesis are added as shown in the following.\n";
for( int t : orders )
{
if( t == -1 ) std::cout << "(";
Expand All @@ -511,6 +513,7 @@ namespace stp
std::cout << "Use sequence method for matrix chain multiply.\n";
std::cout << "Total time: " << to_millisecond( time ) << "ms\n";
}
std::cout << "-------------------------------------------------------------\n";
}

private:
Expand Down
45 changes: 43 additions & 2 deletions test/stp_eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,47 @@ TEST_CASE( "STP calculation by two methods", "[eigen]" )
int p2 = 4 << 5;
matrix A2 = stp::matrix_random_generation( m, n2 );
matrix B2 = stp::matrix_random_generation( p2, q );
CHECK( stp::semi_tensor_product( A2, B2, false, stp::stp_method::copy_method ) ==
stp::semi_tensor_product( A2, B2, false, stp::stp_method::native_method ) );

matrix r1 = stp::semi_tensor_product( A2, B2, false, stp::stp_method::copy_method );
matrix r2 = stp::semi_tensor_product( A2, B2, false, stp::stp_method::native_method );
CHECK( r1 == r2 );
}

TEST_CASE( "Matrix chain STP calculation by two methods", "[eigen]" )
{
matrix_chain mc1, mc2, mc3;

for( int i = 0; i < 100; i++ )
{
if( rand() % 2 == 1 )
{
mc1.push_back( matrix_random_generation( 4, 2 ) );
}
else
{
mc1.push_back( matrix_random_generation( 2, 4 ) );
}
}

matrix r1 = stp::matrix_chain_multiply( mc1, false, stp::mc_multiply_method::dynamic_programming );
matrix r2 = stp::matrix_chain_multiply( mc1, false, stp::mc_multiply_method::sequence );
CHECK( r1 == r2 );

for( int i = 0; i < 22; i++ )
{
mc2.push_back( stp::matrix_random_generation( 2, 4 ) );
}

matrix r3 = stp::matrix_chain_multiply( mc2, false, stp::mc_multiply_method::dynamic_programming );
matrix r4 = stp::matrix_chain_multiply( mc2, false, stp::mc_multiply_method::sequence );
CHECK( r3 == r4 );

for( int i = 0; i < 22; i++ )
{
mc3.push_back( stp::matrix_random_generation( 4, 2 ) );
}

matrix r5 = stp::matrix_chain_multiply( mc3, false, stp::mc_multiply_method::dynamic_programming );
matrix r6 = stp::matrix_chain_multiply( mc3, false, stp::mc_multiply_method::sequence );
CHECK( r5 == r6 );
}

0 comments on commit 6e0f1d8

Please sign in to comment.