Skip to content

Commit

Permalink
fix tensorflow.go, option.go (#261)
Browse files Browse the repository at this point in the history
* fix tensorflow.go, option.go

Signed-off-by: datelier <57349093+datelier@users.noreply.github.com>

* fix TF interface

Signed-off-by: datelier <57349093+datelier@users.noreply.github.com>

* fix magic number, nil check

Signed-off-by: datelier <57349093+datelier@users.noreply.github.com>

* add ml tasks to Makefile

Signed-off-by: datelier <57349093+datelier@users.noreply.github.com>

* fix const value name

Signed-off-by: datelier <57349093+datelier@users.noreply.github.com>
  • Loading branch information
datelier authored Apr 23, 2020
1 parent f60e83f commit e46d63f
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 26 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,4 @@ include Makefile.d/proto.mk
include Makefile.d/k8s.mk
include Makefile.d/kind.mk
include Makefile.d/client.mk
include Makefile.d/ml.mk
42 changes: 42 additions & 0 deletions Makefile.d/ml.mk
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# Copyright (C) 2019-2020 Vdaas.org Vald team ( kpango, rinx, kmrmt )
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
.PHONY: ml/models/clean
ml/models/clean:
rm -rf hack/ml/models

.PHONY: ml/models/tensorflow/init
ml/models/tensorflow/init:
mkdir -p hack/ml/models/tensorflow

.PHONY: ml/models/tensorflow/download
## download tensorflow model
ml/models/tensorflow/download: \
ml/models/clean \
ml/models/tensorflow/init \
ml/models/tensorflow/download/bert \
ml/models/tensorflow/download/insightface

.PHONY: ml/models/tensorflow/download/bert
ml/models/tensorflow/download/bert:
curl -LO https://github.com/vdaas/ml/raw/master/tensorflow/bert.tar.gz
tar -xvf bert.tar.gz -C hack/ml/models/tensorflow
rm bert.tar.gz

.PHONY: ml/models/tensorflow/download/insightface
ml/models/tensorflow/download/insightface:
curl -LO https://github.com/vdaas/ml/raw/master/tensorflow/insightface.tar.gz
tar -xvf insightface.tar.gz -C hack/ml/models/tensorflow
rm insightface.tar.gz
2 changes: 2 additions & 0 deletions hack/ml/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
39 changes: 39 additions & 0 deletions internal/core/converter/tensorflow/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var (
defaultOpts = []Option{
WithOperations(), // set to default
WithSessionOptions(nil), // set to default
WithNdim(0), // set to default
}
)

Expand Down Expand Up @@ -79,3 +80,41 @@ func WithTags(tags ...string) Option {
}
}
}

func WithFeed(operationName string, outputIndex int) Option {
return func(t *tensorflow) {
t.feeds = append(t.feeds, OutputSpec{operationName, outputIndex})
}
}

func WithFeeds(operationNames []string, outputIndexes []int) Option {
return func(t *tensorflow) {
if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) {
for i := range operationNames {
t.feeds = append(t.feeds, OutputSpec{operationNames[i], outputIndexes[i]})
}
}
}
}

func WithFetch(operationName string, outputIndex int) Option {
return func(t *tensorflow) {
t.fetches = append(t.fetches, OutputSpec{operationName, outputIndex})
}
}

func WithFetches(operationNames []string, outputIndexes []int) Option {
return func(t *tensorflow) {
if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) {
for i := range operationNames {
t.fetches = append(t.fetches, OutputSpec{operationNames[i], outputIndexes[i]})
}
}
}
}

func WithNdim(ndim uint8) Option {
return func(t *tensorflow) {
t.ndim = ndim
}
}
105 changes: 79 additions & 26 deletions internal/core/converter/tensorflow/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,35 @@ type SessionOptions = tf.SessionOptions
type Operation = tf.Operation

type TF interface {
GetVector(feeds []Feed, fetches []Fetch, targets ...*Operation) (values [][][]float64, err error)
GetVector(inputs ...string) ([]float64, error)
GetValue(inputs ...string) (interface{}, error)
GetValues(inputs ...string) (values []interface{}, err error)
Close() error
}

type tensorflow struct {
exportDir string
tags []string
feeds []OutputSpec
fetches []OutputSpec
operations []*Operation
sessionTarget string
sessionConfig []byte
options *SessionOptions
graph *tf.Graph
session *tf.Session
ndim uint8
}

