Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the KServe wrapper configuration loading more resiliant #2995

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 21 additions & 31 deletions kubernetes/kserve/kserve_wrapper/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,52 +28,42 @@ def parse_config():
model_store: the path in which the .mar file resides
"""
separator = "="
keys = {}
ts_configuration = {}
config_path = os.environ.get("CONFIG_PATH", DEFAULT_CONFIG_PATH)

logging.info(f"Wrapper: loading configuration from {config_path}")

with open(config_path) as f:
for line in f:
if separator in line:
# Find the name and value by splitting the string
name, value = line.split(separator, 1)

# Assign key value pair to dict
# strip() removes white space from the ends of strings
keys[name.strip()] = value.strip()

keys["model_snapshot"] = json.loads(keys["model_snapshot"])
inference_address, management_address, grpc_inference_port, model_store = (
keys["inference_address"],
keys["management_address"],
keys["grpc_inference_port"],
keys["model_store"],
if not line.startswith("#"):
if separator in line:
name, value = line.split(separator, 1)
ts_configuration[name.strip()] = value.strip()

ts_configuration["model_snapshot"] = json.loads(
ts_configuration.get("model_snapshot", "{}")
)

models = keys["model_snapshot"]["models"]
model_names = []
inference_address = ts_configuration.get(
"inference_address", DEFAULT_INFERENCE_ADDRESS
)
management_address = ts_configuration.get(
"management_address", DEFAULT_MANAGEMENT_ADDRESS
)
grpc_inference_port = ts_configuration.get(
"grpc_inference_port", DEFAULT_GRPC_INFERENCE_PORT
)
model_store = ts_configuration.get("model_store", DEFAULT_MODEL_STORE)

# Get all the model_names
for model, value in models.items():
model_names.append(model)
model_names = ts_configuration["model_snapshot"].get("models", {}).keys()

if not inference_address:
inference_address = DEFAULT_INFERENCE_ADDRESS
if not model_names:
model_names = [DEFAULT_MODEL_NAME]
if not inference_address:
inference_address = DEFAULT_INFERENCE_ADDRESS
if not management_address:
management_address = DEFAULT_MANAGEMENT_ADDRESS

inf_splits = inference_address.split(":")
if not grpc_inference_port:
grpc_inference_address = inf_splits[1] + ":" + DEFAULT_GRPC_INFERENCE_PORT
else:
grpc_inference_address = inf_splits[1] + ":" + grpc_inference_port
grpc_inference_address = inf_splits[1] + ":" + grpc_inference_port
grpc_inference_address = grpc_inference_address.replace("/", "")
if not model_store:
model_store = DEFAULT_MODEL_STORE

logging.info(
"Wrapper : Model names %s, inference address %s, management address %s, grpc_inference_address, %s, model store %s",
Expand Down
Loading