Skip to content

Commit

Permalink
Auto-detect whether to always label the root node
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitakit committed Feb 5, 2021
1 parent 28a62e6 commit 547e548
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
12 changes: 11 additions & 1 deletion src/benepar/decode_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,18 @@ def build_vocab(trees):
label_set = set()
for tree in trees:
for _, _, label in get_labeled_spans(tree):
label_set.add(label)
if label:
label_set.add(label)
label_set = [""] + sorted(label_set)
return {label: i for i, label in enumerate(label_set)}

@staticmethod
def infer_force_root_constituent(trees):
for tree in trees:
for _, _, label in get_labeled_spans(tree):
if not label:
return False
return True

def chart_from_tree(self, tree):
spans = get_labeled_spans(tree)
Expand Down Expand Up @@ -168,6 +177,7 @@ def charts_from_pytorch_scores_batched(self, scores, lengths):
def compressed_output_from_chart(self, chart):
chart_with_filled_diagonal = chart.copy()
np.fill_diagonal(chart_with_filled_diagonal, 1)
chart_with_filled_diagonal[0, -1] = 1
starts, inclusive_ends = np.where(chart_with_filled_diagonal)
preorder_sort = np.lexsort((-inclusive_ends, starts))
starts = starts[preorder_sort]
Expand Down
12 changes: 10 additions & 2 deletions src/benepar/parse_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,13 @@ def __init__(
self.f_tag = None
self.tag_from_index = None

self.decoder = decode_chart.ChartDecoder(label_vocab=self.label_vocab)
self.criterion = decode_chart.SpanClassificationMarginLoss(reduction="sum")
self.decoder = decode_chart.ChartDecoder(
label_vocab=self.label_vocab,
force_root_constituent=hparams.force_root_constituent,
)
self.criterion = decode_chart.SpanClassificationMarginLoss(
reduction="sum", force_root_constituent=hparams.force_root_constituent
)

self.parallelized_devices = None

Expand Down Expand Up @@ -153,6 +158,9 @@ def from_trained(cls, model_path, config=None, state_dict=None):
config = config.copy()
hparams = config["hparams"]

if "force_root_constituent" not in hparams:
hparams["force_root_constituent"] = True

config["hparams"] = nkutil.HParams(**hparams)
parser = cls(**config)
parser.load_state_dict(state_dict)
Expand Down
11 changes: 11 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def make_hparams():
relu_dropout=0.1,
residual_dropout=0.2,
# Output heads and losses
force_root_constituent="auto",
predict_tags=False,
d_label_hidden=256,
d_tag_hidden=256,
Expand Down Expand Up @@ -119,6 +120,16 @@ def run_train(args, hparams):
tag_vocab = ["UNK"] + sorted(tag_vocab)
tag_vocab = {label: i for i, label in enumerate(tag_vocab)}

if hparams.force_root_constituent.lower() in ("true", "yes", "1"):
hparams.force_root_constituent = True
elif hparams.force_root_constituent.lower() in ("false", "no", "0"):
hparams.force_root_constituent = False
elif hparams.force_root_constituent.lower() == "auto":
hparams.force_root_constituent = (
decode_chart.ChartDecoder.infer_force_root_constituent(train_treebank.trees)
)
print("Set hparams.force_root_constituent to", hparams.force_root_constituent)

print("Initializing model...")
parser = parse_chart.ChartParser(
tag_vocab=tag_vocab,
Expand Down

0 comments on commit 547e548

Please sign in to comment.