Skip to content
This repository has been archived by the owner on Mar 9, 2023. It is now read-only.

Commit

Permalink
Merge pull request #163 from WorksApplications/feature/fix_join_katak…
Browse files Browse the repository at this point in the history
…ana_oov

Fix issue #162
  • Loading branch information
kazuma-t authored Sep 10, 2021
2 parents 9eae063 + 8c4a6ef commit 0734f9b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sudachipy/lattice.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ cdef class Lattice:
return self.end_lists[end]

def get_nodes(self, begin: int, end: int) -> List[LatticeNode]:
return [node for node in self.end_lists[end] if node.begin == begin]
return [node for node in self.end_lists[end] if node.get_begin() == begin]

def get_minumum_node(self, begin: int, end: int) -> Optional[LatticeNode]:
def get_minimum_node(self, begin: int, end: int) -> Optional[LatticeNode]:
nodes = self.get_nodes(begin, end)
if not nodes:
return None
min_arg = nodes[0]
for node in nodes[1:]:
if node.cost < min_arg.cost:
if node.get_path_cost() < min_arg.get_path_cost():
min_arg = node
return min_arg

Expand Down
7 changes: 7 additions & 0 deletions sudachipy/plugin/path_rewrite/path_rewrite_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def concatenate_oov(self, path, begin, end, pos_id, lattice):
raise IndexError("begin >= end")
b = path[begin].get_begin()
e = path[end - 1].get_end()

n = lattice.get_minimum_node(b, e)
if n is not None:
path[begin:end] = [n]
return n

surface = ""
length = 0
for i in range(begin, end):
Expand All @@ -76,6 +82,7 @@ def concatenate_oov(self, path, begin, end, pos_id, lattice):
node = lattice.create_node()
node.set_range(b, e)
node.set_word_info(wi)
node.set_oov()

path[begin:end] = [node]
return node
Expand Down
2 changes: 2 additions & 0 deletions tests/plugin/test_join_katakana_oov_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def test_pos(self):
path = self.get_path('アイアイウ')
self.assertEqual(1, len(path))
self.assertFalse(path[0].is_oov())
self.assertEqual(['名詞', '固有名詞', '地名', '一般', '*', '*'],
self.dict_.grammar.get_part_of_speech_string(path[0].get_word_info().pos_id))

def test_starts_with_middle(self):
self.plugin._min_length = 3
Expand Down

0 comments on commit 0734f9b

Please sign in to comment.