Skip to content

Commit

Permalink
adding threshold and multi-line strings support
Browse files Browse the repository at this point in the history
Signed-off-by: Dusan Malusev <dusan@dusanmalusev.dev>
  • Loading branch information
CodeLieutenant committed Jul 7, 2023
1 parent e13f776 commit bb716e8
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 49 deletions.
8 changes: 7 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
"memory": "cpp",
"utility": "cpp",
"iomanip": "cpp",
"cmath": "cpp"
"cmath": "cpp",
"array": "cpp",
"string_view": "cpp",
"initializer_list": "cpp",
"ranges": "cpp",
"span": "cpp",
"string": "cpp"
}
}
41 changes: 28 additions & 13 deletions cbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <memory>
#include <streambuf>

#include <args.cc>
#include <autotune.cc>
#include <dictionary.cc>
#include <fasttext.cc>
Expand Down Expand Up @@ -47,36 +46,45 @@ void FastText_DeleteHandle(const FastTextHandle handle)
delete model;
}

FastText_Predict_t FastText_Predict(const FastTextHandle handle, FastText_String_t query)
FastText_Predict_t FastText_PredictOne(const FastTextHandle handle, FastText_String_t query, float threshold)
{
return FastText_Predict(handle, query, 1, threshold);
}

FastText_Predict_t FastText_Predict(const FastTextHandle handle, FastText_String_t query, int k, float threshold)
{
const auto model = reinterpret_cast<fasttext::FastText *>(handle);

membuf sbuf(query);
std::istream in(&sbuf);

auto predictions = new std::vector<std::pair<fasttext::real, std::string>>();
model->predictLine(in, reinterpret_cast<std::vector<std::pair<fasttext::real, std::string>> &>(predictions), 1,
0.0f);
model->predictLine(in, reinterpret_cast<std::vector<std::pair<fasttext::real, std::string>> &>(predictions), k,
threshold);

free(query.data);
query.data = nullptr;
query.size = 0;

return FastText_Predict_t{
predictions->size(),
(void *)predictions,
};
}

char *FastText_Analogy(const FastTextHandle handle, const char *query, size_t length)
{
return "";
// char *FastText_Analogy(const FastTextHandle handle, const char *query, size_t length)
// {
// return "";

// auto model = reinterpret_cast<fasttext::FastText *>(handle);
// // auto model = reinterpret_cast<fasttext::FastText *>(handle);

// model->getAnalogies(1, query, 10);
// // model->getAnalogies(1, query, 10);

// size_t ii = 0;
// auto res = json::array();
// // size_t ii = 0;
// // auto res = json::array();

// return strdup(res.dump().c_str());
}
// // return strdup(res.dump().c_str());
// }

