From 96aabaef967e1b84c3a47dd6c7d54f72e13f5093 Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Sat, 23 Dec 2023 09:28:40 +0000 Subject: [PATCH] add set_min_log_level function to python to change the loglevel from python wrapper. --- python/setup.py | 10 +- python/src/sentencepiece/__init__.py | 4 + python/src/sentencepiece/sentencepiece.i | 1 + .../src/sentencepiece/sentencepiece_wrap.cxx | 31 +++ python/test/sentencepiece_test.py | 192 ++++++++++++------ src/common.h | 8 - src/sentencepiece_processor.h | 10 +- src/util.cc | 2 + 8 files changed, 177 insertions(+), 81 deletions(-) diff --git a/python/setup.py b/python/setup.py index 54112313..d600321c 100755 --- a/python/setup.py +++ b/python/setup.py @@ -14,14 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License.! -from setuptools import setup, Extension -from setuptools.command.build_ext import build_ext as _build_ext -from setuptools.command.build_py import build_py as _build_py import codecs +import os import string import subprocess import sys -import os +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext as _build_ext +from setuptools.command.build_py import build_py as _build_py sys.path.append(os.path.join('.', 'test')) @@ -94,6 +94,8 @@ def build_extension(self, ext): else: cflags.append('-Wl,-strip-all') libs.append('-Wl,-strip-all') + if sys.platform == 'linux': + libs.append('-Wl,-Bsymbolic') print('## cflags={}'.format(' '.join(cflags))) print('## libs={}'.format(' '.join(libs))) ext.extra_compile_args = cflags diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py index 6040e7bb..2bfd645e 100644 --- a/python/src/sentencepiece/__init__.py +++ b/python/src/sentencepiece/__init__.py @@ -904,6 +904,9 @@ def Load(self, model_file=None, model_proto=None): def SetRandomGeneratorSeed(seed): return _sentencepiece.SetRandomGeneratorSeed(seed) + +def SetMinLogLevel(v): + return _sentencepiece.SetMinLogLevel(v) class SentencePieceTrainer(object): thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag") @@ -1039,6 +1042,7 @@ def _batched_func(self, arg): _add_snake_case(SentencePieceProcessor) _add_snake_case(SentencePieceTrainer) set_random_generator_seed = SetRandomGeneratorSeed +set_min_log_level = SetMinLogLevel from ._version import __version__ diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i index bef8298e..5b28abc5 100644 --- a/python/src/sentencepiece/sentencepiece.i +++ b/python/src/sentencepiece/sentencepiece.i @@ -1771,6 +1771,7 @@ for m in [ _add_snake_case(SentencePieceProcessor) _add_snake_case(SentencePieceTrainer) set_random_generator_seed = SetRandomGeneratorSeed +set_min_log_level = SetMinLogLevel from ._version import __version__ diff --git a/python/src/sentencepiece/sentencepiece_wrap.cxx b/python/src/sentencepiece/sentencepiece_wrap.cxx index 8e831d67..753b2e2c 100644 --- a/python/src/sentencepiece/sentencepiece_wrap.cxx +++ b/python/src/sentencepiece/sentencepiece_wrap.cxx @@ -8429,6 +8429,36 @@ SWIGINTERN PyObject *_wrap_SetRandomGeneratorSeed(PyObject *self, PyObject *args } +SWIGINTERN PyObject *_wrap_SetMinLogLevel(PyObject *self, PyObject *args) { + PyObject *resultobj = 0; + int arg1 ; + int val1 ; + int ecode1 = 0 ; + PyObject *swig_obj[1] ; + + if (!args) SWIG_fail; + swig_obj[0] = args; + ecode1 = SWIG_AsVal_int(swig_obj[0], &val1); + if (!SWIG_IsOK(ecode1)) { + SWIG_exception_fail(SWIG_ArgError(ecode1), "in method '" "SetMinLogLevel" "', argument " "1"" of type '" "int""'"); + } + arg1 = static_cast< int >(val1); + { + try { + sentencepiece::SetMinLogLevel(arg1); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + resultobj = SWIG_Py_Void(); + return resultobj; +fail: + return NULL; +} + + SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromString(PyObject *self, PyObject *args) { PyObject *resultobj = 0; absl::string_view arg1 ; @@ -8800,6 +8830,7 @@ static PyMethodDef SwigMethods[] = { { "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL}, { "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL}, { "SetRandomGeneratorSeed", _wrap_SetRandomGeneratorSeed, METH_O, NULL}, + { "SetMinLogLevel", _wrap_SetMinLogLevel, METH_O, NULL}, { "SentencePieceTrainer__TrainFromString", _wrap_SentencePieceTrainer__TrainFromString, METH_O, NULL}, { "SentencePieceTrainer__TrainFromMap", _wrap_SentencePieceTrainer__TrainFromMap, METH_O, NULL}, { "SentencePieceTrainer__TrainFromMap2", _wrap_SentencePieceTrainer__TrainFromMap2, METH_VARARGS, NULL}, diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py index 2b9ad282..adbc607e 100755 --- a/python/test/sentencepiece_test.py +++ b/python/test/sentencepiece_test.py @@ -15,14 +15,13 @@ # See the License for the specific language governing permissions and # limitations under the License.! +from collections import defaultdict import io -import sentencepiece as spm -import unittest -import sys import os import pickle - -from collections import defaultdict +import sys +import unittest +import sentencepiece as spm print('VERSION={}'.format(spm.__version__)) @@ -39,7 +38,8 @@ def setUp(self): self.jasp_ = spm.SentencePieceProcessor() self.assertTrue(self.sp_.Load(os.path.join('test', 'test_model.model'))) self.assertTrue( - self.jasp_.Load(os.path.join('test', 'test_ja_model.model'))) + self.jasp_.Load(os.path.join('test', 'test_ja_model.model')) + ) with open(os.path.join('test', 'test_model.model'), 'rb') as f: self.assertTrue(self.sp_.LoadFromSerializedProto(f.read())) with open(os.path.join('test', 'test_ja_model.model'), 'rb') as f: @@ -83,14 +83,18 @@ def test_roundtrip(self): for n in range(100): self.assertEqual( text, - self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, 64, 0.5))) + self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, 64, 0.5)), + ) self.assertEqual( text, - self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, -1, 0.5))) + self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, -1, 0.5)), + ) self.assertEqual( - text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, 64, 0.5))) + text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, 64, 0.5)) + ) self.assertEqual( - text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, -1, 0.5))) + text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, -1, 0.5)) + ) ids2 = self.sp_.encode_as_ids(text) pieces3 = self.sp_.encode_as_pieces(text) @@ -104,21 +108,28 @@ def test_roundtrip(self): self.assertEqual( text, self.sp_.decode_pieces( - self.sp_.sample_encode_as_pieces(text, 64, 0.5))) + self.sp_.sample_encode_as_pieces(text, 64, 0.5) + ), + ) self.assertEqual( text, self.sp_.decode_pieces( - self.sp_.sample_encode_as_pieces(text, -1, 0.5))) + self.sp_.sample_encode_as_pieces(text, -1, 0.5) + ), + ) self.assertEqual( text, - self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, 64, 0.5))) + self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, 64, 0.5)), + ) self.assertEqual( text, - self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, -1, 0.5))) + self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, -1, 0.5)), + ) self.assertEqual( self.sp_.calculate_entropy(text, 0.1), - self.sp_.CalculateEntropy(text, 0.1)) + self.sp_.CalculateEntropy(text, 0.1), + ) def test_ja_load(self): self.assertEqual(8000, self.jasp_.GetPieceSize()) @@ -155,11 +166,15 @@ def test_ja_roundtrip(self): self.assertEqual( text, self.jasp_.DecodePieces( - self.jasp_.SampleEncodeAsPieces(text, 64, 0.5))) + self.jasp_.SampleEncodeAsPieces(text, 64, 0.5) + ), + ) self.assertEqual( text, self.jasp_.DecodePieces( - self.jasp_.SampleEncodeAsPieces(text, -1, 0.5))) + self.jasp_.SampleEncodeAsPieces(text, -1, 0.5) + ), + ) ids2 = self.jasp_.encode_as_ids(text) pieces3 = self.jasp_.encode_as_pieces(text) @@ -173,20 +188,27 @@ def test_ja_roundtrip(self): self.assertEqual( text, self.jasp_.decode_pieces( - self.jasp_.sample_encode_as_pieces(text, 64, 0.5))) + self.jasp_.sample_encode_as_pieces(text, 64, 0.5) + ), + ) self.assertEqual( text, self.jasp_.decode_pieces( - self.jasp_.sample_encode_as_pieces(text, -1, 0.5))) + self.jasp_.sample_encode_as_pieces(text, -1, 0.5) + ), + ) self.assertEqual( self.jasp_.calculate_entropy(text, 0.1), - self.jasp_.CalculateEntropy(text, 0.1)) + self.jasp_.CalculateEntropy(text, 0.1), + ) def test_train(self): - spm.SentencePieceTrainer.Train('--input=' + - os.path.join(data_dir, 'botchan.txt') + - ' --model_prefix=m --vocab_size=1000') + spm.SentencePieceTrainer.Train( + '--input=' + + os.path.join(data_dir, 'botchan.txt') + + ' --model_prefix=m --vocab_size=1000' + ) sp = spm.SentencePieceProcessor() sp.Load('m.model') with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file: @@ -195,9 +217,11 @@ def test_train(self): sp.DecodeIds(sp.EncodeAsIds(line)) def test_train_iterator(self): - spm.SentencePieceTrainer.Train('--input=' + - os.path.join(data_dir, 'botchan.txt') + - ' --model_prefix=m --vocab_size=1000') + spm.SentencePieceTrainer.Train( + '--input=' + + os.path.join(data_dir, 'botchan.txt') + + ' --model_prefix=m --vocab_size=1000' + ) # Load as 'rb' for Python3.5/2.7. os1 = io.BytesIO() os2 = io.BytesIO() @@ -207,32 +231,38 @@ def test_train_iterator(self): input=os.path.join(data_dir, 'botchan.txt'), model_prefix='m', vocab_size=1000, - logstream=open(os.devnull, 'w')) + logstream=open(os.devnull, 'w'), + ) with open(os.path.join(data_dir, 'botchan.txt'), 'rb') as is1: spm.SentencePieceTrainer.train( sentence_iterator=is1, model_prefix='m', vocab_size=1000, - logstream=open(os.devnull, 'w')) + logstream=open(os.devnull, 'w'), + ) spm.SentencePieceTrainer.train( input=os.path.join(data_dir, 'botchan.txt'), model_writer=os1, vocab_size=1000, - logstream=open(os.devnull, 'w')) + logstream=open(os.devnull, 'w'), + ) with open(os.path.join(data_dir, 'botchan.txt'), 'rb') as is2: spm.SentencePieceTrainer.train( sentence_iterator=is2, model_writer=os2, vocab_size=1000, - logstream=open(os.devnull, 'w')) + logstream=open(os.devnull, 'w'), + ) sp1 = spm.SentencePieceProcessor(model_proto=os1.getvalue()) sp2 = spm.SentencePieceProcessor(model_proto=os2.getvalue()) - self.assertEqual([sp1.id_to_piece(i) for i in range(sp1.get_piece_size())], - [sp2.id_to_piece(i) for i in range(sp2.get_piece_size())]) + self.assertEqual( + [sp1.id_to_piece(i) for i in range(sp1.get_piece_size())], + [sp2.id_to_piece(i) for i in range(sp2.get_piece_size())], + ) def test_train_kwargs(self): # suppress logging (redirect to /dev/null) @@ -241,7 +271,8 @@ def test_train_kwargs(self): model_prefix='m', vocab_size=1002, user_defined_symbols=['foo', 'bar', ',', ' ', '\t', '\b', '\n', '\r'], - logstream=open(os.devnull, 'w')) + logstream=open(os.devnull, 'w'), + ) sp = spm.SentencePieceProcessor() sp.Load('m.model') with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file: @@ -268,7 +299,8 @@ def test_serialized_proto(self): y1 = self.sp_.encode(text, out_type='serialized_proto') y2 = self.sp_.encode( - text, enable_sampling=True, out_type='serialized_proto') + text, enable_sampling=True, out_type='serialized_proto' + ) y3 = self.sp_.nbest_encode(text, out_type='serialized_proto', nbest_size=10) y4 = self.sp_.decode(['foo', 'bar'], out_type='serialized_proto') y5 = self.sp_.decode([20, 30], out_type='serialized_proto') @@ -372,7 +404,7 @@ def test_immutable_proto(self): self.assertEqual([x.piece for x in s1.pieces], v2) self.assertEqual(text, s1.text) - surfaces1 = [s1.text[x.begin:x.end] for x in s1.pieces] + surfaces1 = [s1.text[x.begin : x.end] for x in s1.pieces] surfaces2 = [x.surface for x in s1.pieces] self.assertEqual(surfaces1, surfaces2) @@ -393,15 +425,18 @@ def test_immutable_proto(self): for i in range(len(s3.nbests)): self.assertEqual(text, s3.nbests[i].text) self.assertEqual( - self.sp_.Decode([x.id for x in s3.nbests[i].pieces]), text) + self.sp_.Decode([x.id for x in s3.nbests[i].pieces]), text + ) # slice self.assertEqual(s1.pieces[::-1], list(reversed(s1.pieces))) self.assertEqual(s3.nbests[::-1], list(reversed(s3.nbests))) # Japanese offset - s1 = self.jasp_.EncodeAsImmutableProto('吾輩は猫である。Hello world. ABC 123') - surfaces1 = [s1.text[x.begin:x.end] for x in s1.pieces] + s1 = self.jasp_.EncodeAsImmutableProto( + '吾輩は猫である。Hello world. ABC 123' + ) + surfaces1 = [s1.text[x.begin : x.end] for x in s1.pieces] surfaces2 = [x.surface for x in s1.pieces] self.assertEqual(surfaces1, surfaces2) @@ -415,7 +450,8 @@ def test_immutable_proto(self): def test_new_api(self): sp = spm.SentencePieceProcessor( - model_file=os.path.join('test', 'test_model.model')) + model_file=os.path.join('test', 'test_model.model') + ) text = 'hello world' text2 = 'Tokyo' ids = self.sp_.EncodeAsIds(text) @@ -512,7 +548,8 @@ def test_new_api_init(self): model_file=os.path.join('test', 'test_model.model'), add_bos=True, add_eos=True, - out_type=str) + out_type=str, + ) text = 'hello world' pieces = [''] + self.sp_.EncodeAsPieces(text) + [''] self.assertEqual(pieces, sp.encode(text)) @@ -540,13 +577,17 @@ def test_sampling(self): ++ids2[out] self.assertEqual(len(ids2), 1) - out = sp.encode(['hello world', 'this is a test'], - out_type=out_type, - enable_sampling=True) + out = sp.encode( + ['hello world', 'this is a test'], + out_type=out_type, + enable_sampling=True, + ) self.assertEqual(len(out), 2) - out = sp.encode(['hello world', 'this is a test'], - out_type=out_type, - enable_sampling=False) + out = sp.encode( + ['hello world', 'this is a test'], + out_type=out_type, + enable_sampling=False, + ) self.assertEqual(len(out), 2) def test_nbest(self): @@ -556,8 +597,9 @@ def test_nbest(self): for out_type in [str, int, 'serialized_proto', 'immutable_proto']: results = sp.nbest_encode(text, nbest_size=10, out_type=out_type) - self.assertEqual(results, - sp.NBestEncode(text, nbest_size=10, out_type=out_type)) + self.assertEqual( + results, sp.NBestEncode(text, nbest_size=10, out_type=out_type) + ) if out_type in [str, int]: for n in results: @@ -570,7 +612,8 @@ def test_nbest(self): results = sp.nbest_encode([text, text2], nbest_size=10, out_type=out_type) self.assertEqual( results, - sp.NBestEncode([text, text2], nbest_size=10, out_type=out_type)) + sp.NBestEncode([text, text2], nbest_size=10, out_type=out_type), + ) self.assertEqual(len(results), 2) if out_type in [str, int]: @@ -591,16 +634,20 @@ def test_nbest(self): self.assertEqual( sp.nbest_encode(text, nbest_size=10, out_type=str), - sp.nbest_encode_as_pieces(text, nbest_size=10)) + sp.nbest_encode_as_pieces(text, nbest_size=10), + ) self.assertEqual( sp.nbest_encode(text, nbest_size=10, out_type=int), - sp.nbest_encode_as_ids(text, nbest_size=10)) + sp.nbest_encode_as_ids(text, nbest_size=10), + ) self.assertEqual( sp.nbest_encode(text, nbest_size=10, out_type='serialized_proto'), - sp.nbest_encode_as_serialized_proto(text, nbest_size=10)) + sp.nbest_encode_as_serialized_proto(text, nbest_size=10), + ) self.assertEqual( sp.nbest_encode(text, nbest_size=10, out_type='immutable_proto'), - sp.nbest_encode_as_immutable_proto(text, nbest_size=10)) + sp.nbest_encode_as_immutable_proto(text, nbest_size=10), + ) def test_sample_and_score(self): sp = self.sp_ @@ -608,22 +655,22 @@ def test_sample_and_score(self): text2 = 'I have a pen.' for out_type in [str, int, 'serialized_proto', 'immutable_proto']: results = sp.sample_encode_and_score( - text, wor=True, num_samples=10, out_type=out_type) + text, wor=True, num_samples=10, out_type=out_type + ) results = sp.SampleEncodeAndScore( - text, wor=False, num_samples=10, out_type=out_type) + text, wor=False, num_samples=10, out_type=out_type + ) if out_type in [str, int]: for n in results: self.assertEqual(sp.decode(n[0]), text) - results = sp.sample_encode_and_score([text, text2], - wor=True, - num_samples=10, - out_type=out_type) - results = sp.SampleEncodeAndScore([text, text2], - wor=True, - num_samples=10, - out_type=out_type) + results = sp.sample_encode_and_score( + [text, text2], wor=True, num_samples=10, out_type=out_type + ) + results = sp.SampleEncodeAndScore( + [text, text2], wor=True, num_samples=10, out_type=out_type + ) if out_type in [str, int]: for n in results[0]: @@ -639,8 +686,14 @@ def test_sample_and_score(self): def test_valid_range(self): size = self.sp_.piece_size() funcs = [ - 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused', 'IsByte', - 'DecodeIds', 'DecodeIdsAsSerializedProto' + 'IdToPiece', + 'GetScore', + 'IsUnknown', + 'IsControl', + 'IsUnused', + 'IsByte', + 'DecodeIds', + 'DecodeIdsAsSerializedProto', ] for m in funcs: getattr(self.sp_, m)([10, 20, 30]) @@ -654,7 +707,8 @@ def test_valid_range(self): def test_batch(self): sp = spm.SentencePieceProcessor( - model_file=os.path.join('test', 'test_model.model')) + model_file=os.path.join('test', 'test_model.model') + ) with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file: texts = file.readlines() @@ -700,6 +754,12 @@ def test_pickle(self): self.assertEqual(id1, id2) + def test_global_params(self): + spm.SetRandomGeneratorSeed(0) + spm.SetMinLogLevel(2) + spm.set_random_generator_seed(1) + spm.set_min_log_level(3) + def suite(): suite = unittest.TestSuite() diff --git a/src/common.h b/src/common.h index 119a1c29..a7b9871a 100644 --- a/src/common.h +++ b/src/common.h @@ -74,17 +74,9 @@ char (&ArraySizeHelper(const T (&array)[N]))[N]; #endif namespace sentencepiece { -#ifdef OS_WIN -namespace win32 { -std::wstring Utf8ToWide(const absl::string_view input); -} // namespace win32 -#endif - -#ifdef IS_BIG_ENDIAN namespace util { inline uint32 Swap32(uint32 x) { return __builtin_bswap32(x); } } // namespace util -#endif namespace error { diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index 14b1e8cd..7a155175 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -431,19 +431,19 @@ class SentencePieceProcessor { #define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \ OutType output; \ const auto status = FuncName(__VA_ARGS__, &output); \ - SPP_SWIG_CHECK_AND_THROW; \ + SPP_SWIG_CHECK_AND_THROW; \ return output; #define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...) \ OutType output; \ const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \ - SPP_SWIG_CHECK_AND_THROW; \ + SPP_SWIG_CHECK_AND_THROW; \ return output.SerializeAsString(); #define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \ OutType output; \ const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \ - SPP_SWIG_CHECK_AND_THROW; \ + SPP_SWIG_CHECK_AND_THROW; \ return output; ////////////////////////////////////////////////////////////// @@ -709,6 +709,10 @@ class SentencePieceProcessor { // std::random_device. void SetRandomGeneratorSeed(unsigned int seed); +// Set the global log level. The default loglevel is 0. +// The log is emitted only when min_log_level >= output_log_level. +void SetMinLogLevel(int v); + // IO related functions to absorb model formats. namespace io { // Loads `model_proto` from `filename`. diff --git a/src/util.cc b/src/util.cc index 70c45489..61c4e5df 100644 --- a/src/util.cc +++ b/src/util.cc @@ -43,6 +43,8 @@ int GetMinLogLevel() { return g_minloglevel.load(); } void SetMinLogLevel(int v) { g_minloglevel.store(v); } } // namespace logging +void SetMinLogLevel(int v) { logging::SetMinLogLevel(v); } + namespace string_util { // mblen sotres the number of bytes consumed after decoding.