-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathFuture.scala
119 lines (96 loc) · 2.98 KB
/
Future.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package lantern
import org.scala_lang.virtualized.virtualize
import org.scala_lang.virtualized.SourceContext
import scala.virtualization.lms._
trait TestExp extends DslOps {
type Size = Int
object Size {
def zero: Size = 0
}
abstract class DataLoop {
def foreach(f: Rep[Int] => Unit): Unit
}
@virtualize
object DataLoop {
def apply(size: Int) = if (size <= 1) {
new DataLoop {
def foreach(f: Rep[Int] => Unit) = {
for (i <- 0 until size: Range) f(unit(i))
}
}
} else {
new DataLoop {
def foreach(f: Rep[Int] => Unit) = {
for (i <- 0 until size: Rep[Range]) f(i)
}
}
}
}
class Dimensions(val dims: Seq[Size]) {
def apply(idx: Int) = dims(idx)
val (nbElem +: strides) = (dims :\ Seq[Int](1)) {
case (dim, seq@(t +: q)) => (dim * t) +: seq
}
override def toString = dims mkString " x "
override def equals(o: Any) = o match {
case t: Dimensions => this.dims == t.dims
case _ => false
}
}
object Dimensions {
def apply(x: Size*) = new Dimensions(x)
}
val debug: Rep[Boolean] = false
class Test[T:Ordering:Numeric:Manifest](val data: Rep[Array[T]], val dims: Dimensions) {
val nbElem = dims.nbElem
def exit() = unchecked[Unit]("exit(0)")
def format = manifest[T] match {
case t if t == manifest[Int] => "%d "
case t if t == manifest[Float] => "%.3f "
case t if t == manifest[Double] => "%.3f "
}
val zero = implicitly[Numeric[T]].zero
val one = implicitly[Numeric[T]].one
val num = implicitly[Numeric[T]]
def log(x: T) = Math.log(num.toDouble(x))
@virtualize
def assertC(b: Rep[Boolean], msg: String, x: Rep[Any]) = {
if (debug && !b) { printf(s"Assert failed $msg\\n", x); exit() }
}
@virtualize
def apply(x: Rep[Size]*) = {
// Fei: should we make sure that length of x is the same as length of Dimensions
val idx: Rep[Size] = ((x zip (if (x.length == 1) Seq(1) else dims.strides)) :\ (0: Rep[Int])) { (c, agg) => agg + c._1 * c._2 }
assertC(0 <= idx && idx < nbElem, s"Out of bound: %d not in [0, ${nbElem}]", idx)
data(idx)
}
@virtualize
def update(idx: Rep[Int], v: Rep[T]) = this.data(idx) = v
@virtualize
def +(that: Test[T]) = {
assert(this.dims == that.dims, s"Dimension mismatch for +: ${this.dims} != ${that.dims}")
val arr = NewArray[T](this.nbElem)
for (x <- DataLoop(this.nbElem))
arr(x) = this(x) + that(x)
new Test(arr, this.dims)
}
@virtualize
def clipAt(bound: Rep[T]) = {
for (i <- DataLoop(this.nbElem)) {
if (this(i) > bound) this(i) = bound
if (this(i) + bound < zero)
this(i) = zero - bound
}
}
@virtualize
def printRaw(row: Int = 10) = {
for (i <- DataLoop(this.nbElem)) {
printf(format, this(i))
val imod = i % row
if (imod == row - 1)
printf("\\n")
}
printf("\\n")
}
}
}