From 5cec47dbaa3c88f9b3265fc061709024f36693c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kuba=20Podg=C3=B3rski?= Date: Sat, 28 Oct 2023 20:44:46 +0200 Subject: [PATCH] Fixing tests (#6) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fixing tests * Fix src/bpe_model_test.cc (#7) Co-authored-by: Kuba Podgórski --------- Co-authored-by: rbehjati --- src/bpe_model_test.cc | 23 ++++++++----- src/bpe_model_trainer_test.cc | 6 ++-- src/char_model_test.cc | 6 ++-- src/char_model_trainer_test.cc | 6 ++-- src/model_factory_test.cc | 8 ++--- src/model_interface_test.cc | 51 +++++++++++++++++++++-------- src/sentencepiece_processor_test.cc | 8 ++--- src/sentencepiece_trainer_test.cc | 10 +++--- src/testharness.h | 2 +- src/trainer_interface_test.cc | 15 +++++++-- src/unigram_model_test.cc | 47 ++++++++++++++++++-------- src/unigram_model_trainer_test.cc | 6 ++-- src/util_test.cc | 2 +- src/word_model_test.cc | 7 ++-- src/word_model_trainer_test.cc | 6 ++-- 15 files changed, 132 insertions(+), 71 deletions(-) diff --git a/src/bpe_model_test.cc b/src/bpe_model_test.cc index 42d40625..67f1b110 100644 --- a/src/bpe_model_test.cc +++ b/src/bpe_model_test.cc @@ -69,7 +69,7 @@ TEST(BPEModelTest, EncodeTest) { model_proto.mutable_pieces(12)->set_type( // r ModelProto::SentencePiece::USER_DEFINED); - const Model model(model_proto); + const Model model(std::make_unique(model_proto)); EncodeResult result; @@ -149,7 +149,7 @@ TEST(BPEModelTest, EncodeAmbiguousTest) { AddPiece(&model_proto, "a", -0.4); AddPiece(&model_proto, "b", -0.5); - const Model model(model_proto); + const Model model(std::make_unique(model_proto)); EncodeResult result; @@ -188,7 +188,8 @@ TEST(BPEModelTest, EncodeAmbiguousTest) { TEST(BPEModelTest, NotSupportedTest) { ModelProto model_proto = MakeBaseModelProto(); - const Model model(model_proto); + + const Model model(std::make_unique(model_proto)); EXPECT_EQ(NBestEncodeResult(), model.NBestEncode("test", 10)); } @@ -206,7 +207,7 @@ TEST(BPEModelTest, EncodeWithUnusedTest) { // No unused. { - const Model model(model_proto); + const Model model(std::make_unique(model_proto)); const auto result = model.Encode("abcd"); EXPECT_EQ(1, result.size()); EXPECT_EQ("abcd", result[0].first); @@ -214,7 +215,8 @@ TEST(BPEModelTest, EncodeWithUnusedTest) { { model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED); - const Model model(model_proto); + + const Model model(std::make_unique(model_proto)); const auto result = model.Encode("abcd"); EXPECT_EQ(2, result.size()); EXPECT_EQ("abc", result[0].first); @@ -225,7 +227,8 @@ TEST(BPEModelTest, EncodeWithUnusedTest) { // The parent rule "abc" is still alive even if the child "ab" is unused. model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED); model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::UNUSED); - const Model model(model_proto); + + const Model model(std::make_unique(model_proto)); const auto result = model.Encode("abcd"); EXPECT_EQ(2, result.size()); EXPECT_EQ("abc", result[0].first); @@ -240,7 +243,8 @@ TEST(BPEModelTest, EncodeWithUnusedTest) { model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED); model_proto.mutable_pieces(4)->set_type(ModelProto::SentencePiece::UNUSED); model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::NORMAL); - const Model model(model_proto); + + const Model model(std::make_unique(model_proto)); const auto result = model.Encode("abcd"); EXPECT_EQ(3, result.size()); EXPECT_EQ("ab", result[0].first); @@ -259,7 +263,7 @@ TEST(SampleModelTest, EncodeTest) { // No regularization { - const Model model(model_proto); + const Model model(std::make_unique(model_proto)); const auto result = model.Encode("abcd"); EXPECT_EQ(1, result.size()); EXPECT_EQ("abcd", result[0].first); @@ -275,7 +279,8 @@ TEST(SampleModelTest, EncodeTest) { return out; }; - const Model model(model_proto); + const Model model(std::make_unique(model_proto)); + const std::vector kAlpha = {0.0, 0.1, 0.5, 0.7, 0.9}; for (const auto alpha : kAlpha) { constexpr int kTrial = 100000; diff --git a/src/bpe_model_trainer_test.cc b/src/bpe_model_trainer_test.cc index 2a43c3ac..b23d70cb 100644 --- a/src/bpe_model_trainer_test.cc +++ b/src/bpe_model_trainer_test.cc @@ -67,12 +67,12 @@ std::string RunTrainer( SentencePieceProcessor processor; EXPECT_TRUE(processor.Load(model_prefix + ".model").ok()); - const auto &model = processor.model_proto(); + const auto model = processor.model_proto(); std::vector pieces; // remove , , - for (int i = 3; i < model.pieces_size(); ++i) { - pieces.emplace_back(model.pieces(i).piece()); + for (int i = 3; i < model->pieces_size(); ++i) { + pieces.emplace_back(model->pieces(i).piece()); } return absl::StrJoin(pieces, " "); diff --git a/src/char_model_test.cc b/src/char_model_test.cc index 7e082802..1399912f 100644 --- a/src/char_model_test.cc +++ b/src/char_model_test.cc @@ -49,8 +49,8 @@ void AddPiece(ModelProto *model_proto, const std::string &piece, } TEST(ModelTest, EncodeTest) { - ModelProto model_proto = MakeBaseModelProto(); + auto model_proto = MakeBaseModelProto(); AddPiece(&model_proto, WS, 0.0); AddPiece(&model_proto, "a", 0.1); AddPiece(&model_proto, "b", 0.2); @@ -60,7 +60,7 @@ TEST(ModelTest, EncodeTest) { model_proto.mutable_pieces(8)->set_type( ModelProto::SentencePiece::USER_DEFINED); - const Model model(model_proto); + const Model model(std::make_unique(model_proto)); EncodeResult result; @@ -108,7 +108,7 @@ TEST(ModelTest, EncodeTest) { TEST(CharModelTest, NotSupportedTest) { ModelProto model_proto = MakeBaseModelProto(); - const Model model(model_proto); + const Model model(std::make_unique(model_proto)); EXPECT_EQ(NBestEncodeResult(), model.NBestEncode("test", 10)); EXPECT_EQ(EncodeResult(), model.SampleEncode("test", 0.1)); } diff --git a/src/char_model_trainer_test.cc b/src/char_model_trainer_test.cc index e8b49796..90266c3a 100644 --- a/src/char_model_trainer_test.cc +++ b/src/char_model_trainer_test.cc @@ -59,12 +59,12 @@ std::string RunTrainer(const std::vector &input, int size) { SentencePieceProcessor processor; EXPECT_TRUE(processor.Load(model_prefix + ".model").ok()); - const auto &model = processor.model_proto(); + const auto model = processor.model_proto(); std::vector pieces; // remove , , - for (int i = 3; i < model.pieces_size(); ++i) { - pieces.emplace_back(model.pieces(i).piece()); + for (int i = 3; i < model->pieces_size(); ++i) { + pieces.emplace_back(model->pieces(i).piece()); } return absl::StrJoin(pieces, " "); diff --git a/src/model_factory_test.cc b/src/model_factory_test.cc index 04e97fcd..73d47d89 100644 --- a/src/model_factory_test.cc +++ b/src/model_factory_test.cc @@ -37,22 +37,22 @@ TEST(ModelFactoryTest, BasicTest) { { model_proto.mutable_trainer_spec()->set_model_type(TrainerSpec::UNIGRAM); - auto m = ModelFactory::Create(model_proto); + auto m = ModelFactory::Create(std::make_unique(model_proto)); } { model_proto.mutable_trainer_spec()->set_model_type(TrainerSpec::BPE); - auto m = ModelFactory::Create(model_proto); + auto m = ModelFactory::Create(std::make_unique(model_proto)); } { model_proto.mutable_trainer_spec()->set_model_type(TrainerSpec::WORD); - auto m = ModelFactory::Create(model_proto); + auto m = ModelFactory::Create(std::make_unique(model_proto)); } { model_proto.mutable_trainer_spec()->set_model_type(TrainerSpec::CHAR); - auto m = ModelFactory::Create(model_proto); + auto m = ModelFactory::Create(std::make_unique(model_proto)); } } } // namespace sentencepiece diff --git a/src/model_interface_test.cc b/src/model_interface_test.cc index f209b3aa..ca044d46 100644 --- a/src/model_interface_test.cc +++ b/src/model_interface_test.cc @@ -72,7 +72,9 @@ TEST(ModelInterfaceTest, GetDefaultPieceTest) { { ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); AddPiece(&model_proto, "a"); - auto model = ModelFactory::Create(model_proto); + + auto model = ModelFactory::Create(std::make_unique(model_proto)); + EXPECT_EQ("", model->unk_piece()); EXPECT_EQ("", model->bos_piece()); EXPECT_EQ("", model->eos_piece()); @@ -86,7 +88,9 @@ TEST(ModelInterfaceTest, GetDefaultPieceTest) { model_proto.mutable_trainer_spec()->clear_bos_piece(); model_proto.mutable_trainer_spec()->clear_eos_piece(); model_proto.mutable_trainer_spec()->clear_pad_piece(); - auto model = ModelFactory::Create(model_proto); + + auto model = ModelFactory::Create(std::make_unique(model_proto)); + EXPECT_EQ("", model->unk_piece()); EXPECT_EQ("", model->bos_piece()); EXPECT_EQ("", model->eos_piece()); @@ -100,7 +104,9 @@ TEST(ModelInterfaceTest, GetDefaultPieceTest) { model_proto.mutable_trainer_spec()->set_bos_piece("BOS"); model_proto.mutable_trainer_spec()->set_eos_piece("EOS"); model_proto.mutable_trainer_spec()->set_pad_piece("PAD"); - auto model = ModelFactory::Create(model_proto); + + auto model = ModelFactory::Create(std::make_unique(model_proto)); + EXPECT_EQ("UNK", model->unk_piece()); EXPECT_EQ("BOS", model->bos_piece()); EXPECT_EQ("EOS", model->eos_piece()); @@ -116,7 +122,8 @@ TEST(ModelInterfaceTest, SetModelInterfaceTest) { AddPiece(&model_proto, "c"); AddPiece(&model_proto, "d"); - auto model = ModelFactory::Create(model_proto); + auto model = ModelFactory::Create(std::make_unique(model_proto)); + EXPECT_EQ(model_proto.SerializeAsString(), model->model_proto().SerializeAsString()); } @@ -135,7 +142,7 @@ TEST(ModelInterfaceTest, PieceToIdTest) { model_proto.mutable_pieces(7)->set_type( ModelProto::SentencePiece::USER_DEFINED); - auto model = ModelFactory::Create(model_proto); + auto model = ModelFactory::Create(std::make_unique(model_proto)); EXPECT_EQ(model_proto.SerializeAsString(), model->model_proto().SerializeAsString()); @@ -212,7 +219,9 @@ TEST(ModelInterfaceTest, InvalidModelTest) { { ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); AddPiece(&model_proto, ""); - auto model = ModelFactory::Create(model_proto); + + auto model = ModelFactory::Create(std::make_unique(model_proto)); + EXPECT_FALSE(model->status().ok()); } @@ -221,7 +230,9 @@ TEST(ModelInterfaceTest, InvalidModelTest) { ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); AddPiece(&model_proto, "a"); AddPiece(&model_proto, "a"); - auto model = ModelFactory::Create(model_proto); + + auto model = ModelFactory::Create(std::make_unique(model_proto)); + EXPECT_FALSE(model->status().ok()); } @@ -229,7 +240,9 @@ TEST(ModelInterfaceTest, InvalidModelTest) { { ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); model_proto.mutable_pieces(1)->set_type(ModelProto::SentencePiece::UNKNOWN); - auto model = ModelFactory::Create(model_proto); + + auto model = ModelFactory::Create(std::make_unique(model_proto)); + EXPECT_FALSE(model->status().ok()); } @@ -237,7 +250,9 @@ TEST(ModelInterfaceTest, InvalidModelTest) { { ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); model_proto.mutable_pieces(0)->set_type(ModelProto::SentencePiece::CONTROL); - auto model = ModelFactory::Create(model_proto); + + auto model = ModelFactory::Create(std::make_unique(model_proto)); + EXPECT_FALSE(model->status().ok()); } } @@ -249,7 +264,9 @@ TEST(ModelInterfaceTest, ByteFallbackModelTest) { AddBytePiece(&model_proto, i); } AddPiece(&model_proto, "a"); - auto model = ModelFactory::Create(model_proto); + + auto model = ModelFactory::Create(std::make_unique(model_proto)); + EXPECT_TRUE(model->status().ok()); } @@ -260,7 +277,9 @@ TEST(ModelInterfaceTest, ByteFallbackModelTest) { AddBytePiece(&model_proto, i); } AddPiece(&model_proto, "a"); - auto model = ModelFactory::Create(model_proto); + + auto model = ModelFactory::Create(std::make_unique(model_proto)); + EXPECT_FALSE(model->status().ok()); } @@ -271,7 +290,9 @@ TEST(ModelInterfaceTest, ByteFallbackModelTest) { AddBytePiece(&model_proto, i); } AddPiece(&model_proto, "a"); - auto model = ModelFactory::Create(model_proto); + + auto model = ModelFactory::Create(std::make_unique(model_proto)); + EXPECT_FALSE(model->status().ok()); } } @@ -307,7 +328,8 @@ TEST(ModelInterfaceTest, PieceToIdStressTest) { AddPiece(&model_proto, piece); } - auto model = ModelFactory::Create(model_proto); + auto model = ModelFactory::Create(std::make_unique(model_proto)); + for (const auto &it : expected_p2i) { EXPECT_EQ(it.second, model->PieceToId(it.first)); } @@ -505,7 +527,8 @@ TEST(ModelInterfaceTest, VerifyOutputsEquivalent) { ModelProto model_proto = MakeBaseModelProto(type); AddPiece(&model_proto, "a", 1.0); AddPiece(&model_proto, "b", 2.0); - auto model = ModelFactory::Create(model_proto); + + auto model = ModelFactory::Create(std::make_unique(model_proto)); // Equivalent outputs. EXPECT_TRUE(model->VerifyOutputsEquivalent("", "")); diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index 4f2f33be..392b7c98 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -969,7 +969,7 @@ TEST(SentencePieceProcessorTest, LoadSerializedProtoTest) { EXPECT_FALSE(sp.LoadFromSerializedProto("__NOT_A_PROTO__").ok()); EXPECT_TRUE(sp.LoadFromSerializedProto(model_proto.SerializeAsString()).ok()); EXPECT_EQ(model_proto.SerializeAsString(), - sp.model_proto().SerializeAsString()); + sp.model_proto()->SerializeAsString()); } TEST(SentencePieceProcessorTest, EndToEndTest) { @@ -1004,7 +1004,7 @@ TEST(SentencePieceProcessorTest, EndToEndTest) { sp.Load(util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model")).ok()); EXPECT_EQ(model_proto.SerializeAsString(), - sp.model_proto().SerializeAsString()); + sp.model_proto()->SerializeAsString()); EXPECT_EQ(8, sp.GetPieceSize()); EXPECT_EQ(0, sp.PieceToId("")); @@ -1273,7 +1273,7 @@ TEST(SentencePieceProcessorTest, EndToEndTest) { auto RunTest = [&model_proto](const SentencePieceProcessor &sp) { EXPECT_EQ(model_proto.SerializeAsString(), - sp.model_proto().SerializeAsString()); + sp.model_proto()->SerializeAsString()); EXPECT_EQ(8, sp.GetPieceSize()); EXPECT_EQ(0, sp.PieceToId("")); @@ -1351,7 +1351,7 @@ TEST(SentencePieceProcessorTest, EndToEndTest) { const ModelProto *moved_ptr = moved.get(); *moved = model_proto; EXPECT_TRUE(sp.Load(std::move(moved)).ok()); - EXPECT_EQ(moved_ptr, &sp.model_proto()); + EXPECT_EQ(moved_ptr, sp.model_proto()); RunTest(sp); } diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc index dd32f28f..e3801b8d 100644 --- a/src/sentencepiece_trainer_test.cc +++ b/src/sentencepiece_trainer_test.cc @@ -31,17 +31,17 @@ static constexpr char kIdsDenormTsv[] = "ids_denorm.tsv"; void CheckVocab(absl::string_view filename, int expected_vocab_size) { SentencePieceProcessor sp; ASSERT_TRUE(sp.Load(filename.data()).ok()); - EXPECT_EQ(expected_vocab_size, sp.model_proto().trainer_spec().vocab_size()); - EXPECT_EQ(sp.model_proto().pieces_size(), - sp.model_proto().trainer_spec().vocab_size()); + EXPECT_EQ(expected_vocab_size, sp.model_proto()->trainer_spec().vocab_size()); + EXPECT_EQ(sp.model_proto()->pieces_size(), + sp.model_proto()->trainer_spec().vocab_size()); } void CheckNormalizer(absl::string_view filename, bool expected_has_normalizer, bool expected_has_denormalizer) { SentencePieceProcessor sp; ASSERT_TRUE(sp.Load(filename.data()).ok()); - const auto &normalizer_spec = sp.model_proto().normalizer_spec(); - const auto &denormalizer_spec = sp.model_proto().denormalizer_spec(); + const auto &normalizer_spec = sp.model_proto()->normalizer_spec(); + const auto &denormalizer_spec = sp.model_proto()->denormalizer_spec(); EXPECT_EQ(!normalizer_spec.precompiled_charsmap().empty(), expected_has_normalizer); EXPECT_EQ(!denormalizer_spec.precompiled_charsmap().empty(), diff --git a/src/testharness.h b/src/testharness.h index e98d4eec..5bd67164 100644 --- a/src/testharness.h +++ b/src/testharness.h @@ -198,7 +198,7 @@ std::vector ValuesIn(const std::vector &v) { } \ ParamType param_; \ void SetParam(const ParamType ¶m) { param_ = param; } \ - const ParamType GetParam() { return param_; } \ + ParamType GetParam() const { return param_; } \ void _Run(); \ static void _RunIt() { \ TCONCAT(base, _Test_p_, name) t; \ diff --git a/src/trainer_interface_test.cc b/src/trainer_interface_test.cc index 54e62d79..33522354 100644 --- a/src/trainer_interface_test.cc +++ b/src/trainer_interface_test.cc @@ -456,7 +456,10 @@ TEST(TrainerInterfaceTest, SerializeTest) { { trainer_spec.set_vocab_size(10); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); - trainer.final_pieces_ = final_pieces; + + trainer.final_pieces_.resize(final_pieces.size()); + copy(final_pieces.begin(), final_pieces.end(), trainer.final_pieces_.begin()); + ModelProto model_proto; EXPECT_FALSE(trainer.Serialize(&model_proto).ok()); } @@ -465,7 +468,10 @@ TEST(TrainerInterfaceTest, SerializeTest) { trainer_spec.set_vocab_size(10); trainer_spec.set_hard_vocab_limit(false); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); - trainer.final_pieces_ = final_pieces; + + trainer.final_pieces_.resize(final_pieces.size()); + copy(final_pieces.begin(), final_pieces.end(), trainer.final_pieces_.begin()); + ModelProto model_proto; EXPECT_TRUE(trainer.Serialize(&model_proto).ok()); EXPECT_EQ(6, model_proto.trainer_spec().vocab_size()); @@ -480,7 +486,10 @@ TEST(TrainerInterfaceTest, SerializeTest) { trainer_spec.set_model_type(TrainerSpec::CHAR); trainer_spec.set_hard_vocab_limit(true); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); - trainer.final_pieces_ = final_pieces; + + trainer.final_pieces_.resize(final_pieces.size()); + copy(final_pieces.begin(), final_pieces.end(), trainer.final_pieces_.begin()); + ModelProto model_proto; EXPECT_TRUE(trainer.Serialize(&model_proto).ok()); EXPECT_EQ(6, model_proto.trainer_spec().vocab_size()); diff --git a/src/unigram_model_test.cc b/src/unigram_model_test.cc index 21cbec34..c6a721ad 100644 --- a/src/unigram_model_test.cc +++ b/src/unigram_model_test.cc @@ -511,7 +511,8 @@ TEST(UnigramModelTest, SetUnigramModelTest) { AddPiece(&model_proto, "c"); AddPiece(&model_proto, "d"); - const Model model(model_proto); + Model model(std::make_unique(model_proto)); + EXPECT_EQ(model_proto.SerializeAsString(), model.model_proto().SerializeAsString()); } @@ -526,7 +527,8 @@ TEST(UnigramModelTest, SampleEncodeAndScoreTest) { AddPiece(&model_proto, "BC", 0.5); // 7 AddPiece(&model_proto, "ABC", 1.0); // 8 - Model model(model_proto); + Model model(std::make_unique(model_proto)); + Lattice lattice; lattice.SetSentence("ABC"); @@ -620,7 +622,8 @@ TEST_P(UnigramModelTest, PieceToIdTest) { AddPiece(&model_proto, "c", 0.3); AddPiece(&model_proto, "d", 0.4); - Model model(model_proto); + Model model(std::make_unique(model_proto)); + model.SetEncoderVersion(encoder_version_); EXPECT_EQ(model_proto.SerializeAsString(), @@ -678,7 +681,9 @@ TEST_P(UnigramModelTest, PieceToIdTest) { TEST_P(UnigramModelTest, PopulateNodesAllUnknownsTest) { ModelProto model_proto = MakeBaseModelProto(); AddPiece(&model_proto, "x"); - Model model(model_proto); + + Model model(std::make_unique(model_proto)); + model.SetEncoderVersion(encoder_version_); Lattice lattice; @@ -702,7 +707,9 @@ TEST_P(UnigramModelTest, PopulateNodesTest) { AddPiece(&model_proto, "ab", 0.3); // 5 AddPiece(&model_proto, "bc", 0.4); // 6 - Model model(model_proto); + + Model model(std::make_unique(model_proto)); + model.SetEncoderVersion(encoder_version_); Lattice lattice; @@ -737,7 +744,9 @@ TEST_P(UnigramModelTest, PopulateNodesWithUnusedTest) { model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::UNUSED); model_proto.mutable_pieces(6)->set_type(ModelProto::SentencePiece::UNUSED); - Model model(model_proto); + + Model model(std::make_unique(model_proto)); + model.SetEncoderVersion(encoder_version_); Lattice lattice; @@ -762,7 +771,8 @@ TEST_P(UnigramModelTest, ModelNBestTest) { AddPiece(&model_proto, "bc", 5.0); // 7 AddPiece(&model_proto, "abc", 10.0); // 8 - Model model(model_proto); + Model model(std::make_unique(model_proto)); + model.SetEncoderVersion(encoder_version_); auto nbest = model.NBestEncode("", 10); @@ -801,7 +811,8 @@ TEST_P(UnigramModelTest, EncodeTest) { model_proto.mutable_pieces(12)->set_type( // r ModelProto::SentencePiece::USER_DEFINED); - Model model(model_proto); + Model model(std::make_unique(model_proto)); + model.SetEncoderVersion(encoder_version_); EncodeResult result; @@ -884,7 +895,9 @@ TEST_P(UnigramModelTest, EncodeWithUnusedTest) { // No unused. { - Model model(model_proto); + + Model model(std::make_unique(model_proto)); + model.SetEncoderVersion(encoder_version_); const auto result = model.Encode("abcd"); EXPECT_EQ(1, result.size()); @@ -893,7 +906,9 @@ TEST_P(UnigramModelTest, EncodeWithUnusedTest) { { model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED); - Model model(model_proto); + + Model model(std::make_unique(model_proto)); + model.SetEncoderVersion(encoder_version_); const auto result = model.Encode("abcd"); EXPECT_EQ(2, result.size()); @@ -904,7 +919,9 @@ TEST_P(UnigramModelTest, EncodeWithUnusedTest) { { model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED); model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::UNUSED); - Model model(model_proto); + + Model model(std::make_unique(model_proto)); + model.SetEncoderVersion(encoder_version_); const auto result = model.Encode("abcd"); EXPECT_EQ(2, result.size()); @@ -918,7 +935,9 @@ TEST_P(UnigramModelTest, EncodeWithUnusedTest) { model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED); model_proto.mutable_pieces(4)->set_type(ModelProto::SentencePiece::UNUSED); model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::NORMAL); - Model model(model_proto); + + Model model(std::make_unique(model_proto)); + model.SetEncoderVersion(encoder_version_); const auto result = model.Encode("abcd"); EXPECT_EQ(2, result.size()); @@ -938,7 +957,9 @@ TEST_P(UnigramModelTest, VerifyOutputsEquivalent) { AddPiece(&model_proto, "b", 1.9); // 8 AddPiece(&model_proto, "c", 2.0); // 9 AddPiece(&model_proto, "d", 1.0); // 10 - Model model(model_proto); + + Model model(std::make_unique(model_proto)); + model.SetEncoderVersion(encoder_version_); // Equivalent outputs. EXPECT_TRUE(model.VerifyOutputsEquivalent("", "")); diff --git a/src/unigram_model_trainer_test.cc b/src/unigram_model_trainer_test.cc index ba1d57e3..3f0878e4 100644 --- a/src/unigram_model_trainer_test.cc +++ b/src/unigram_model_trainer_test.cc @@ -96,11 +96,11 @@ TrainerResult RunTrainer(const std::vector& input, int size, SentencePieceProcessor processor; EXPECT_TRUE(processor.Load(model_prefix + ".model").ok()); - const auto& model = processor.model_proto(); + const auto model = processor.model_proto(); // remove , , - for (int i = 3; i < model.pieces_size(); ++i) { - pieces.emplace_back(model.pieces(i).piece()); + for (int i = 3; i < model->pieces_size(); ++i) { + pieces.emplace_back(model->pieces(i).piece()); } } diff --git a/src/util_test.cc b/src/util_test.cc index 407b038b..07d3004b 100644 --- a/src/util_test.cc +++ b/src/util_test.cc @@ -416,7 +416,7 @@ TEST(UtilTest, JoinPathTest) { TEST(UtilTest, ReservoirSamplerTest) { std::vector sampled; - random::ReservoirSampler sampler(&sampled, 100); + random::ReservoirSampler sampler(&sampled, uint64(100)); for (int i = 0; i < 10000; ++i) { sampler.Add(i); } diff --git a/src/word_model_test.cc b/src/word_model_test.cc index aefb1748..8295888f 100644 --- a/src/word_model_test.cc +++ b/src/word_model_test.cc @@ -60,7 +60,8 @@ TEST(WordModelTest, EncodeTest) { AddPiece(&model_proto, WS "c", 0.3); AddPiece(&model_proto, WS "d", 0.4); - const Model model(model_proto); + const Model model(std::make_unique(model_proto)); + EncodeResult result; @@ -82,7 +83,9 @@ TEST(WordModelTest, EncodeTest) { TEST(WordModelTest, NotSupportedTest) { ModelProto model_proto = MakeBaseModelProto(); - const Model model(model_proto); + + const Model model(std::make_unique(model_proto)); + EXPECT_EQ(NBestEncodeResult(), model.NBestEncode("test", 10)); EXPECT_EQ(EncodeResult(), model.SampleEncode("test", 0.1)); } diff --git a/src/word_model_trainer_test.cc b/src/word_model_trainer_test.cc index 366810f4..9211a14d 100644 --- a/src/word_model_trainer_test.cc +++ b/src/word_model_trainer_test.cc @@ -60,12 +60,12 @@ std::string RunTrainer(const std::vector &input, int size) { SentencePieceProcessor processor; EXPECT_TRUE(processor.Load(model_prefix + ".model").ok()); - const auto &model = processor.model_proto(); + const auto model = processor.model_proto(); std::vector pieces; // remove , , - for (int i = 3; i < model.pieces_size(); ++i) { - pieces.emplace_back(model.pieces(i).piece()); + for (int i = 3; i < model->pieces_size(); ++i) { + pieces.emplace_back(model->pieces(i).piece()); } return absl::StrJoin(pieces, " ");