Skip to content

Commit

Permalink
Merge pull request #60 from jalammar/hotfix-jan-2022-1
Browse files Browse the repository at this point in the history
Hotfix jan 2022 1
  • Loading branch information
jalammar authored Jan 4, 2022
2 parents 9aa8639 + 5e10240 commit 431fc5c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def read(*names, **kwargs):

setup(
name='ecco',
version='0.1.0',
version='0.1.1',
license='BSD-3-Clause',
description='Visualization tools for NLP machine learning models.',
long_description='%s\n%s' % (
Expand Down
2 changes: 1 addition & 1 deletion src/ecco/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""


__version__ = '0.0.15'
__version__ = '0.1.1'
from ecco.lm import LM
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM
from typing import Any, Dict, Optional, List
Expand Down
5 changes: 4 additions & 1 deletion src/ecco/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ def generate(self, input_str: str,
# Get decoder input ids
if self.model_type == 'enc-dec': # FIXME: only done because causal LMs like GPT-2 have the _prepare_decoder_input_ids_for_generation method but do not use it
assert len(input_ids.size()) == 2 # will break otherwise
decoder_input_ids = self.model._prepare_decoder_input_ids_for_generation(input_ids, None, None)
if transformers.__version__ >= '4.13': # ALSO FIXME: awful hack. But seems to work?
decoder_input_ids = self.model._prepare_decoder_input_ids_for_generation(input_ids.shape[0], None, None)
else:
decoder_input_ids = self.model._prepare_decoder_input_ids_for_generation(input_ids, None, None)
else:
decoder_input_ids = None

Expand Down

0 comments on commit 431fc5c

Please sign in to comment.