Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/release'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia authored and heyselbi committed May 16, 2024
2 parents 0882cd9 + 397b736 commit c9d0852
Show file tree
Hide file tree
Showing 14 changed files with 184 additions and 77 deletions.
6 changes: 2 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,15 @@ check-test-image:
integration-tests: check-test-image ## Run integration tests
mkdir -p /tmp/transformers_cache
docker run --rm -v /tmp/transformers_cache:/transformers_cache \
-e HUGGINGFACE_HUB_CACHE=/transformers_cache \
-e TRANSFORMERS_CACHE=/transformers_cache \
-e HF_HUB_CACHE=/transformers_cache \
-w /usr/src/integration_tests \
$(TEST_IMAGE_NAME) make test

.PHONY: python-tests
python-tests: check-test-image ## Run Python tests
mkdir -p /tmp/transformers_cache
docker run --rm -v /tmp/transformers_cache:/transformers_cache \
-e HUGGINGFACE_HUB_CACHE=/transformers_cache \
-e TRANSFORMERS_CACHE=/transformers_cache \
-e HF_HUB_CACHE=/transformers_cache \
$(TEST_IMAGE_NAME) pytest -sv --ignore=server/tests/test_utils.py server/tests

.PHONY: clean
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,15 @@ cd deployment

### Model configuration

When deploying TGIS, the `MODEL_NAME` environment variable can contain either the full name of a model on the Hugging Face hub (such as `google/flan-ul2`) or an absolute path to a (mounted) model directory inside the container. In the former case, the `TRANSFORMERS_CACHE` and `HUGGINGFACE_HUB_CACHE` environment variables should be set to the path of a mounted directory containing a local HF hub model cache, see [this](deployment/base/patches/pvcs/pvc.yaml) kustomize patch as an example.
When deploying TGIS, the `MODEL_NAME` environment variable can contain either the full name of a model on the Hugging Face hub (such as `google/flan-ul2`) or an absolute path to a (mounted) model directory inside the container. In the former case, the `HF_HUB_CACHE` environment variable should be set to the path of a mounted directory containing a local HF hub model cache, see [this](deployment/base/patches/pvcs/pvc.yaml) kustomize patch as an example.

### Downloading model weights

TGIS will not download model data at runtime. To populate the local HF hub cache with models so that it can be used per above, the image can be run with the following command:
```shell
text-generation-server download-weights model_name
```
where `model_name` is the name of the model on the HF hub. Ensure that it's run with the same mounted directory and `TRANSFORMERS_CACHE` and `HUGGINGFACE_HUB_CACHE` environment variables, and that it has write access to this mounted filesystem.
where `model_name` is the name of the model on the HF hub. Ensure that it's run with the same mounted directory and the `HF_HUB_CACHE` environment variable, and that it has write access to this mounted filesystem.

This will attempt to download weights in `.safetensors` format, and if those aren't in the HF hub will download pytorch `.bin` weights and then convert them to `.safetensors`.

Expand Down
47 changes: 38 additions & 9 deletions integration_tests/text_generation_tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def start_server(
master_port: int,
timeout=30,
model_path=None,
include_cache_env_vars=True,
env=None,
output_special_tokens=False,
):
# Download weights to the cache first
Expand Down Expand Up @@ -66,13 +66,12 @@ def start_server(
if output_special_tokens:
args.append("--output-special-tokens")

env = os.environ.copy()
if env is None:
env = os.environ.copy()

env["RUST_BACKTRACE"] = "full"
env["ESTIMATE_MEMORY"] = "manual"
env["PREFIX_STORE_PATH"] = os.path.join(TESTS_DIR, "prompt_prefixes")
if not include_cache_env_vars:
env.pop("TRANSFORMERS_CACHE", None)
env.pop("HUGGING_FACE_HUB_CACHE", None)

# Start the process
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env)
Expand Down Expand Up @@ -455,17 +454,21 @@ async def test_time_limit_stopping(server_fixture):

# Test loading when an explicit local path is provided
def test_explicit_path():
# Test with and without providing TRANSFORMERS_CACHE env var
path = glob.glob(f'{os.environ["TRANSFORMERS_CACHE"]}/models--bigscience--mt0-small/snapshots/*')[0]
for include_env_vars in [False, True]:
path = glob.glob(f'{os.environ["HF_HUB_CACHE"]}/models--bigscience--mt0-small/snapshots/*')[0]

# Test with and without providing HF_HUB_CACHE
env_with = os.environ.copy()
env_without = os.environ.copy()
env_without.pop("HF_HUB_CACHE", None)
for env in [env_with, env_without]:
p = start_server(
"bigscience/mt0-small",
".bin,.json,.model",
1,
3000,
29502,
model_path=path,
include_cache_env_vars=include_env_vars,
env=env,
)
try:
async def test_model_info() -> pb2.ModelInfoResponse:
Expand All @@ -481,6 +484,32 @@ async def test_model_info() -> pb2.ModelInfoResponse:

