Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add backend param for triton serving #1039

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions charts/triton/templates/_helpers.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,20 @@ Create chart name and version as used by the chart label.
{{- define "nvidia-triton-server.chart" -}}
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" -}}
{{- end -}}

{{/*
Return tritonserver image
*/}}
{{- define "triton.image" -}}
{{- if .Values.image }}
{{- .Values.image -}}
{{- else }}
{{- if eq .Values.backend "vllm" }}
{{- "nvcr.io/nvidia/tritonserver:24.01-vllm-python-py3" -}}
{{- else if eq .Values.backend "trt-llm" }}
{{- "nvcr.io/nvidia/tritonserver:24.01-trtllm-python-py3" -}}
{{- else }}
{{- "nvcr.io/nvidia/tritonserver:24.01-py3" -}}
{{- end }}
{{- end }}
{{- end -}}
4 changes: 1 addition & 3 deletions charts/triton/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ spec:
{{- end }}
containers:
- name: tritonserver
{{- if .Values.image }}
image: "{{ .Values.image }}"
{{- end }}
image: {{ include "triton.image" . }}
{{- if .Values.imagePullPolicy }}
imagePullPolicy: "{{ .Values.imagePullPolicy }}"
{{- end }}
Expand Down
2 changes: 0 additions & 2 deletions charts/triton/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ serviceType: ClusterIP
servingName:
servingVersion:

image: "nvcr.io/nvidia/tritonserver:24.01-py3"

imagePullPolicy: "IfNotPresent"

cpu: 1.0
Expand Down
1 change: 0 additions & 1 deletion pkg/apis/serving/triton_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ func NewTritonServingJobBuilder() *TritonServingJobBuilder {
GrpcPort: 8001,
MetricsPort: 8002,
CommonServingArgs: types.CommonServingArgs{
Image: argsbuilder.DefaultTritonServingImage,
ImagePullPolicy: "IfNotPresent",
Replicas: 1,
Namespace: "default",
Expand Down
1 change: 1 addition & 0 deletions pkg/apis/types/serving.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ type SeldonServingArgs struct {
}

type TritonServingArgs struct {
Backend string `yaml:"backend"` // --backend
ModelRepository string `yaml:"modelRepository"` // --model-repository
MetricsPort int `yaml:"metricsPort"` // --metrics-port
HttpPort int `yaml:"httpPort"` // --http-port
Expand Down
16 changes: 6 additions & 10 deletions pkg/argsbuilder/serving_triton.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ import (
"github.com/spf13/cobra"
)

const (
DefaultTritonServingImage = "nvcr.io/nvidia/tritonserver:24.01-py3"
)

type TritonServingArgsBuilder struct {
args *types.TritonServingArgs
argValues map[string]interface{}
Expand All @@ -43,7 +39,6 @@ func NewTritonServingArgsBuilder(args *types.TritonServingArgs) ArgsBuilder {
s.AddSubBuilder(
NewServingArgsBuilder(&s.args.CommonServingArgs),
)
s.AddArgValue("default-image", DefaultTritonServingImage)
return s
}

Expand Down Expand Up @@ -72,6 +67,7 @@ func (s *TritonServingArgsBuilder) AddCommandFlags(command *cobra.Command) {
s.subBuilders[name].AddCommandFlags(command)
}
var loadModels []string
command.Flags().StringVar(&s.args.Backend, "backend", "", "the backend type of triton server. Valid values: [vllm|trt-llm]")
command.Flags().StringVar(&s.args.ModelRepository, "model-repository", "", "the path of triton model path")
command.Flags().IntVar(&s.args.HttpPort, "http-port", 8000, "the port of http serving server")
command.Flags().IntVar(&s.args.GrpcPort, "grpc-port", 8001, "the port of grpc serving server")
Expand Down Expand Up @@ -112,14 +108,14 @@ func (s *TritonServingArgsBuilder) Build() error {
}

func (s *TritonServingArgsBuilder) validate() (err error) {
if s.args.Image == "" {
return fmt.Errorf("image must be specified")
}
/*
if s.args.Backend != "" {
if s.args.Backend != "vllm" && s.args.Backend != "trt-llm" {
return fmt.Errorf("backend %s is Invalid. Triton backend only supports vllm or trt-llm", s.args.Backend)
}
if s.args.GPUCount == 0 {
return fmt.Errorf("--gpus must be specific at least 1 GPU")
}
*/
}
return nil
}

Expand Down
Loading