diff --git a/libs/infinity_emb/infinity_emb/infinity_server.py b/libs/infinity_emb/infinity_emb/infinity_server.py index 9a361956..faeb3813 100644 --- a/libs/infinity_emb/infinity_emb/infinity_server.py +++ b/libs/infinity_emb/infinity_emb/infinity_server.py @@ -340,10 +340,10 @@ def v1( batch_size=[batch_size], revision=[revision], # type: ignore trust_remote_code=[trust_remote_code], - engine=engine, - dtype=dtype, - pooling_method=pooling_method, - device=device, + engine=[engine], + dtype=[dtype], + pooling_method=[pooling_method], + device=[device], model_warmup=[model_warmup], vector_disk_cache=[vector_disk_cache], lengths_via_tokenize=[lengths_via_tokenize], @@ -369,13 +369,13 @@ def v2( batch_size: list[int] = [32], revision: list[str] = [""], trust_remote_code: list[bool] = [True], - engine: InferenceEngine = InferenceEngine.default_value(), # type: ignore # noqa + engine: list[InferenceEngine] = [InferenceEngine.default_value()], # type: ignore # noqa model_warmup: list[bool] = [True], vector_disk_cache: list[bool] = [INFINITY_CACHE_VECTORS], - device: Device = Device.default_value(), # type: ignore + device: list[Device] = [Device.default_value()], # type: ignore lengths_via_tokenize: list[bool] = [False], - dtype: Dtype = Dtype.default_value(), # type: ignore - pooling_method: PoolingMethod = PoolingMethod.default_value(), # type: ignore + dtype: list[Dtype] = [Dtype.default_value()], # type: ignore + pooling_method: list[PoolingMethod] = [PoolingMethod.default_value()], # type: ignore compile: list[bool] = [False], bettertransformer: list[bool] = [True], # arguments for uvicorn / server @@ -452,10 +452,13 @@ def v2( def cli(): if len(sys.argv) == 1 or sys.argv[1] not in ["v1", "v2", "help", "--help"]: - print( - "WARNING: No command given. Defaulting to `v1`." - "Make sure to upgrade to the latest version of `typer`." - ) + for _ in range(3): + logger.error( + "WARNING: No command given. Defaulting to `v1`." + "This will be deprecated in the future, and will require usage of a `v1` or `v2`" + "Specify the version of the CLI you want to use." + ) + time.sleep(1) sys.argv.insert(1, "v1") print(sys.argv) tp()