assert p.wait(8.0) == 0

# Test loading with only TRANSFORMERS_CACHE set
def test_transformers_cache():
env = os.environ.copy()
env["TRANSFORMERS_CACHE"] = env.pop("HF_HUB_CACHE")
p = start_server(
"bigscience/mt0-small",
".bin,.json,.model",
1,
3000,
29502,
env=env,
)
try:
async def test_model_info() -> pb2.ModelInfoResponse:
async with grpc.aio.insecure_channel('localhost:8033') as channel:
return await gpb2.GenerationServiceStub(channel).ModelInfo(pb2.ModelInfoRequest(model_id="unused"))

result = asyncio.get_event_loop().run_until_complete(test_model_info())
assert result.max_sequence_length == 200
assert result.max_new_tokens == 169
assert result.model_kind == pb2.ModelInfoResponse.ModelKind.ENCODER_DECODER
finally:
p.terminate()

assert p.wait(8.0) == 0


# To avoid errors related to event loop shutdown timing
@pytest.fixture(scope="session")
Expand Down
94 changes: 63 additions & 31 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,54 @@ fn main() -> ExitCode {
// Determine number of shards based on command line arg and env vars
let num_shard = find_num_shards(args.num_shard);

let config_path: PathBuf = resolve_config_path(&args.model_name, args.revision.as_deref())
// Determine the model cache path and resolve from possible env vars:
// - HF_HUB_CACHE
// - TRANSFORMERS_CACHE (deprecated)
// - HUGGINGFACE_HUB_CACHE (deprecated)
//
// We allow multiple to be set for compatibility, but then the values must match.

let mut cache_env_var: String = "".to_string();
let mut cache_env_value: String = "".to_string();

if let Ok(t) = env::var("HF_HUB_CACHE") {
cache_env_var = "HF_HUB_CACHE".into();
cache_env_value = t.into();
}

for deprecated_env_var in vec!["TRANSFORMERS_CACHE", "HUGGINGFACE_HUB_CACHE"] {
match (
env::var(deprecated_env_var),
!cache_env_var.is_empty(),
) {
(Ok(t), false) => {
cache_env_var = deprecated_env_var.into();
cache_env_value = t.into();
},
(Ok(t), true) if t != cache_env_value => panic!(
"{deprecated_env_var} and {cache_env_var} env vars can't be set to different values"
),
(Ok(_), true) => warn!(
"{deprecated_env_var} is deprecated and should not be used. Use HF_HUB_CACHE instead."
),
_ => (),
}
}

// ensure HF_HUB_CACHE is set for downstream usage
// default value to match huggingface_hub
// REF: https://github.com/huggingface/huggingface_hub/blob/5ff2d150d121d04799b78bc08f2343c21b8f07a9/docs/source/en/package_reference/environment_variables.md?plain=1#L32
let cache_path = if !cache_env_value.is_empty() {
PathBuf::from(cache_env_value)
} else if let Ok(hf_home) = env::var("HF_HOME") {
PathBuf::from(hf_home).join("hub")
} else if let Ok(home) = env::var("HOME") {
PathBuf::from(home).join(".cache").join("huggingface").join("hub")
} else {
PathBuf::new()
};

let config_path: PathBuf = resolve_config_path(cache_path.clone(), &args.model_name, args.revision.as_deref())
.expect("Failed to resolve config path")
.into();

Expand Down Expand Up @@ -223,15 +270,18 @@ fn main() -> ExitCode {
let (status_sender, status_receiver) = mpsc::channel();

// Start shard processes
let cache_path_string = cache_path.into_os_string();
for rank in 0..num_shard {
let args = args.clone();
let cache_path = cache_path_string.clone();
let deployment_framework = deployment_framework.to_string();
let status_sender = status_sender.clone();
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
thread::spawn(move || {
shard_manager(
args.model_name,
cache_path,
args.revision,
deployment_framework,
args.dtype.or(args.dtype_str),
Expand Down Expand Up @@ -548,6 +598,7 @@ enum ShardStatus {
#[allow(clippy::too_many_arguments)]
fn shard_manager(
model_name: String,
cache_path: OsString,
revision: Option<String>,
deployment_framework: String,
dtype: Option<String>,
Expand Down Expand Up @@ -620,19 +671,6 @@ fn shard_manager(
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();

// Fix up TRANSFORMERS_CACHE and HUGGINGFACE_HUB_CACHE env vars
match (
env::var("TRANSFORMERS_CACHE"),
env::var("HUGGINGFACE_HUB_CACHE"),
) {
(Ok(t), Err(_)) => env.push(("HUGGINGFACE_HUB_CACHE".into(), t.into())),
(Err(_), Ok(h)) => env.push(("TRANSFORMERS_CACHE".into(), h.into())),
(Ok(t), Ok(h)) if t != h => panic!(
"TRANSFORMERS_CACHE and HUGGINGFACE_HUB_CACHE env vars can't be set to different values"
),
_ => (),
}

if let Some(alloc_conf) = cuda_alloc_conf {
if alloc_conf.is_empty() {
// Remove it from env
Expand Down Expand Up @@ -665,6 +703,9 @@ fn shard_manager(
// Ensure offline-only
env.push(("HF_HUB_OFFLINE".into(), "1".into()));

// Ensure that we set the standard cache variable
env.push(("HF_HUB_CACHE".into(), cache_path.into()));

// Start process
info!("Starting shard {rank}");
let mut p = match Command::new("text-generation-server")
Expand Down Expand Up @@ -776,18 +817,13 @@ fn write_termination_log(msg: &str) -> Result<(), io::Error> {
Ok(())
}

fn resolve_config_path(model_name: &str, revision: Option<&str>) -> Result<String, io::Error> {
let cache = env::var("TRANSFORMERS_CACHE")
.or_else(|_| env::var("HUGGINGFACE_HUB_CACHE"))
.ok();
let mut model_dir = cache
.as_ref()
.map(|c| Path::new(&c).join(format!("models--{}", model_name.replace('/', "--"))));
if let Some(ref d) = model_dir {
if !d.try_exists()? {
model_dir = None;
}
}
fn resolve_config_path(cache_path: PathBuf, model_name: &str, revision: Option<&str>) -> Result<String, io::Error> {
let model_hf_cache_dir = cache_path.join(format!("models--{}", model_name.replace('/', "--")));
let model_dir = if model_hf_cache_dir.try_exists()? {
Some(model_hf_cache_dir)
} else {
None
};
if let Some(dir) = model_dir {
let revision = revision.unwrap_or("main");
let ref_path = dir.join("refs").join(revision);
Expand All @@ -811,11 +847,7 @@ fn resolve_config_path(model_name: &str, revision: Option<&str>) -> Result<Strin
if try_path.try_exists()? {
Ok(try_path.to_string_lossy().into())
} else {
let message = if cache.is_none() {
format!("Model path {model_name} not found (TRANSFORMERS_CACHE env var not set)")
} else {
format!("Model {model_name} not found in local cache")
};
let message = format!("Model {model_name} not found");
error!(message);
Err(io::Error::new(ErrorKind::NotFound, message))
}
Expand Down
10 changes: 5 additions & 5 deletions server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ onnxruntime-gpu = { version = "^1.17.1", optional = true }
onnx = { version = "^1.16.0", optional = true }
einops = "^0.7.0"
ibm-fms = { version = "^0.0", optional = true }
fms-extras = { git = "https://github.com/foundation-model-stack/fms-extras", rev = "a010516ff2c938c206b9b342b16bd747ef07d43c", optional = true }
fms-extras = { git = "https://github.com/foundation-model-stack/fms-extras", rev = "d41f8a34c9841aa3c4c59f17b5e7f3cb365f49de", optional = true }

# Explicitly install some transitive dependencies to avoid CVEs
jinja2 = ">=3.1.3"
Expand Down
10 changes: 10 additions & 0 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,14 @@ def convert_to_fast_tokenizer(


if __name__ == "__main__":

# Use of TRANSFORMERS_CACHE is deprecated
if (tc := os.getenv("TRANSFORMERS_CACHE")) is not None:
print("WARNING: Using TRANSFORMERS_CACHE is deprecated. Use HF_HUB_CACHE instead.")
hc = os.getenv("HF_HUB_CACHE")
if tc != hc:
raise ValueError("Conflicting model cache values between TRANSFORMERS_CACHE and HF_HUB_CACHE")
if hc is None:
os.putenv("HF_HUB_CACHE", tc)

app()
5 changes: 5 additions & 0 deletions server/text_generation_server/inference_engine/tgis_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def __init__(
model_class = FlashRWForCausalLM

elif model_type == "llama":
# See: https://github.com/ibm-granite/vllm_granite/blob/main/vllm/model_executor/models/llama.py#L353-L354
if self._config.tie_word_embeddings:
aliases = {
"lm_head.weight": ["model.embed_tokens.weight"]
}
if PAGED_ATTENTION:
from text_generation_server.models.custom_modeling.paged_llama_modeling import PagedLlamaForCausalLM
model_class = PagedLlamaForCausalLM
Expand Down
Loading

0 comments on commit c9d0852

Please sign in to comment.