type Feed struct {
InputBytes []byte
OperationName string
OutputIndex int
type OutputSpec struct {
operationName string
outputIndex int
}

type Fetch struct {
OperationName string
OutputIndex int
}
const (
TwoDim uint8 = iota + 2
ThreeDim
)

func New(opts ...Option) (TF, error) {
t := new(tensorflow)
Expand Down Expand Up @@ -78,38 +82,87 @@ func (t *tensorflow) Close() error {
return t.session.Close()
}

func (t *tensorflow) GetVector(feeds []Feed, fetches []Fetch, targets ...*Operation) (values [][][]float64, err error) {
input := make(map[tf.Output]*tf.Tensor, len(feeds))
for _, feed := range feeds {
inputTensor, err := tf.NewTensor([]string{string(feed.InputBytes)})
func (t *tensorflow) run(inputs ...string) ([]*tf.Tensor, error) {
if len(inputs) != len(t.feeds) {
return nil, errors.ErrInputLength(len(inputs), len(t.feeds))
}

feeds := make(map[tf.Output]*tf.Tensor, len(inputs))
for i, val := range inputs {
inputTensor, err := tf.NewTensor(val)
if err != nil {
return nil, err
}
input[t.graph.Operation(feed.OperationName).Output(feed.OutputIndex)] = inputTensor
feeds[t.graph.Operation(t.feeds[i].operationName).Output(t.feeds[i].outputIndex)] = inputTensor
}

output := make([]tf.Output, 0, len(fetches))
for _, fetch := range fetches {
output = append(output, t.graph.Operation(fetch.OperationName).Output(fetch.OutputIndex))
fetches := make([]tf.Output, 0, len(t.fetches))
for _, fetch := range t.fetches {
fetches = append(fetches, t.graph.Operation(fetch.operationName).Output(fetch.outputIndex))
}

if targets == nil {
targets = t.operations
}
return t.session.Run(feeds, fetches, t.operations)
}

results, err := t.session.Run(input, output, targets)
func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) {
tensors, err := t.run(inputs...)
if err != nil {
return nil, err
}
if tensors == nil || tensors[0] == nil || tensors[0].Value() == nil {
return nil, errors.ErrNilTensorTF(tensors)
}

values = make([][][]float64, 0, len(results))
for _, result := range results {
value, ok := result.Value().([][]float64)
switch t.ndim {
case TwoDim:
value, ok := tensors[0].Value().([][]float64)
if ok {
if value == nil {
return nil, errors.ErrNilTensorValueTF(value)
}
return value[0], nil
} else {
return nil, errors.ErrFailedToCastTF(tensors[0].Value())
}
case ThreeDim:
value, ok := tensors[0].Value().([][][]float64)
if ok {
if value == nil || value[0] == nil {
return nil, errors.ErrNilTensorValueTF(value)
}
return value[0][0], nil
} else {
return nil, errors.ErrFailedToCastTF(tensors[0].Value())
}
default:
value, ok := tensors[0].Value().([]float64)
if ok {
values = append(values, value)
return value, nil
} else {
return nil, errors.ErrFailedToCastTF(result.Value())
return nil, errors.ErrFailedToCastTF(tensors[0].Value())
}
}
}

func (t *tensorflow) GetValue(inputs ...string) (interface{}, error) {
tensors, err := t.run(inputs...)
if err != nil {
return nil, err
}
if tensors == nil || tensors[0] == nil {
return nil, errors.ErrNilTensorTF(tensors)
}
return tensors[0].Value(), nil
}

func (t *tensorflow) GetValues(inputs ...string) (values []interface{}, err error) {
tensors, err := t.run(inputs...)
if err != nil {
return nil, err
}
values = make([]interface{}, 0, len(tensors))
for _, tensor := range tensors {
values = append(values, tensor.Value())
}
return values, nil
}
9 changes: 9 additions & 0 deletions internal/errors/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,13 @@ var (
ErrFailedToCastTF = func(v interface{}) error {
return Errorf("failed to cast tensorflow result %+v", v)
}
ErrInputLength = func(i int, f int) error {
return Errorf("inputs length %d does not match feeds length %d", i, f)
}
ErrNilTensorTF = func(v interface{}) error {
return Errorf("nil tensorflow tensor %+v", v)
}
ErrNilTensorValueTF = func(v interface{}) error {
return Errorf("nil tensorflow tensor value %+v", v)
}
)

0 comments on commit e46d63f

Please sign in to comment.