Skip to content

Commit

Permalink
csc <-> csr conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
konovod committed Dec 19, 2023
1 parent fa221bf commit d57ea9c
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ Most operations - matrix addition, multiplication, inversion, transposition and
- [x] Sparse matrices (perhaps out of scope/deserves separate shard)
- [x] COO Matrix
- [x] CSR Matrix
- [x] CSC Matrix (except CSC\CSR conversion)
- [x] CSC Matrix
- [ ] non-sorted mode
- [ ] Separate shard for SparseSuite or other libs

Expand Down
23 changes: 23 additions & 0 deletions spec/sparce_csr_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ describe CSRMatrix do
ms.nonzeros.should eq 6
end

it "can be created from CSC matrix" do
m_dense = GMat[
[-4.0, -3.0, -2.0, 0.0],
[-1.0, 0.0, 11.0, 0.0],
[20.0, 30.0, 4.0, 1.0],
]
m_csc = CSCMatrix(Float64).new m_dense
m_csr = CSRMatrix(Float64).new m_csc
m_csc.should eq m_dense
m_csr.should eq m_dense
end

it "can be cleared" do
ms = CSRMatrix(Float64).new(4, 4, raw_rows: [0, 1, 2, 3, 4], raw_columns: [0, 1, 2, 1], raw_values: [5.0, 8.0, 3.0, 6.0])
ms.clear
Expand Down Expand Up @@ -278,4 +290,15 @@ describe CSRMatrix do
m.norm(MatrixNorm::One).should eq 7
m.norm(MatrixNorm::MaxAbs).should eq 4
end

it "can be converted to CSC using as_csc" do
m = CSRMatrix(Float64).new GMat[
[-4.0, -3.0, -2.0],
[-1.0, 0.0, 1.0],
[2.0, 3.0, 4.0],
]
mt = m.as_csc
mt.should be_a CSCMatrix(Float64)
mt.should eq m.t
end
end
23 changes: 23 additions & 0 deletions spec/sparse_csc_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ describe CSCMatrix do
ms.nonzeros.should eq 6
end

it "can be created from CSR matrix" do
m_dense = GMat[
[-4.0, -3.0, -2.0, 0.0],
[-1.0, 0.0, 11.0, 0.0],
[20.0, 30.0, 4.0, 1.0],
]
m_csr = CSRMatrix(Float64).new m_dense
m_csc = CSCMatrix(Float64).new m_csr
m_csc.should eq m_dense
m_csr.should eq m_dense
end

it "can be cleared" do
ms = CSCMatrix(Float64).new(4, 4, raw_columns: [0, 1, 2, 3, 4], raw_rows: [0, 1, 2, 1], raw_values: [5.0, 8.0, 3.0, 6.0])
ms.clear
Expand Down Expand Up @@ -278,4 +290,15 @@ describe CSCMatrix do
m.norm(MatrixNorm::One).should eq 7
m.norm(MatrixNorm::MaxAbs).should eq 4
end

it "can be converted to CSR using as_csr" do
m = CSCMatrix(Float64).new GMat[
[-4.0, -3.0, -2.0],
[-1.0, 0.0, 1.0],
[2.0, 3.0, 4.0],
]
mt = m.as_csr
mt.should be_a CSRMatrix(Float64)
mt.should eq m.t
end
end
29 changes: 19 additions & 10 deletions src/matrix/sparse/cscmatrix.cr
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ require "./sparse_matrix.cr"
# TODO - inline docs

module LA::Sparse
macro csxmatrix(name, arows, acolumns, index1, index2, norm)
macro csxmatrix(name, arows, acolumns, index1, index2, norm, other)
class {{name}}Matrix(T) < Matrix(T)
protected getter raw_{{acolumns}} : Array(Int32)
protected getter raw_{{arows}} : Array(Int32)
Expand All @@ -17,21 +17,21 @@ module LA::Sparse
end

def initialize(@nrows, @ncolumns, raw_{{arows}} : Array(Int32), raw_{{acolumns}} : Array(Int32), raw_values : Array(T), @flags = MatrixFlags::None, *, dont_clone : Bool = false)
if raw_{{arows}}.size != @n{{arows}} + 1
if raw_{{arows}}.size != @n{{arows}} + 1
raise ArgumentError.new("Can't construct #{self.class} from arrays of different size: {{arows}}.size(#{raw_{{arows}}.size}) != n{{arows}}+1 (#{@n{{arows}} + 1}")
end
if raw_{{acolumns}}.size != raw_values.size
end
if raw_{{acolumns}}.size != raw_values.size
raise ArgumentError.new("Can't construct #{self.class} from arrays of different size: {{acolumns}}.size(#{raw_{{acolumns}}.size}) != values.size(#{raw_values.size})")
end
if dont_clone
end
if dont_clone
@raw_{{arows}} = raw_{{arows}}
@raw_{{acolumns}} = raw_{{acolumns}}
@raw_values = raw_values
else
else
@raw_{{arows}} = raw_{{arows}}.dup
@raw_{{acolumns}} = raw_{{acolumns}}.dup
@raw_values = raw_values.dup
end
end
end

def nonzeros : Int32
Expand All @@ -46,6 +46,11 @@ module LA::Sparse
new(matrix.nrows, matrix.ncolumns, matrix.raw_{{arows}}.dup, matrix.raw_{{acolumns}}.dup, matrix.raw_values.map { |v| T.new(v) }, dont_clone: true, flags: matrix.flags)
end

def self.new(matrix : {{other}}Matrix)
tr = matrix.transpose
self.new(matrix.nrows, matrix.ncolumns, raw_{{arows}}: tr.raw_{{acolumns}},raw_{{acolumns}}: tr.raw_{{arows}}, raw_values: tr.raw_values, flags: matrix.flags, dont_clone: true)
end

def self.new(matrix : LA::Matrix)
if matrix.is_a? Sparse::Matrix
nonzeros = matrix.nonzeros
Expand Down Expand Up @@ -413,9 +418,13 @@ module LA::Sparse
super(kind)
end
end

def as_{{other.stringify.downcase.id}}
{{other}}Matrix(T).new(@ncolumns, @nrows, raw_{{arows}}: raw_{{acolumns}},raw_{{acolumns}}: raw_{{arows}}, raw_values: raw_values, flags: flags.transpose)
end
end
end

csxmatrix(CSR, rows, columns, i, j, inf)
csxmatrix(CSC, columns, rows, j, i, one)
csxmatrix(CSR, rows, columns, i, j, inf, CSC)
csxmatrix(CSC, columns, rows, j, i, one, CSR)
end

0 comments on commit d57ea9c

Please sign in to comment.