diff --git a/sudachipy/lattice.pyx b/sudachipy/lattice.pyx index 86f68d7..fdb120e 100644 --- a/sudachipy/lattice.pyx +++ b/sudachipy/lattice.pyx @@ -57,15 +57,15 @@ cdef class Lattice: return self.end_lists[end] def get_nodes(self, begin: int, end: int) -> List[LatticeNode]: - return [node for node in self.end_lists[end] if node.begin == begin] + return [node for node in self.end_lists[end] if node.get_begin() == begin] - def get_minumum_node(self, begin: int, end: int) -> Optional[LatticeNode]: + def get_minimum_node(self, begin: int, end: int) -> Optional[LatticeNode]: nodes = self.get_nodes(begin, end) if not nodes: return None min_arg = nodes[0] for node in nodes[1:]: - if node.cost < min_arg.cost: + if node.get_path_cost() < min_arg.get_path_cost(): min_arg = node return min_arg diff --git a/sudachipy/plugin/path_rewrite/path_rewrite_plugin.py b/sudachipy/plugin/path_rewrite/path_rewrite_plugin.py index 3265c4e..7ccd744 100644 --- a/sudachipy/plugin/path_rewrite/path_rewrite_plugin.py +++ b/sudachipy/plugin/path_rewrite/path_rewrite_plugin.py @@ -62,6 +62,12 @@ def concatenate_oov(self, path, begin, end, pos_id, lattice): raise IndexError("begin >= end") b = path[begin].get_begin() e = path[end - 1].get_end() + + n = lattice.get_minimum_node(b, e) + if n is not None: + path[begin:end] = [n] + return n + surface = "" length = 0 for i in range(begin, end): @@ -76,6 +82,7 @@ def concatenate_oov(self, path, begin, end, pos_id, lattice): node = lattice.create_node() node.set_range(b, e) node.set_word_info(wi) + node.set_oov() path[begin:end] = [node] return node diff --git a/tests/plugin/test_join_katakana_oov_plugin.py b/tests/plugin/test_join_katakana_oov_plugin.py index d1fc70b..d025b4a 100644 --- a/tests/plugin/test_join_katakana_oov_plugin.py +++ b/tests/plugin/test_join_katakana_oov_plugin.py @@ -53,6 +53,8 @@ def test_pos(self): path = self.get_path('アイアイウ') self.assertEqual(1, len(path)) self.assertFalse(path[0].is_oov()) + self.assertEqual(['名詞', '固有名詞', '地名', '一般', '*', '*'], + self.dict_.grammar.get_part_of_speech_string(path[0].get_word_info().pos_id)) def test_starts_with_middle(self): self.plugin._min_length = 3