diff --git a/cbits.cpp b/cbits.cpp index 91da431..8b6b174 100644 --- a/cbits.cpp +++ b/cbits.cpp @@ -79,8 +79,7 @@ FastText_Predict_t FastText_Predict(const FastText_Handle_t handle, FastText_Str std::istream in(&sbuf); auto predictions = new std::vector>(); - model->predictLine(in, reinterpret_cast> &>(predictions), k, - threshold); + model->predictLine(in, *predictions, k, threshold); free(query.data); query.data = nullptr; @@ -112,7 +111,7 @@ FastText_FloatVector_t FastText_Wordvec(const FastText_Handle_t handle, FastText int64_t dimensions = model->getDimension(); auto vec = new fasttext::Vector(dimensions); - model->getWordVector(reinterpret_cast(vec), std::string(word.data, word.size)); + model->getWordVector(*vec, std::string(word.data, word.size)); free(word.data); word.data = nullptr; @@ -133,7 +132,7 @@ FastText_FloatVector_t FastText_Sentencevec(const FastText_Handle_t handle, Fast std::istream in(&sbuf); auto vec = new fasttext::Vector(model->getDimension()); - model->getSentenceVector(in, reinterpret_cast(vec)); + model->getSentenceVector(in, *vec); free(sentence.data); sentence.data = nullptr; sentence.size = 0; @@ -163,12 +162,12 @@ FastText_PredictItem_t FastText_PredictItemAt(FastText_Predict_t predict, size_t const auto &data = vec->at(idx); auto str = FastText_String_t{ - data.second.size(), - (char *)data.second.c_str(), + data.second.size() - sizeof("__label__") + 1, + (char *)(data.second.c_str() + sizeof("__label__") - 1), }; return FastText_PredictItem_t{ - std::exp(data.first), + data.first, str, }; } diff --git a/fasttext.go b/fasttext.go index 2740305..38880de 100644 --- a/fasttext.go +++ b/fasttext.go @@ -72,11 +72,8 @@ func (handle Model) MultiLinePredict(query string, k int32, threshoad float32) [ } func (handle Model) PredictOne(query string, threshoad float32) Prediction { - cquery := C.CString(query) - defer C.free(unsafe.Pointer(cquery)) - r := C.FastText_PredictOne(handle.p, C.FastText_String_t{ - data: cquery, + data: C.CString(query), size: C.size_t(len(query)), }, C.float(threshoad)) defer C.FastText_FreePredict(r) diff --git a/fasttext_test.go b/fasttext_test.go index 721ea2b..d6cd50d 100644 --- a/fasttext_test.go +++ b/fasttext_test.go @@ -9,10 +9,24 @@ import ( func TestOpen(t *testing.T) { t.Parallel() - assert := require.New(t) _, err := fasttext.Open("testdata/lid.176.ftz") assert.NoError(err) } + + +func TestPredictOne(t *testing.T) { + + assert := require.New(t) + + model, err := fasttext.Open("testdata/lid.176.ftz") + + assert.NoError(err) + + prediction := model.PredictOne("hello world from my dear C++", 0.7) + + assert.Equal("en", prediction.Label) + assert.Greater(prediction.Probability, float32(0.7)) +}