Skip to content

Commit

Permalink
Merge pull request #20 from jemc-savi/add/op
Browse files Browse the repository at this point in the history
Delay errors until session compute.
  • Loading branch information
jemc authored Jun 13, 2023
2 parents 3759b9b + b57dce7 commit e081dd9
Show file tree
Hide file tree
Showing 52 changed files with 639 additions and 547 deletions.
12 changes: 6 additions & 6 deletions spec/Tensor.Comp.TensorDot.Outer.Spec.savi
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
:it "is equivalent to matrix multiplication for rank-2 tensors"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.tensordot_outer!("example"
g.const!("A", Tensor(F64).from_array([
g.tensordot_outer("example"
g.const("A", Tensor(F64).from_array([
1.0, 2.0
3.0, 4.0
]).try_reshape(Tensor.Shape.new([2, 2])))
g.const!("B", Tensor(F64).from_array([
g.const("B", Tensor(F64).from_array([
5.0, 6.0
7.0, 8.0
]).try_reshape(Tensor.Shape.new([2, 2])))
Expand All @@ -26,15 +26,15 @@
:it "handles larger-rank tensors by applying to the outer axes"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.tensordot_outer!("example"
g.const!("A", Tensor(F64).from_array([
g.tensordot_outer("example"
g.const("A", Tensor(F64).from_array([
1, 2, 3
4, 5, 6

7, 8, 9
10, 11, 12
]).try_reshape(Tensor.Shape.new([2, 2, 3])))
g.const!("B", Tensor(F64).from_array([
g.const("B", Tensor(F64).from_array([
13, 14
15, 16

Expand Down
18 changes: 9 additions & 9 deletions spec/Tensor.Gen.Random.Spec.savi
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
:const describes: "Tensor.Gen.Random"

:it "raises its internal counter with each graph node that uses it"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
random = g.gen_random!("random"
g.const!("seed", Tensor(U32).from_array([2, 3]))
_WithGraphHelper.run(@env) -> (g, session |
random = g.gen_random("random"
g.const("seed", Tensor(U32).from_array([2, 3]))
)
shape = Tensor.Shape.scalar
example1 = g.random_uniform!("example1", random, Tensor(F64), shape)
example2 = g.random_uniform!("example2", random, Tensor(F64), shape)
example3 = g.random_uniform!("example3", random, Tensor(F64), shape)
example4 = g.random_uniform!("example4", random, Tensor(F64), shape)
example5 = g.random_uniform!("example5", random, Tensor(F64), shape)
example1 = g.random_uniform("example1", random, Tensor(F64), shape)
example2 = g.random_uniform("example2", random, Tensor(F64), shape)
example3 = g.random_uniform("example3", random, Tensor(F64), shape)
example4 = g.random_uniform("example4", random, Tensor(F64), shape)
example5 = g.random_uniform("example5", random, Tensor(F64), shape)

assert: [
session.compute!(example1).as!(Tensor(F64)).into_array.first!
Expand All @@ -27,4 +27,4 @@
0.913356627721398900
0.007108289495397546
]
))
)
74 changes: 38 additions & 36 deletions spec/Tensor.Graph.Spec.savi
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,38 @@

graph = Tensor.Graph.new
session = Tensor.Graph.Session.new(graph)
try (
a = graph.new_operation("Const", "a") -> (builder |
builder
.set_attr_type("dtype", Tensor(F64).element_type_code)
.set_attr_tensor!("value", a_value)
.finish!
)

assert: a.output(0).shape.rank == 2
assert: a.output(0).shape.into_array == [2, 2]
a = graph.new_operation("Const", "a") -> (builder |
builder
.set_attr_type("dtype", Tensor(F64).element_type_code)
.set_attr_tensor("value", a_value)
.finish
)

b = graph.new_operation("Const", "b") -> (builder |
builder
.set_attr_type("dtype", Tensor(F64).element_type_code)
.set_attr_tensor!("value", b_value)
.finish!
)
product1 = graph.new_operation("MatMul", "product1") -> (builder |
builder
.add_input(a.output(0))
.add_input(b.output(0))
.finish!
)
product2 = graph.new_operation("MatMul", "product2") -> (builder |
builder
.add_input(a.output(0))
.add_input(b.output(0))
.set_attr_bool("transpose_a", True)
.finish!
)
assert: a.output(0).shape.rank == 2
assert: a.output(0).shape.into_array == [2, 2]

b = graph.new_operation("Const", "b") -> (builder |
builder
.set_attr_type("dtype", Tensor(F64).element_type_code)
.set_attr_tensor("value", b_value)
.finish
)
product1 = graph.new_operation("MatMul", "product1") -> (builder |
builder
.add_input(a.output(0))
.add_input(b.output(0))
.finish
)
product2 = graph.new_operation("MatMul", "product2") -> (builder |
builder
.add_input(a.output(0))
.add_input(b.output(0))
.set_attr_bool("transpose_a", True)
.finish
)

try (
result = session.compute!(product1.output(0))
assert: result.as!(Tensor(F64)).into_array == [
1.0 * 5.0 + 2.0 * 7.0, 1.0 * 6.0 + 2.0 * 8.0 // row1⋅col1, row1⋅col2
Expand All @@ -61,25 +61,27 @@
assert no_error: error!
)

:it "complains when creating an operation with an invalid type"
:it "complains when evaluating an operation with an invalid type"
g = Tensor.Graph.new
assert error: (
session = Tensor.Graph.Session.new(g)

assert error: session.compute!(
g.new_operation("Bogus", "example") -> (builder |
builder.finish!
)
builder.finish
).output(0)
)
assert: g.errors.first!.code == Tensor.Graph.Error.Code.InvalidArgument
assert: g.errors.first!.message.includes("Op type not registered 'Bogus'")

:it "optimizes to minimize a loss function with gradient descent"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
learning_rate = g.const!("learning_rate", Tensor(F64).scalar(0.25))
learning_rate = g.const("learning_rate", Tensor(F64).scalar(0.25))

x = g.variable!("x", Tensor(F64), [])
loss = g.square!("square", x)
x = g.variable("x", Tensor(F64), [])
loss = g.square("square", x)
grad = g.graph.add_gradients!([loss], [x]).first!

x2 = g.apply_gradient_descent!("apply_grad", grad, x, learning_rate)
x2 = g.apply_gradient_descent("apply_grad", grad, x, learning_rate)

result Tensor.Any = Tensor(F64).scalar(5)
[
Expand Down
12 changes: 6 additions & 6 deletions spec/Tensor.Op.Add.Spec.savi
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
:it "computes arithmetic addition"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.add!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3, 4]))
g.const!("y", Tensor(I32).from_array([5, 6, 7, 8]))
g.add("example"
g.const("x", Tensor(I32).from_array([1, 2, 3, 4]))
g.const("y", Tensor(I32).from_array([5, 6, 7, 8]))
)
)

Expand All @@ -17,9 +17,9 @@
:it "can broadcast smaller sizes/shapes across larger sizes/shapes"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.add!("example"
g.const!("x", Tensor(I32).from_array([-1, -3, -5]))
g.const!("y", Tensor(I32).from_array([
g.add("example"
g.const("x", Tensor(I32).from_array([-1, -3, -5]))
g.const("y", Tensor(I32).from_array([
1, 2, 3
4, 5, 6
7, 8, 9
Expand Down
18 changes: 10 additions & 8 deletions spec/Tensor.Op.Bitcast.Spec.savi
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
:it "distributes bits into a larger number of narrower elements"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.bitcast!("example"
g.const!("input", Tensor(U16).from_array([0x0246, 0x8ace]))
g.bitcast("example"
g.const("input", Tensor(U16).from_array([0x0246, 0x8ace]))
Tensor(U8)
)
)
Expand All @@ -18,8 +18,8 @@
:it "consolidates bits into a smaller number of wider elements"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.bitcast!("example"
g.const!("input", Tensor(U8).from_array([0x46, 0x02, 0xce, 0x8a])
g.bitcast("example"
g.const("input", Tensor(U8).from_array([0x46, 0x02, 0xce, 0x8a])
.try_reshape(Tensor.Shape.new([2, 2]))
)
Tensor(U16)
Expand All @@ -31,10 +31,12 @@

:it "complains on narrow to wide with more than one wide result per row"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.bitcast!("example"
g.const!("input", Tensor(U8).from_array([0x46, 0x02, 0xce, 0x8a])
// this would work if we did a reshape like: [2, 2]
assert error: session.compute!(
g.bitcast("example"
g.const("input", Tensor(U8).from_array([0x46, 0x02, 0xce, 0x8a])
// this would work if we did a reshape like: [2, 2]
)
Tensor(U16)
)
Tensor(U16)
)
)
16 changes: 8 additions & 8 deletions spec/Tensor.Op.Cast.Spec.savi
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
:it "does bounds-wrapping when converting to a narrower integer type"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.cast!("example"
g.const!("input", Tensor(I16).from_array([0, 1, 0xffff, 0x7890]))
g.cast("example"
g.const("input", Tensor(I16).from_array([0, 1, 0xffff, 0x7890]))
Tensor(I8)
)
)
Expand All @@ -17,8 +17,8 @@
:it "rounds a floating-point value to its nearest integer value"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.cast!("example"
g.const!("input", Tensor(F64).from_array([2.4, 2.5, 2.6, -2.5]))
g.cast("example"
g.const("input", Tensor(F64).from_array([2.4, 2.5, 2.6, -2.5]))
Tensor(I32)
)
)
Expand All @@ -29,8 +29,8 @@
:it "rounds to the nearest representable less-precise floating-point value"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.cast!("example"
g.const!("input", Tensor(F64).from_array([-1e26]))
g.cast("example"
g.const("input", Tensor(F64).from_array([-1e26]))
Tensor(F32)
)
)
Expand All @@ -41,8 +41,8 @@
:it "can be set to round with floating-point truncation (toward zero)"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.cast_with_floating_point_truncation!("example"
g.const!("input", Tensor(F64).from_array([-1e26]))
g.cast_with_floating_point_truncation("example"
g.const("input", Tensor(F64).from_array([-1e26]))
Tensor(F32)
)
)
Expand Down
50 changes: 28 additions & 22 deletions spec/Tensor.Op.Concat.Spec.savi
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
:it "combines the list of tensors into one new tensor"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.concat!("example"
g.concat("example"
[
g.const!("input_a", @f64_2x2(1, 2, 3, 4))
g.const!("input_b", @f64_2x2(5, 6, 7, 8))
g.const("input_a", @f64_2x2(1, 2, 3, 4))
g.const("input_b", @f64_2x2(5, 6, 7, 8))
]
)
)
Expand All @@ -28,10 +28,10 @@
:it "can combine along a different axis"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.concat!("example"
g.concat("example"
[
g.const!("input_a", @f64_2x2(1, 2, 3, 4))
g.const!("input_b", @f64_2x2(5, 6, 7, 8))
g.const("input_a", @f64_2x2(1, 2, 3, 4))
g.const("input_b", @f64_2x2(5, 6, 7, 8))
]
1 // axis
)
Expand All @@ -46,31 +46,37 @@

:it "complains when the inputs are of different types"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.concat!("example"
[
g.const!("input_a", Tensor(F64).from_array([1, 2, 3, 4]))
g.const!("input_b", Tensor(F32).from_array([5, 6, 7, 8]))
]
assert error: session.compute!(
g.concat("example"
[
g.const("input_a", Tensor(F64).from_array([1, 2, 3, 4]))
g.const("input_b", Tensor(F32).from_array([5, 6, 7, 8]))
]
)
)
)

:it "complains when the inputs are of different shapes"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.concat!("example"
[
g.const!("input_a", @f64_2x2(1, 2, 3, 4))
g.const!("input_b", @f64_2x2(5, 6, 7, 8).try_reshape(Tensor.Shape.new([1, 4])))
]
assert error: session.compute!(
g.concat("example"
[
g.const("input_a", @f64_2x2(1, 2, 3, 4))
g.const("input_b", @f64_2x2(5, 6, 7, 8).try_reshape(Tensor.Shape.new([1, 4])))
]
)
)
)

:it "complains when the given axis is greater-or-equal to the inputs' rank"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.concat!("example"
[
g.const!("input_a", @f64_2x2(1, 2, 3, 4))
g.const!("input_b", @f64_2x2(5, 6, 7, 8))
]
2
assert error: session.compute!(
g.concat("example"
[
g.const("input_a", @f64_2x2(1, 2, 3, 4))
g.const("input_b", @f64_2x2(5, 6, 7, 8))
]
2
)
)
)
2 changes: 1 addition & 1 deletion spec/Tensor.Op.Const.Spec.savi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
:it "emits a constant tensor value"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.const!("example"
g.const("example"
Tensor(F64).from_array([1, 2, 3, 4])
)
)
Expand Down
Loading

0 comments on commit e081dd9

Please sign in to comment.