Skip to content

Commit

Permalink
#5386: [Falcon7b] Remove hf reference files and import from transform…
Browse files Browse the repository at this point in the history
…ers instead

Signed-off-by: Salar Hosseini <skhorasgani@tenstorrent.com>
  • Loading branch information
skhorasganiTT committed Aug 22, 2024
1 parent 51e39c6 commit 9b2885c
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 1,574 deletions.
8 changes: 3 additions & 5 deletions models/demos/falcon7b_common/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
import torch.nn.functional as F
import ttnn
from loguru import logger
from models.demos.falcon7b_common.reference.hf_modeling_falcon import FalconConfig
from models.demos.falcon7b_common.tt.falcon_causallm import TtFalconCausalLM
from models.demos.falcon7b_common.tt.model_config import get_model_config, model_config_entries
from models.demos.falcon7b_common.tt.model_config import get_model_config
from models.demos.falcon7b_common.tests.test_utils import (
initialize_kv_cache,
load_hf_model,
Expand Down Expand Up @@ -165,8 +164,6 @@ def run_falcon_demo_kv(
if perf_mode:
logger.info("Running in performance measurement mode (invalid outputs)!")

configuration = FalconConfig(**model_config_entries)

profiler.start(f"loading_inputs")
if num_devices > 1:
assert len(user_input) == global_batch, "Number of users must be equal to batch size * number of devices!"
Expand Down Expand Up @@ -194,7 +191,8 @@ def run_falcon_demo_kv(
# State dict is needed for embeddings
logger.info("Loading huggingface weights...")
profiler.start(f"loading_weights")
_, state_dict = load_hf_model(model_location_generator, model_version)
hugging_face_reference_model, state_dict = load_hf_model(model_location_generator, model_version)
configuration = hugging_face_reference_model.config
logger.info("Loading weights finished!")
profiler.end(f"loading_weights")

Expand Down
2 changes: 1 addition & 1 deletion models/demos/falcon7b_common/reference/cpu_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from transformers import AutoTokenizer

from models.demos.falcon7b_common.reference.hf_modeling_falcon import FalconForCausalLM
from transformers import FalconForCausalLM
import time

falcon1b = "tiiuae/falcon-rw-1b"
Expand Down
156 changes: 0 additions & 156 deletions models/demos/falcon7b_common/reference/hf_configuration_falcon.py

This file was deleted.

Loading

0 comments on commit 9b2885c

Please sign in to comment.