Skip to content

Commit

Permalink
Fix serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Jan 2, 2025
1 parent 594d858 commit 8fc8a58
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,8 @@ def _get_state_dict(self) -> dict[str, Any]:
"encoding_strategy": self.encoding_strategy,
"zero_tag_value": self.zero_tag_value,
"allow_unk_tag": self.allow_unk_tag,
"max_allowed_tokens_between_entities": self._max_allowed_tokens_between_entities,
"max_encoded_sentence_length": self._max_encoded_sentence_length,
}
return model_state

Expand All @@ -772,6 +774,8 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
encoding_strategy=state["encoding_strategy"],
zero_tag_value=state["zero_tag_value"],
allow_unk_tag=state["allow_unk_tag"],
max_allowed_tokens_between_entities=state.get("max_allowed_tokens_between_entities", 25),
max_encoded_sentence_length=state.get("max_encoded_sentence_length", 50),
**kwargs,
)

Expand Down

0 comments on commit 8fc8a58

Please sign in to comment.