-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsparse_vector.rs
64 lines (51 loc) · 1.62 KB
/
sparse_vector.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
use std::borrow::Borrow;
use wide::f32x4;
use hibit_tree::{config, intersection, BitBlock, ReqDefault};
use hibit_tree::RegularHibitTree;
use hibit_tree::Iter;
#[derive(Clone, Default)]
struct DataBlock(f32x4);
type SparseArray = hibit_tree::SparseTree<config::width_64::depth_2, DataBlock, ReqDefault>;
#[derive(Default)]
struct SparseVector {
sparse_array: SparseArray
}
impl SparseVector{
pub fn set(&mut self, index: usize, value: f32){
const BLOCK_SIZE: usize = 4;
let block_index = index / BLOCK_SIZE;
let in_block_index= index % BLOCK_SIZE;
let block = self.sparse_array.get_or_insert(block_index);
unsafe{
*block.0.as_array_mut().get_unchecked_mut(in_block_index) = value;
}
}
}
/// Per-element multiplication
pub fn mul<'a>(v1: &'a SparseVector, v2: &'a SparseVector)
-> impl RegularHibitTree<Data=DataBlock> + 'a
{
intersection(&v1.sparse_array, &v2.sparse_array)
.map(|(l, r): (&DataBlock, &DataBlock)| DataBlock(l.0 * r.0) )
}
pub fn dot(v1: &SparseVector, v2: &SparseVector) -> f32 {
let m = mul(v1, v2);
let iter = Iter::new(&m);
let mut sum = f32x4::ZERO;
iter.for_each(|(index, block)|{
sum += block.borrow().0;
});
sum.reduce_add()
}
fn main(){
let mut v1 = SparseVector::default();
let mut v2 = SparseVector::default();
let INDEX_MUL: usize = 1;
v1.set(10*INDEX_MUL, 1.0);
v1.set(20*INDEX_MUL, 10.0);
v1.set(30*INDEX_MUL, 100.0);
v2.set(10*INDEX_MUL, 1.0);
v2.set(30*INDEX_MUL, 0.5);
let d = dot(&v1, &v2);
assert_eq!(d, 51.0 )
}