Skip to content

Commit

Permalink
fix warmup
Browse files Browse the repository at this point in the history
Signed-off-by: datelier <57349093+datelier@users.noreply.github.com>
  • Loading branch information
datelier committed Jul 9, 2020
1 parent b400518 commit 45123c4
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 92 deletions.
20 changes: 17 additions & 3 deletions internal/core/converter/tensorflow/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@
// Package tensorflow provides implementation of Go API for extract data to vector
package tensorflow

import (
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)

// Option is tensorflow configure.
type Option func(*tensorflow)

var (
defaultOpts = []Option{
WithOperations(), // set to default
WithSessionOptions(nil), // set to default
WithNdim(0), // set to default
withLoadFunc(tf.LoadSavedModel), // set to default
WithOperations(), // set to default
WithSessionOptions(nil), // set to default
WithNdim(0), // set to default
}
)

Expand Down Expand Up @@ -102,6 +107,15 @@ func WithTags(tags ...string) Option {
}
}

func withLoadFunc(
loadFunc func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error)) Option {
return func(t *tensorflow) {
if loadFunc != nil {
t.loadFunc = loadFunc
}
}
}

// WithFeed returns Option that sets feeds.
func WithFeed(operationName string, outputIndex int) Option {
return func(t *tensorflow) {
Expand Down
113 changes: 101 additions & 12 deletions internal/core/converter/tensorflow/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import (
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
"github.com/vdaas/vald/internal/errors"
"go.uber.org/goleak"
)
Expand Down Expand Up @@ -71,7 +74,7 @@ func TestWithSessionOptions(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -140,7 +143,7 @@ func TestWithSessionTarget(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -209,7 +212,7 @@ func TestWithSessionConfig(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -294,7 +297,7 @@ func TestWithOperations(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -363,7 +366,7 @@ func TestWithExportPath(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -459,7 +462,7 @@ func TestWithTags(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand All @@ -482,6 +485,92 @@ func TestWithTags(t *testing.T) {
}
}

func TestWithLoadFunc(t *testing.T) {
type T = tensorflow
type args struct {
loadFunc func(string, []string, *SessionOptions) (*tf.SavedModel, error)
}
type want struct {
obj *T
}
type test struct {
name string
args args
want want
checkFunc func(want, *T) error
beforeFunc func(args)
afterFunc func(args)
}

defaultCheckFunc := func(w want, obj *T) error {
opts := []cmp.Option{
cmp.AllowUnexported(tensorflow{}),
cmp.AllowUnexported(OutputSpec{}),
cmpopts.IgnoreFields(tensorflow{}, "loadFunc"),
}
if diff := cmp.Diff(w.obj, obj, opts...); len(diff) != 0 {
return errors.Errorf("err: %s", diff)
}
opt := cmp.Comparer(func(want, obj T) bool {
p1 := reflect.ValueOf(want).FieldByName("loadFunc").Pointer()
p2 := reflect.ValueOf(obj).FieldByName("loadFunc").Pointer()
return p1 == p2
})
if !cmp.Equal(w.obj, obj, opt) {
return errors.Errorf("got = %v, want = %v", obj, w.obj)
}
return nil
}

loadFunc := func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error) {
return nil, nil
}
tests := []test{
{
name: "set success when loadFunc is not nil",
args: args{
loadFunc: loadFunc,
},
want: want{
obj: &T{
loadFunc: loadFunc,
},
},
},
{
name: "do nothing when loadFunc is nil",
args: args{
loadFunc: nil,
},
want: want{
obj: &T{},
},
},
}

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
if test.afterFunc != nil {
defer test.afterFunc(test.args)
}

if test.checkFunc == nil {
test.checkFunc = defaultCheckFunc
}
got := withLoadFunc(test.args.loadFunc)
obj := new(T)
got(obj)
if err := test.checkFunc(test.want, obj); err != nil {
tt.Errorf("error = %v", err)
}
})
}
}

func TestWithFeed(t *testing.T) {
type T = tensorflow
type args struct {
Expand Down Expand Up @@ -529,7 +618,7 @@ func TestWithFeed(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -635,7 +724,7 @@ func TestWithFeeds(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -703,7 +792,7 @@ func TestWithFetch(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -809,7 +898,7 @@ func TestWithFetches(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -905,7 +994,7 @@ func TestWithWarmupInputs(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -968,7 +1057,7 @@ func TestWithNdim(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down
31 changes: 16 additions & 15 deletions internal/core/converter/tensorflow/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type session interface {
type tensorflow struct {
exportDir string
tags []string
loadFunc func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error)
feeds []OutputSpec
fetches []OutputSpec
operations []*Operation
Expand All @@ -70,18 +71,6 @@ const (
threeDim
)

var loadFunc = func(t *tensorflow) error {
model, err := tf.LoadSavedModel(t.exportDir, t.tags, t.options)
if err != nil {
return err
}

t.graph = model.Graph
t.session = model.Session

return nil
}

// New load a tensorlfow model and returns a new tensorflow struct.
func New(opts ...Option) (TF, error) {
t := new(tensorflow)
Expand All @@ -90,19 +79,31 @@ func New(opts ...Option) (TF, error) {
opt(t)
}

err := loadFunc(t)
model, err := t.loadFunc(t.exportDir, t.tags, t.options)
if err != nil {
return nil, err
}

t.graph = model.Graph
t.session = model.Session

err = t.warmup()
if err != nil {
return nil, err
}

return t, nil
}

func (t *tensorflow) warmup() error {
if t.warmupInputs != nil {
_, err := t.run(t.warmupInputs...)
if err != nil {
return nil, err
return err
}
}

return t, nil
return nil
}

func (t *tensorflow) Close() error {
Expand Down
Loading

0 comments on commit 45123c4

Please sign in to comment.