Skip to content

Commit

Permalink
Update examples
Browse files Browse the repository at this point in the history
  • Loading branch information
matteo-grella committed Oct 30, 2023
1 parent dea27c8 commit 2c148e0
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 29 deletions.
57 changes: 33 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,27 +75,33 @@ Here is an example of how to calculate the sum of two variables:
package main

import (
"fmt"
"fmt"
"log"

"github.com/nlpodyssey/spago/ag"
"github.com/nlpodyssey/spago/mat"
"github.com/nlpodyssey/spago/ag"
"github.com/nlpodyssey/spago/mat"
)

type T = float32

func main() {
// create a new node of type variable with a scalar
a := mat.Scalar(T(2.0), mat.WithGrad(true)) // create another node of type variable with a scalar
b := mat.Scalar(T(5.0), mat.WithGrad(true)) // create an addition operator (the calculation is actually performed here)
c := ag.Add(a, b)

// print the result
fmt.Printf("c = %v (float%d)\n", c.Value(), c.Value().Scalar().BitSize())

c.AccGrad(mat.Scalar(T(0.5)))
ag.Backward(c)
fmt.Printf("ga = %v\n", a.Grad())
fmt.Printf("gb = %v\n", b.Grad())
// define the type of the elements in the tensors
type T = float32

// create a new node of type variable with a scalar
a := mat.Scalar(T(2.0), mat.WithGrad(true)) // create another node of type variable with a scalar
b := mat.Scalar(T(5.0), mat.WithGrad(true)) // create an addition operator (the calculation is actually performed here)
c := ag.Add(a, b)

// print the result
fmt.Printf("c = %v (float%d)\n", c.Value(), c.Value().Item().BitSize())

c.AccGrad(mat.Scalar(T(0.5)))

if err := ag.Backward(c); err != nil {
log.Fatalf("error during Backward(): %v", err)
}

fmt.Printf("ga = %v\n", a.Grad())
fmt.Printf("gb = %v\n", b.Grad())
}
```

Expand All @@ -115,17 +121,20 @@ Here is a simple implementation of the perceptron formula:
package main

import (
. "github.com/nlpodyssey/spago/ag"
"github.com/nlpodyssey/spago/mat"
"fmt"

. "github.com/nlpodyssey/spago/ag"
"github.com/nlpodyssey/spago/mat"
)

func main() {
x := mat.Scalar(-0.8)
w := mat.Scalar(0.4)
b := mat.Scalar(-0.2)
x := mat.Scalar(-0.8)
w := mat.Scalar(0.4)
b := mat.Scalar(-0.2)

y := Sigmoid(Add(Mul(w, x), b))

y := Sigmoid(Add(Mul(w, x), b))
_ = y
fmt.Printf("y = %0.3f\n", y.Value().Item())
}
```

Expand Down
31 changes: 31 additions & 0 deletions examples/addition/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package main

import (
"fmt"
"log"

"github.com/nlpodyssey/spago/ag"
"github.com/nlpodyssey/spago/mat"
)

func main() {
// define the type of the elements in the tensors
type T = float32

// create a new node of type variable with a scalar
a := mat.Scalar(T(2.0), mat.WithGrad(true)) // create another node of type variable with a scalar
b := mat.Scalar(T(5.0), mat.WithGrad(true)) // create an addition operator (the calculation is actually performed here)
c := ag.Add(a, b)

// print the result
fmt.Printf("c = %v (float%d)\n", c.Value(), c.Value().Item().BitSize())

c.AccGrad(mat.Scalar(T(0.5)))

if err := ag.Backward(c); err != nil {
log.Fatalf("error during Backward(): %v", err)
}

fmt.Printf("ga = %v\n", a.Grad())
fmt.Printf("gb = %v\n", b.Grad())
}
17 changes: 17 additions & 0 deletions examples/perceptron/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package main

import (
"fmt"
. "github.com/nlpodyssey/spago/ag"

Check failure on line 5 in examples/perceptron/main.go

View workflow job for this annotation

GitHub Actions / staticcheck

should not use dot imports (ST1001)
"github.com/nlpodyssey/spago/mat"
)

func main() {
x := mat.Scalar(-0.8)
w := mat.Scalar(0.4)
b := mat.Scalar(-0.2)

y := Sigmoid(Add(Mul(w, x), b))

fmt.Printf("y = %0.3f\n", y.Value().Item())
}
11 changes: 6 additions & 5 deletions examples/regression.go → examples/regression/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ func main() {
fmt.Printf("Model parameters:\n")
fmt.Printf("W: %.2f | B: %.2f\n\n", m.W.Value().Item().F64(), m.B.Value().Item().F64())

fmt.Printf("Saving the trained model to the file...\n")
err := nn.DumpToFile(m, "model.bin")
if err != nil {
log.Fatal(err)
}
// -- Enable this code to save the trained model to a file --
// fmt.Printf("Saving the trained model to the file...\n")
// err := nn.DumpToFile(m, "model.bin")
// if err != nil {
// log.Fatal(err)
// }
}

0 comments on commit 2c148e0

Please sign in to comment.