-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More efficient Matrix data structure #45
base: brakedown
Are you sure you want to change the base?
Changes from all commits
04f4296
ddfd74d
8e0de99
453df5b
e20f67e
cf2af29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
|
||
use ark_ff::Field; | ||
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; | ||
use ark_std::vec::Vec; | ||
|
||
/// Takes as input a struct, and converts them to a series of bytes. All traits | ||
/// that implement `CanonicalSerialize` can be automatically converted to bytes | ||
|
@@ -47,14 +47,22 @@ | |
|
||
#[derive(Derivative, CanonicalSerialize, CanonicalDeserialize)] | ||
#[derivative(Default(bound = ""), Clone(bound = ""), Debug(bound = ""))] | ||
pub struct Matrix<F: Field> { | ||
pub struct RowMajorMatrix<F: Field> { | ||
pub(crate) n: usize, | ||
pub(crate) m: usize, | ||
entries: Vec<Vec<F>>, | ||
rows: Vec<Vec<F>>, | ||
} | ||
|
||
impl<F: Field> Matrix<F> { | ||
/// Returns a Matrix of dimensions n x m given a list of n * m field elements. | ||
#[derive(Derivative, CanonicalSerialize, CanonicalDeserialize)] | ||
#[derivative(Default(bound = ""), Clone(bound = ""), Debug(bound = ""))] | ||
pub struct ColumnMajorMatrix<F: Field> { | ||
pub(crate) n: usize, | ||
pub(crate) m: usize, | ||
cols: Vec<Vec<F>>, | ||
} | ||
|
||
impl<F: Field> RowMajorMatrix<F> { | ||
/// Returns a RowMajorMatrix of dimensions n x m given a list of n * m field elements. | ||
/// The list should be ordered row-first, i.e. [a11, ..., a1m, a21, ..., a2m, ...]. | ||
/// | ||
/// # Panics | ||
|
@@ -70,21 +78,37 @@ | |
); | ||
|
||
// TODO more efficient to run linearly? | ||
let entries: Vec<Vec<F>> = (0..n) | ||
let rows: Vec<Vec<F>> = (0..n) | ||
.map(|row| (0..m).map(|col| entry_list[m * row + col]).collect()) | ||
.collect(); | ||
|
||
Self { n, m, entries } | ||
Self { n, m, rows } | ||
} | ||
|
||
/// Returns self as a list of rows | ||
pub(crate) fn rows(&self) -> &Vec<Vec<F>> { | ||
&self.rows | ||
} | ||
|
||
/// Returns the entry in position (i, j). **Indexing starts at 0 in both coordinates**, | ||
/// i.e. the first element is in position (0, 0) and the last one in (n - 1, j - 1), | ||
/// where n and m are the number of rows and columns, respectively. | ||
/// | ||
/// Index bound checks are waived for efficiency and behaviour under invalid indexing is undefined | ||
#[cfg(test)] | ||
pub(crate) fn entry(&self, i: usize, j: usize) -> F { | ||
self.rows[i][j] | ||
} | ||
|
||
/// Returns a Matrix given a list of its rows, each in turn represented as a list of field elements. | ||
/// Returns a RowMajorMatrix given a list of its rows, each in turn represented as a list of field elements. | ||
/// | ||
/// # Panics | ||
/// Panics if the sub-lists do not all have the same length. | ||
pub(crate) fn new_from_rows(row_list: Vec<Vec<F>>) -> Self { | ||
let m = row_list[0].len(); | ||
#[cfg(test)] | ||
pub(crate) fn new_from_rows(row_major: Vec<Vec<F>>) -> Self { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only used in testing now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can be removed / refactor tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that |
||
let m = row_major[0].len(); | ||
|
||
for row in row_list.iter().skip(1) { | ||
for row in row_major.iter().skip(1) { | ||
assert_eq!( | ||
row.len(), | ||
m, | ||
|
@@ -93,34 +117,12 @@ | |
} | ||
|
||
Self { | ||
n: row_list.len(), | ||
n: row_major.len(), | ||
m, | ||
entries: row_list, | ||
rows: row_major, | ||
} | ||
} | ||
|
||
/// Returns the entry in position (i, j). **Indexing starts at 0 in both coordinates**, | ||
/// i.e. the first element is in position (0, 0) and the last one in (n - 1, j - 1), | ||
/// where n and m are the number of rows and columns, respectively. | ||
/// | ||
/// Index bound checks are waived for efficiency and behaviour under invalid indexing is undefined | ||
#[cfg(test)] | ||
pub(crate) fn entry(&self, i: usize, j: usize) -> F { | ||
self.entries[i][j] | ||
} | ||
|
||
/// Returns self as a list of rows | ||
pub(crate) fn rows(&self) -> Vec<Vec<F>> { | ||
self.entries.clone() | ||
} | ||
|
||
/// Returns self as a list of columns | ||
pub(crate) fn cols(&self) -> Vec<Vec<F>> { | ||
(0..self.m) | ||
.map(|col| (0..self.n).map(|row| self.entries[row][col]).collect()) | ||
.collect() | ||
} | ||
|
||
/// Returns the product v * self, where v is interpreted as a row vector. In other words, | ||
/// it returns a linear combination of the rows of self with coefficients given by v. | ||
/// | ||
|
@@ -139,14 +141,50 @@ | |
inner_product( | ||
v, | ||
&(0..self.n) | ||
.map(|row| self.entries[row][col]) | ||
.map(|row| self.rows[row][col]) | ||
.collect::<Vec<F>>(), | ||
) | ||
}) | ||
.collect() | ||
} | ||
} | ||
|
||
impl<F: Field> ColumnMajorMatrix<F> { | ||
/// Returns a ColumnMajorMatrix given a list of its rows, each in turn represented as a list of field elements. | ||
/// | ||
/// # Panics | ||
/// Panics if the sub-lists do not all have the same length. | ||
pub(crate) fn new_from_rows(row_major: Vec<Vec<F>>) -> Self { | ||
let m = row_major[0].len(); | ||
|
||
for row in row_major.iter().skip(1) { | ||
assert_eq!( | ||
row.len(), | ||
m, | ||
"Invalid matrix construction: not all rows have the same length" | ||
); | ||
} | ||
let cols = (0..m) | ||
.map(|col| { | ||
(0..row_major.len()) | ||
.map(|row| row_major[row][col]) | ||
.collect() | ||
}) | ||
.collect(); | ||
|
||
Self { | ||
n: row_major.len(), | ||
m, | ||
cols, | ||
} | ||
} | ||
|
||
/// Returns self as a list of columns | ||
pub(crate) fn cols(&self) -> &Vec<Vec<F>> { | ||
&self.cols | ||
} | ||
} | ||
|
||
#[inline] | ||
pub(crate) fn inner_product<F: Field>(v1: &[F], v2: &[F]) -> F { | ||
ark_std::cfg_iter!(v1) | ||
|
@@ -207,22 +245,22 @@ | |
#[test] | ||
fn test_matrix_constructor_flat() { | ||
let entries: Vec<Fr> = to_field(vec![10, 100, 4, 67, 44, 50]); | ||
let mat = Matrix::new_from_flat(2, 3, &entries); | ||
let mat = RowMajorMatrix::new_from_flat(2, 3, &entries); | ||
assert_eq!(mat.entry(1, 2), Fr::from(50)); | ||
} | ||
|
||
#[test] | ||
fn test_matrix_constructor_flat_square() { | ||
let entries: Vec<Fr> = to_field(vec![10, 100, 4, 67]); | ||
let mat = Matrix::new_from_flat(2, 2, &entries); | ||
let mat = RowMajorMatrix::new_from_flat(2, 2, &entries); | ||
assert_eq!(mat.entry(1, 1), Fr::from(67)); | ||
} | ||
|
||
#[test] | ||
#[should_panic(expected = "dimensions are 2 x 3 but entry vector has 5 entries")] | ||
fn test_matrix_constructor_flat_panic() { | ||
let entries: Vec<Fr> = to_field(vec![10, 100, 4, 67, 44]); | ||
Matrix::new_from_flat(2, 3, &entries); | ||
RowMajorMatrix::new_from_flat(2, 3, &entries); | ||
} | ||
|
||
#[test] | ||
|
@@ -232,7 +270,7 @@ | |
to_field(vec![23, 1, 0]), | ||
to_field(vec![55, 58, 9]), | ||
]; | ||
let mat = Matrix::new_from_rows(rows); | ||
let mat = RowMajorMatrix::new_from_rows(rows); | ||
assert_eq!(mat.entry(2, 0), Fr::from(55)); | ||
} | ||
|
||
|
@@ -244,7 +282,7 @@ | |
to_field(vec![23, 1, 0]), | ||
to_field(vec![55, 58]), | ||
]; | ||
Matrix::new_from_rows(rows); | ||
ColumnMajorMatrix::new_from_rows(rows); | ||
} | ||
|
||
#[test] | ||
|
@@ -255,7 +293,7 @@ | |
to_field(vec![17, 89]), | ||
]; | ||
|
||
let mat = Matrix::new_from_rows(rows); | ||
let mat = ColumnMajorMatrix::new_from_rows(rows); | ||
|
||
assert_eq!(mat.cols()[1], to_field(vec![76, 92, 89])); | ||
} | ||
|
@@ -268,7 +306,7 @@ | |
to_field(vec![55, 58, 9]), | ||
]; | ||
|
||
let mat = Matrix::new_from_rows(rows); | ||
let mat = RowMajorMatrix::new_from_rows(rows); | ||
let v: Vec<Fr> = to_field(vec![12, 41, 55]); | ||
// by giving the result in the integers and then converting to Fr | ||
// we ensure the test will still pass even if Fr changes | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't seem to be used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is not used in Hyrax and in Ligero, it is only for testing. We can remove it.