Skip to content

Commit

Permalink
add sentence prob handle
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Bernstein committed Aug 28, 2018
1 parent 6967fb1 commit 0a6885e
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 53 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "fastText"]
path = fastText
url = https://github.com/facebookresearch/fastText
25 changes: 25 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Go parameters
GOCMD=go
GOBUILD=$(GOCMD) build
GOCLEAN=$(GOCMD) clean
GOTEST=$(GOCMD) test
GOGET=$(GOCMD) get
BINARY_NAME=fast_bind
BINARY_UNIX=$(BINARY_NAME)_unix

all: clean test
build:
$(GOBUILD) -o $(BINARY_NAME) -v
test:
go run cli/main.go predict -m ../service/chat-quality/tools/ml/fasttext/model.bin "i accidentally shrunk my shrinky dink lol </s>"
echo "i accidentally shrunk my shrinky dink lol" | ./fastText/fasttext predict-prob ../service/chat-quality/tools/ml/fasttext/model.bin -
go run cli/main.go sentence -m ../service/chat-quality/tools/ml/fasttext/model.bin "i accidentally shrunk my shrinky dink lol </s>"
echo "i accidentally shrunk my shrinky dink lol" | ./fastText/fasttext print-sentence-vectors ../service/chat-quality/tools/ml/fasttext/model.bin
clean:
rm -rf ~/Library/Caches/go-build/
$(GOCLEAN)
rm -f $(BINARY_NAME)
rm -f $(BINARY_UNIX)
run:
$(GOBUILD) -o $(BINARY_NAME) -v ./...
./$(BINARY_NAME)
25 changes: 21 additions & 4 deletions cbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,27 @@ char *Wordvec(FastTextHandle handle, char *query) {
model->getWordVector(vec, query);

auto res = json::array();
for (int i = 0; i < vec.data_.size(); i++) {
res.push_back({
{"probability",vec.data_[i]},
});
for (int i = 0; i < vec.size(); i++) {
res.push_back(vec[i]);
}

return strdup(res.dump().c_str());
}

char *Sentencevec(FastTextHandle handle, char *query) {
auto model = bit_cast<fasttext::FastText *>(handle);

membuf sbuf(query, query + strlen(query));
std::istream in(&sbuf);

fasttext::Vector vec(model->getDimension());
// fasttext::Matrix wordVectors(model->dict_->nwords(), model->getDimension());
// model->precomputeWordVectors(wordVectors);
model->getSentenceVector(in, vec);

auto res = json::array();
for (int i = 0; i < vec.size(); i++) {
res.push_back(vec[i]);
}

return strdup(res.dump().c_str());
Expand Down
1 change: 1 addition & 0 deletions cbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ void DeleteHandle(FastTextHandle handle);
char *Predict(FastTextHandle handle, char *query);
char *Analogy(FastTextHandle handle, char *query);
char *Wordvec(FastTextHandle handle, char *query);
char *Sentencevec(FastTextHandle handle, char *query);

#ifdef __cplusplus
}
Expand Down
7 changes: 7 additions & 0 deletions cli/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package main

import "github.com/bountylabs/go-fasttext/cmd"

func main() {
cmd.Execute()
}
40 changes: 40 additions & 0 deletions cmd/sentence.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package cmd

import (
"fmt"

"github.com/Unknwon/com"
"github.com/k0kubun/pp"
fasttext "github.com/bountylabs/go-fasttext"
"github.com/spf13/cobra"
)

// predictCmd represents the predict command
var sentenceCmd = &cobra.Command{
Use: "sentence -m [path_to_model] [query]",
Short: "get a sentence vector",
Args: cobra.ExactArgs(1), // make sure that there is only one argument being passed in
Run: func(cmd *cobra.Command, args []string) {
if !com.IsFile(modelPath) {
fmt.Println("the file %s does not exist", modelPath)
return
}

// create a model object
model := fasttext.Open(modelPath)
// close the model at the end
defer model.Close()
// perform the prediction
preds, err := model.Sentencevec(args[0])
if err != nil {
fmt.Println(err)
return
}
pp.Println(preds)
},
}

func init() {
sentenceCmd.Flags().StringVarP(&modelPath, "model", "m", "", "path to the fasttext model")
rootCmd.AddCommand(sentenceCmd)
}
1 change: 1 addition & 0 deletions fastText
Submodule fastText added at 9fbc03
25 changes: 21 additions & 4 deletions fasttext.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ func (handle *Model) Predict(query string) (Predictions, error) {
defer C.free(unsafe.Pointer(r))
js := C.GoString(r)

panic("here")

// unmarshal the json results into the predictions
// object. See https://blog.golang.org/json-and-go
predictions := []Prediction{}
Expand Down Expand Up @@ -87,19 +85,38 @@ func (handle *Model) Analogy(query string) (Analogs, error) {
return analogies, nil
}

func (handle *Model) Wordvec(query string) (Vectors, error) {
func (handle *Model) Wordvec(query string) ([]float32, error) {
cquery := C.CString(query)
defer C.free(unsafe.Pointer(cquery))

r := C.Wordvec(handle.handle, cquery)
defer C.free(unsafe.Pointer(r))
js := C.GoString(r)

vectors := []Vector{}
vectors := []float32{}
err := json.Unmarshal([]byte(js), &vectors)
if err != nil {
return nil, err
}

return vectors, nil
}

//Requires sentence ends with </s>
func (handle *Model) Sentencevec(query string) ([]float32, error) {
cquery := C.CString(query)
defer C.free(unsafe.Pointer(cquery))

r := C.Sentencevec(handle.handle, cquery)
defer C.free(unsafe.Pointer(r))
js := C.GoString(r)

vectors := []float32{}
err := json.Unmarshal([]byte(js), &vectors)
if err != nil {
return nil, err
}

return vectors, nil
}

11 changes: 0 additions & 11 deletions main.go

This file was deleted.

34 changes: 0 additions & 34 deletions word2vec.go

This file was deleted.

0 comments on commit 0a6885e

Please sign in to comment.