Skip to content

Commit

Permalink
Merge branch 'master' of github.com:nano-interactive/go-fasttext
Browse files Browse the repository at this point in the history
Signed-off-by: Dusan Malusev <dusan@dusanmalusev.dev>
  • Loading branch information
CodeLieutenant committed Jul 13, 2023
2 parents 50fcf9f + 6a636d7 commit bd4f91f
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 51 deletions.
13 changes: 12 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@
"streambuf": "cpp",
"thread": "cpp",
"typeinfo": "cpp",
"valarray": "cpp"
"valarray": "cpp",
"bitset": "cpp",
"charconv": "cpp",
"cinttypes": "cpp",
"condition_variable": "cpp",
"list": "cpp",
"source_location": "cpp",
"format": "cpp",
"future": "cpp",
"mutex": "cpp",
"stdfloat": "cpp",
"variant": "cpp"
}
}
33 changes: 16 additions & 17 deletions cbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ struct membuf : std::streambuf
}
};

FastText_Result_t FastText_NewHandle(const char *path)
FastText_Result_t FastText_NewHandle(FastText_String_t path)
{
auto model = new fasttext::FastText();

try
{
model->loadModel(std::string(path));
model->loadModel(std::string(path.data, path.size));
return FastText_Result_t{
FastText_Result_t::SUCCESS,
(FastText_Handle_t)model,
Expand All @@ -78,21 +78,23 @@ void FastText_DeleteHandle(const FastText_Handle_t handle)
delete model;
}

FastText_Predict_t FastText_PredictOne(const FastText_Handle_t handle, FastText_String_t query, float threshold)
{
return FastText_Predict(handle, query, 1, threshold);
}

FastText_Predict_t FastText_Predict(const FastText_Handle_t handle, FastText_String_t query, int k, float threshold)
{
const auto model = reinterpret_cast<fasttext::FastText *>(handle);

membuf sbuf(query);
std::istream in(&sbuf);

auto predictions = new Predictions((size_t)k);
model->predictLine(in, *predictions, k, threshold);
FREE_STRING(query);
auto predictions = new Predictions();
if (!model->predictLine(in, *predictions, k, threshold))
{
delete predictions;

return FastText_Predict_t{
0,
nullptr,
};
}

return FastText_Predict_t{
predictions->size(),
Expand All @@ -104,11 +106,9 @@ FastText_Predict_t FastText_Analogy(const FastText_Handle_t handle, FastText_Str
FastText_String_t word3, int32_t k)
{
const auto model = reinterpret_cast<fasttext::FastText *>(handle);
Predictions predictions = model->getAnalogies(k, word1.data, word2.data, word3.data);

FREE_STRING(word1);
FREE_STRING(word2);
FREE_STRING(word3);
Predictions predictions =
model->getAnalogies(k, std::string(word1.data, word1.size), std::string(word2.data, word2.size),
std::string(word3.data, word3.size));

auto vec = new Predictions(std::move(predictions));

Expand All @@ -124,8 +124,7 @@ FastText_FloatVector_t FastText_Wordvec(const FastText_Handle_t handle, FastText
int64_t dimensions = model->getDimension();

auto vec = new fasttext::Vector(dimensions);
model->getWordVector(*vec, word.data);
FREE_STRING(word);
model->getWordVector(*vec, std::string(word.data, word.size));

return FastText_FloatVector_t{
vec->data(),
Expand Down
3 changes: 1 addition & 2 deletions cbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@ extern "C"
};
} FastText_Result_t;

FastText_Result_t FastText_NewHandle(const char *path);
FastText_Result_t FastText_NewHandle(FastText_String_t path);
void FastText_DeleteHandle(const FastText_Handle_t handle);
FastText_Predict_t FastText_Predict(const FastText_Handle_t handle, FastText_String_t query, int k,
float threshold);
FastText_Predict_t FastText_PredictOne(const FastText_Handle_t handle, FastText_String_t query, float threshold);
FastText_FloatVector_t FastText_Wordvec(const FastText_Handle_t handle, FastText_String_t word);
FastText_FloatVector_t FastText_Sentencevec(const FastText_Handle_t handle, FastText_String_t sentance);
FastText_Predict_t FastText_Analogy(const FastText_Handle_t handle, FastText_String_t word1,
Expand Down
8 changes: 6 additions & 2 deletions cmd/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ var predictCmd = &cobra.Command{
// close the model at the end
defer model.Close()
// perform the prediction
preds := model.Predict(args[0], 1, 0.0)
pp.Println(preds)
preds, err := model.Predict(args[0], 1, 0.0)
if err != nil {
pp.Fatalln(err)
return
}
pp.Println(preds)
},
}

Expand Down
86 changes: 64 additions & 22 deletions fasttext.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@ package fasttext
import "C"

import (
"strings"
"errors"
"unsafe"
)

var (
ErrPredictionFailed = errors.New("prediction failed")
ErrNoPredictions = errors.New("no predictions")
)

// A model object. Effectively a wrapper
// around the C fasttext handle
type Model struct {
Expand All @@ -29,10 +34,10 @@ func (e *ModelOpenError) Error() string {
// Opens a model from a path and returns a model
// object
func Open(path string) (Model, error) {
cpath := C.CString(path)
defer C.free(unsafe.Pointer(cpath))

result := C.FastText_NewHandle(cpath)
result := C.FastText_NewHandle(C.FastText_String_t{
data: cStr(path),
size: C.size_t(len(path)),
})

if result.status != 0 {
ch := *(**C.char)(unsafe.Pointer(&result.anon0[0]))
Expand All @@ -58,50 +63,76 @@ func (handle *Model) Close() error {
return nil
}

func (handle Model) MultiLinePredict(query string, k int32, threshoad float32) []Predictions {
lines := strings.Split(query, "\n")

func (handle Model) MultiLinePredict(lines []string, k int32, threshoad float32) ([]Predictions, error) {
predics := make([]Predictions, 0, len(lines))

for _, line := range lines {
predictions := handle.Predict(line, k, threshoad)
predictions, err := handle.Predict(line, k, threshoad)
if err != nil && errors.Is(err, ErrPredictionFailed) {
return nil, err
}

predics = append(predics, predictions)
}

return predics
if len(predics) == 0 {
return nil, ErrNoPredictions
}

return predics, nil
}

func (handle Model) PredictOne(query string, threshoad float32) Prediction {
r := C.FastText_PredictOne(
func (handle Model) PredictOne(query string, threshoad float32) (Prediction, error) {
r := C.FastText_Predict(
handle.p,
C.FastText_String_t{
data: C.CString(query),
data: cStr(query),
size: C.size_t(len(query)),
},
1,
C.float(threshoad),
)

if r.data == nil {
return Prediction{}, ErrPredictionFailed
}

defer C.FastText_FreePredict(r)

if r.size == 0 {
return Prediction{}, ErrNoPredictions
}

cPredic := C.FastText_PredictItemAt(r, C.size_t(0))

return Prediction{
Label: C.GoStringN(cPredic.label.data, C.int(cPredic.label.size)),
Probability: float32(cPredic.probability),
}
}, nil
}

// Perform model prediction
func (handle Model) Predict(query string, k int32, threshoad float32) Predictions {
func (handle Model) Predict(query string, k int32, threshoad float32) (Predictions, error) {
r := C.FastText_Predict(
handle.p,
C.FastText_String_t{
data: C.CString(query),
data: cStr(query),
size: C.size_t(len(query)),
},
C.int(k),
C.float(threshoad),
)

if r.data == nil {
return nil, ErrPredictionFailed
}

defer C.FastText_FreePredict(r)

if r.size == 0 {
return nil, ErrNoPredictions
}

predictions := make(Predictions, r.size)

for i := 0; i < int(r.size); i++ {
Expand All @@ -113,20 +144,25 @@ func (handle Model) Predict(query string, k int32, threshoad float32) Prediction
}
}

return predictions
return predictions, nil
}

func (handle Model) Analogy(word1, word2, word3 string, k int32) Analogs {
// cWord1 := ((*C.char) unsafe.Pointer(unsafe.StringData(word1)))

r := C.FastText_Analogy(
handle.p,
C.FastText_String_t{
data: C.CString(word1),
data: cStr(word1),
size: C.size_t(len(word1)),
},
C.FastText_String_t{
data: C.CString(word2),
data: cStr(word2),
size: C.size_t(len(word2)),
},
C.FastText_String_t{
data: C.CString(word3),
data: cStr(word3),
size: C.size_t(len(word3)),
},
C.int32_t(k),
)
Expand All @@ -151,7 +187,8 @@ func (handle Model) Wordvec(word string) []float32 {
r := C.FastText_Wordvec(
handle.p,
C.FastText_String_t{
data: C.CString(word),
data: cStr(word),
size: C.size_t(len(word)),
},
)
defer C.FastText_FreeFloatVector(r)
Expand All @@ -165,7 +202,7 @@ func (handle Model) Wordvec(word string) []float32 {
// Requires sentence ends with </s>
func (handle Model) Sentencevec(query string) []float32 {
r := C.FastText_Sentencevec(handle.p, C.FastText_String_t{
data: C.CString(query),
data: cStr(query),
size: C.size_t(len(query)),
})

Expand All @@ -176,3 +213,8 @@ func (handle Model) Sentencevec(query string) []float32 {

return vectors
}

//go:inline
func cStr(str string) *C.char {
return ((*C.char)(unsafe.Pointer(unsafe.StringData(str))))
}
Loading

0 comments on commit bd4f91f

Please sign in to comment.