Skip to content

Commit

Permalink
Merge pull request #230 from michaelfeil/update-kwargs-repl
Browse files Browse the repository at this point in the history
add infinity server kwargs for device
  • Loading branch information
michaelfeil authored May 26, 2024
2 parents 8870ac7 + ab6125d commit 9364b2a
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,10 @@ def v1(
batch_size=[batch_size],
revision=[revision], # type: ignore
trust_remote_code=[trust_remote_code],
engine=engine,
dtype=dtype,
pooling_method=pooling_method,
device=device,
engine=[engine],
dtype=[dtype],
pooling_method=[pooling_method],
device=[device],
model_warmup=[model_warmup],
vector_disk_cache=[vector_disk_cache],
lengths_via_tokenize=[lengths_via_tokenize],
Expand All @@ -369,13 +369,13 @@ def v2(
batch_size: list[int] = [32],
revision: list[str] = [""],
trust_remote_code: list[bool] = [True],
engine: InferenceEngine = InferenceEngine.default_value(), # type: ignore # noqa
engine: list[InferenceEngine] = [InferenceEngine.default_value()], # type: ignore # noqa
model_warmup: list[bool] = [True],
vector_disk_cache: list[bool] = [INFINITY_CACHE_VECTORS],
device: Device = Device.default_value(), # type: ignore
device: list[Device] = [Device.default_value()], # type: ignore
lengths_via_tokenize: list[bool] = [False],
dtype: Dtype = Dtype.default_value(), # type: ignore
pooling_method: PoolingMethod = PoolingMethod.default_value(), # type: ignore
dtype: list[Dtype] = [Dtype.default_value()], # type: ignore
pooling_method: list[PoolingMethod] = [PoolingMethod.default_value()], # type: ignore
compile: list[bool] = [False],
bettertransformer: list[bool] = [True],
# arguments for uvicorn / server
Expand Down Expand Up @@ -452,10 +452,13 @@ def v2(

def cli():
if len(sys.argv) == 1 or sys.argv[1] not in ["v1", "v2", "help", "--help"]:
print(
"WARNING: No command given. Defaulting to `v1`."
"Make sure to upgrade to the latest version of `typer`."
)
for _ in range(3):
logger.error(
"WARNING: No command given. Defaulting to `v1`."
"This will be deprecated in the future, and will require usage of a `v1` or `v2`"
"Specify the version of the CLI you want to use."
)
time.sleep(1)
sys.argv.insert(1, "v1")
print(sys.argv)
tp()
Expand Down

0 comments on commit 9364b2a

Please sign in to comment.