Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tensorflow.go, option.go #261

Merged
merged 5 commits into from
Apr 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,4 @@ include Makefile.d/proto.mk
include Makefile.d/k8s.mk
include Makefile.d/kind.mk
include Makefile.d/client.mk
include Makefile.d/ml.mk
42 changes: 42 additions & 0 deletions Makefile.d/ml.mk
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# Copyright (C) 2019-2020 Vdaas.org Vald team ( kpango, rinx, kmrmt )
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
.PHONY: ml/models/clean
ml/models/clean:
rm -rf hack/ml/models

.PHONY: ml/models/tensorflow/init
ml/models/tensorflow/init:
mkdir -p hack/ml/models/tensorflow

.PHONY: ml/models/tensorflow/download
## download tensorflow model
ml/models/tensorflow/download: \
ml/models/clean \
ml/models/tensorflow/init \
ml/models/tensorflow/download/bert \
ml/models/tensorflow/download/insightface

.PHONY: ml/models/tensorflow/download/bert
ml/models/tensorflow/download/bert:
curl -LO https://github.com/vdaas/ml/raw/master/tensorflow/bert.tar.gz
tar -xvf bert.tar.gz -C hack/ml/models/tensorflow
rm bert.tar.gz

.PHONY: ml/models/tensorflow/download/insightface
ml/models/tensorflow/download/insightface:
curl -LO https://github.com/vdaas/ml/raw/master/tensorflow/insightface.tar.gz
tar -xvf insightface.tar.gz -C hack/ml/models/tensorflow
rm insightface.tar.gz
2 changes: 2 additions & 0 deletions hack/ml/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
39 changes: 39 additions & 0 deletions internal/core/converter/tensorflow/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var (
defaultOpts = []Option{
WithOperations(), // set to default
WithSessionOptions(nil), // set to default
WithNdim(0), // set to default
}
)

Expand Down Expand Up @@ -79,3 +80,41 @@ func WithTags(tags ...string) Option {
}
}
}

func WithFeed(operationName string, outputIndex int) Option {
return func(t *tensorflow) {
t.feeds = append(t.feeds, OutputSpec{operationName, outputIndex})
}
}

func WithFeeds(operationNames []string, outputIndexes []int) Option {
return func(t *tensorflow) {
if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) {
for i := range operationNames {
t.feeds = append(t.feeds, OutputSpec{operationNames[i], outputIndexes[i]})
}
}
}
}

func WithFetch(operationName string, outputIndex int) Option {
return func(t *tensorflow) {
t.fetches = append(t.fetches, OutputSpec{operationName, outputIndex})
}
}

func WithFetches(operationNames []string, outputIndexes []int) Option {
return func(t *tensorflow) {
if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) {
for i := range operationNames {
t.fetches = append(t.fetches, OutputSpec{operationNames[i], outputIndexes[i]})
}
}
}
}

func WithNdim(ndim uint8) Option {
return func(t *tensorflow) {
t.ndim = ndim
}
}
105 changes: 79 additions & 26 deletions internal/core/converter/tensorflow/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,35 @@ type SessionOptions = tf.SessionOptions
type Operation = tf.Operation

type TF interface {
GetVector(feeds []Feed, fetches []Fetch, targets ...*Operation) (values [][][]float64, err error)
GetVector(inputs ...string) ([]float64, error)
GetValue(inputs ...string) (interface{}, error)
GetValues(inputs ...string) (values []interface{}, err error)
Close() error
}

type tensorflow struct {
exportDir string
tags []string
feeds []OutputSpec
fetches []OutputSpec
operations []*Operation
sessionTarget string
sessionConfig []byte
options *SessionOptions
graph *tf.Graph
session *tf.Session
ndim uint8
}