FastText_FloatVector_t FastText_Wordvec(const FastTextHandle handle, FastText_String_t word)
{
Expand All @@ -86,6 +94,10 @@ FastText_FloatVector_t FastText_Wordvec(const FastTextHandle handle, FastText_St
auto vec = new fasttext::Vector(dimensions);
model->getWordVector(reinterpret_cast<fasttext::Vector &>(vec), std::string(word.data, word.size));

free(word.data);
word.data = nullptr;
word.size = 0;

return FastText_FloatVector_t{
vec->data(),
(void *)vec,
Expand All @@ -102,6 +114,9 @@ FastText_FloatVector_t FastText_Sentencevec(const FastTextHandle handle, FastTex

auto vec = new fasttext::Vector(model->getDimension());
model->getSentenceVector(in, reinterpret_cast<fasttext::Vector &>(vec));
free(sentance.data);
sentance.data = nullptr;
sentance.size = 0;

return FastText_FloatVector_t{
vec->data(),
Expand Down
6 changes: 4 additions & 2 deletions cbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ extern "C"

FastTextHandle FastText_NewHandle(const char *path);
void FastText_DeleteHandle(const FastTextHandle handle);
FastText_Predict_t FastText_Predict(const FastTextHandle handle, FastText_String_t query);
FastText_Predict_t FastText_Predict(const FastTextHandle handle, FastText_String_t query, int k, float threshold);
FastText_Predict_t FastText_PredictOne(const FastTextHandle handle, FastText_String_t query, float threshold);

FastText_FloatVector_t FastText_Wordvec(const FastTextHandle handle, FastText_String_t word);
FastText_FloatVector_t FastText_Sentencevec(const FastTextHandle handle, FastText_String_t sentance);

char *FastText_Analogy(const FastTextHandle handle, FastText_String_t query);
// char *FastText_Analogy(const FastTextHandle handle, FastText_String_t query);

void FastText_FreeFloatVector(FastText_FloatVector_t vector);
void FastText_FreePredict(FastText_Predict_t predict);
Expand Down
28 changes: 11 additions & 17 deletions cmd/analogy.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
package cmd

import (
"fmt"

"github.com/k0kubun/pp"
"github.com/spf13/cobra"
"github.com/unknwon/com"

"github.com/nano-interactive/go-fasttext"
)

var (
Expand All @@ -20,18 +14,18 @@ var analogyCmd = &cobra.Command{
Short: "Perform word analogy on a query using an input model",
Args: cobra.ExactArgs(1), // make sure that there is only one argument being passed in
Run: func(cmd *cobra.Command, args []string) {
if !com.IsFile(unsupervisedModelPath) {
fmt.Println("the file %s does not exist", unsupervisedModelPath)
return
}
// if !com.IsFile(unsupervisedModelPath) {
// fmt.Println("the file %s does not exist", unsupervisedModelPath)
// return
// }

// create a model object
model := fasttext.Open(unsupervisedModelPath)
// close the model at the end
defer model.Close()
// perform the prediction
analogies := model.Analogy(args[0])
pp.Println(analogies)
// // create a model object
// model := fasttext.Open(unsupervisedModelPath)
// // close the model at the end
// defer model.Close()
// // perform the prediction
// analogies := model.Analogy(args[0])
// pp.Println(analogies)
},
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ var predictCmd = &cobra.Command{
// close the model at the end
defer model.Close()
// perform the prediction
preds := model.Predict(args[0])
preds := model.Predict(args[0], 1, 0.0)
pp.Println(preds)
},
}
Expand Down
51 changes: 36 additions & 15 deletions fasttext.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package fasttext

// #cgo CXXFLAGS: -I${SRCDIR}/fastText/src -I${SRCDIR} -I${SRCDIR}/include -std=c++17 -O3 -fPIC -pedantic -Wall -Wextra -Wno-sign-compare -Wno-unused-parameter
// #cgo CXXFLAGS: -I${SRCDIR}/fastText/src -I${SRCDIR} -I${SRCDIR}/include -std=c++17 -O3 -fPIC
// #cgo LDFLAGS: -lstdc++
// #include <stdio.h>
// #include <stdlib.h>
// #include "cbits.h"
import "C"

import (
"strings"
"unsafe"
)

Expand Down Expand Up @@ -37,17 +38,43 @@ func (handle *Model) Close() error {
return nil
}

// // Perform model prediction
func (handle Model) Predict(query string) Predictions {
func (handle Model) MultiLinePredict(query string, k int32, threshoad float32 ) []Predictions {
lines := strings.Split(query, "\n")

predics := make([]Predictions, 0, len(lines))

for _, line := range lines {
predictions := handle.Predict(line, k, threshoad)
predics = append(predics, predictions)
}

return predics
}

func (handle Model) PredictOne(query string, threshoad float32) Prediction {
cquery := C.CString(query)
defer C.free(unsafe.Pointer(cquery))

// Call the Predict function defined in cbits.cpp
// passing in the model handle and the query string
r := C.FastText_Predict(handle.p, C.FastText_String_t{
r := C.FastText_PredictOne(handle.p, C.FastText_String_t{
data: cquery,
size: C.size_t(len(query)),
})
}, C.float(threshoad))
defer C.FastText_FreePredict(r)

cPredic := C.FastText_PredictItemAt(r, C.size_t(0))

return Prediction{
Label: C.GoStringN(cPredic.label.data, C.int(cPredic.label.size)),
Probability: float32(cPredic.probability),
}
}

// Perform model prediction
func (handle Model) Predict(query string, k int32, threshoad float32) Predictions {
r := C.FastText_Predict(handle.p, C.FastText_String_t{
data: C.CString(query),
size: C.size_t(len(query)),
}, C.int(k), C.float(threshoad))
defer C.FastText_FreePredict(r)

predictions := make(Predictions, r.size)
Expand Down Expand Up @@ -82,11 +109,8 @@ func (handle Model) Predict(query string) Predictions {
// }

func (handle Model) Wordvec(word string) []float32 {
cquery := C.CString(word)
defer C.free(unsafe.Pointer(cquery))

r := C.FastText_Wordvec(handle.p, C.FastText_String_t{
data: cquery,
data: C.CString(word),
size: C.size_t(len(word)),
})
defer C.FastText_FreeFloatVector(r)
Expand All @@ -99,11 +123,8 @@ func (handle Model) Wordvec(word string) []float32 {

// Requires sentence ends with </s>
func (handle Model) Sentencevec(query string) []float32 {
cquery := C.CString(query)
defer C.free(unsafe.Pointer(cquery))

r := C.FastText_Sentencevec(handle.p, C.FastText_String_t{
data: cquery,
data: C.CString(query),
size: C.size_t(len(query)),
})

Expand Down

0 comments on commit bb716e8

Please sign in to comment.