Skip to content

Commit

Permalink
fix textcat init functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
svlandeg committed May 14, 2024
1 parent c27679f commit 5992e92
Showing 1 changed file with 4 additions and 19 deletions.
23 changes: 4 additions & 19 deletions spacy/ml/models/textcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,23 +169,6 @@ def build_text_classifier_v2(
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
model.attrs["multi_label"] = not exclusive_classes

model.init = init_ensemble_textcat # type: ignore[assignment]
return model


def init_ensemble_textcat(model, X, Y) -> Model:
# When tok2vec is lazily initialized, we need to initialize it before
# the rest of the chain to ensure that we can get its width.
tok2vec = model.get_ref("tok2vec")
tok2vec.initialize(X)

tok2vec_width = get_tok2vec_width(model)
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)
model.get_ref("norm_layer").set_dim("nO", tok2vec_width)
init_chain(model, X, Y)
return model


Expand Down Expand Up @@ -273,8 +256,10 @@ def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model:

tok2vec_width = get_tok2vec_width(model)
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
model.get_ref("key_transform").set_dim("nI", tok2vec_width)
model.get_ref("key_transform").set_dim("nO", tok2vec_width)
if model.get_ref("key_transform").has_dim("nI"):
model.get_ref("key_transform").set_dim("nI", tok2vec_width)
if model.get_ref("key_transform").has_dim("nO"):
model.get_ref("key_transform").set_dim("nO", tok2vec_width)
model.get_ref("nonlinear_layer").set_dim("nI", tok2vec_width)
model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width)
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)
Expand Down

0 comments on commit 5992e92

Please sign in to comment.