type Feed struct {
InputBytes []byte
OperationName string
OutputIndex int
type OutputSpec struct {
operationName string
outputIndex int
}

type Fetch struct {
OperationName string
OutputIndex int
}
const (
TwoDim uint8 = iota + 2
ThreeDim
)

func New(opts ...Option) (TF, error) {
t := new(tensorflow)
Expand Down Expand Up @@ -78,38 +82,87 @@ func (t *tensorflow) Close() error {
return t.session.Close()
}

func (t *tensorflow) GetVector(feeds []Feed, fetches []Fetch, targets ...*Operation) (values [][][]float64, err error) {
input := make(map[tf.Output]*tf.Tensor, len(feeds))
for _, feed := range feeds {
inputTensor, err := tf.NewTensor([]string{string(feed.InputBytes)})
func (t *tensorflow) run(inputs ...string) ([]*tf.Tensor, error) {
if len(inputs) != len(t.feeds) {
return nil, errors.ErrInputLength(len(inputs), len(t.feeds))
}

feeds := make(map[tf.Output]*tf.Tensor, len(inputs))
for i, val := range inputs {
inputTensor, err := tf.NewTensor(val)
if err != nil {
return nil, err
}
input[t.graph.Operation(feed.OperationName).Output(feed.OutputIndex)] = inputTensor
feeds[t.graph.Operation(t.feeds[i].operationName).Output(t.feeds[i].outputIndex)] = inputTensor
}

output := make([]tf.Output, 0, len(fetches))
for _, fetch := range fetches {
output = append(output, t.graph.Operation(fetch.OperationName).Output(fetch.OutputIndex))
fetches := make([]tf.Output, 0, len(t.fetches))
for _, fetch := range t.fetches {
fetches = append(fetches, t.graph.Operation(fetch.operationName).Output(fetch.outputIndex))
}

if targets == nil {
targets = t.operations
}
return t.session.Run(feeds, fetches, t.operations)
}

results, err := t.session.Run(input, output, targets)
func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) {
tensors, err := t.run(inputs...)
if err != nil {
return nil, err
}
if tensors == nil || tensors[0] == nil || tensors[0].Value() == nil {
return nil, errors.ErrNilTensorTF(tensors)
}

values = make([][][]float64, 0, len(results))
for _, result := range results {
value, ok := result.Value().([][]float64)
switch t.ndim {
case TwoDim:
value, ok := tensors[0].Value().([][]float64)
if ok {
if value == nil {
return nil, errors.ErrNilTensorValueTF(value)
}
return value[0], nil
} else {
return nil, errors.ErrFailedToCastTF(tensors[0].Value())
}
case ThreeDim:
value, ok := tensors[0].Value().([][][]float64)
if ok {
if value == nil || value[0] == nil {
return nil, errors.ErrNilTensorValueTF(value)
}
return value[0][0], nil
} else {
return nil, errors.ErrFailedToCastTF(tensors[0].Value())
}
default:
value, ok := tensors[0].Value().([]float64)
if ok {
values = append(values, value)
return value, nil
} else {
return nil, errors.ErrFailedToCastTF(result.Value())
return nil, errors.ErrFailedToCastTF(tensors[0].Value())
}
}
}

func (t *tensorflow) GetValue(inputs ...string) (interface{}, error) {
tensors, err := t.run(inputs...)
if err != nil {
return nil, err
}
if tensors == nil || tensors[0] == nil {
return nil, errors.ErrNilTensorTF(tensors)
}
return tensors[0].Value(), nil
}

func (t *tensorflow) GetValues(inputs ...string) (values []interface{}, err error) {
tensors, err := t.run(inputs...)
if err != nil {
return nil, err
}
values = make([]interface{}, 0, len(tensors))
for _, tensor := range tensors {
values = append(values, tensor.Value())
}
return values, nil
}
9 changes: 9 additions & 0 deletions internal/errors/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,13 @@ var (
ErrFailedToCastTF = func(v interface{}) error {
return Errorf("failed to cast tensorflow result %+v", v)
}
ErrInputLength = func(i int, f int) error {
return Errorf("inputs length %d does not match feeds length %d", i, f)
}
ErrNilTensorTF = func(v interface{}) error {
return Errorf("nil tensorflow tensor %+v", v)
}
ErrNilTensorValueTF = func(v interface{}) error {
return Errorf("nil tensorflow tensor value %+v", v)
}
)