Skip to content

Commit

Permalink
Merge pull request #3 from huggingface/conversion-fixes
Browse files Browse the repository at this point in the history
Conversion fixes: add 27b-896
  • Loading branch information
pcuenca authored Dec 3, 2024
2 parents 26a59d3 + d919ae2 commit 96180e5
Showing 1 changed file with 2 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
from transformers.utils import logging


device = "cuda" # "cpu"
device = "cpu"

logging.set_verbosity_info()
logger = logging.get_logger(__name__)

# TODO add sequence length variations here

PALIGEMMA2_VARIANTS = ["2b-224", "2b-448", "2b-896", "9b-224", "9b-448", "9b-896", "27b-224", "27b-448"]
PALIGEMMA2_VARIANTS = ["2b-224", "2b-448", "2b-896", "9b-224", "9b-448", "9b-896", "27b-224", "27b-448", "27b-896"]
VARIANT_CONFIGS = {
"2b": {
"num_positions": 256,
Expand Down Expand Up @@ -310,7 +310,6 @@ def convert_paligemma2_checkpoint(
Read checkpoints from flax npz files, rename/reshape, send result to state dict and verify logits if needed.
"""
config = get_paligemma2_config(variant, precision=precision)
device = "cuda" if torch.cuda.is_available() else "cpu"
if do_convert_weights:
tokenizer_id = "google/paligemma-3b-pt-224" # same tokenizer as paligemma 1
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
Expand Down

0 comments on commit 96180e5

Please sign in to comment.