diff --git a/gen/_template/client.tmpl b/gen/_template/client.tmpl index 7908b2388..2c9314c23 100644 --- a/gen/_template/client.tmpl +++ b/gen/_template/client.tmpl @@ -2,6 +2,59 @@ {{- /*gotype: github.com/ogen-go/ogen/gen.TemplateConfig*/ -}} {{ template "header" $ }} +{{- if $.RequestOptionsEnabled }} +type requestConfig struct { + Client ht.Client + ServerURL *url.URL + EditRequest func(req *http.Request) error + EditResponse func(resp *http.Response) error +} + +func (cfg *requestConfig) setDefaults(c baseClient) { + if cfg.Client == nil { + cfg.Client = c.cfg.Client + } +} + +func (cfg *requestConfig) onRequest(req *http.Request) error { + if fn := cfg.EditRequest; fn != nil { + return fn(req) + } + return nil +} + +func (cfg *requestConfig) onResponse(resp *http.Response) error { + if fn := cfg.EditResponse; fn != nil { + return fn(resp) + } + return nil +} + +// RequestOption defines options for request. +type RequestOption func(cfg *requestConfig) + +// WithRequestClient sets client for request. +func WithRequestClient(client ht.Client) RequestOption { + return func(cfg *requestConfig) { + cfg.Client = client + } +} + +// WithEditRequest sets function to edit request. +func WithEditRequest(fn func(req *http.Request) error) RequestOption { + return func(cfg *requestConfig) { + cfg.EditRequest = fn + } +} + +// WithEditResponse sets function to edit response. +func WithEditResponse(fn func(resp *http.Response) error) RequestOption { + return func(cfg *requestConfig) { + cfg.EditResponse = fn + } +} +{{- end }} + {{- if $.PathsClientEnabled }} // Invoker invokes operations described by OpenAPI v3 specification. @@ -16,7 +69,8 @@ type Invoker interface { {{ $op.Name }}(ctx context.Context {{- if $op.WebhookInfo }}, targetURL string{{ end }} {{- if $op.Request }}, request {{ $op.Request.GoType }}{{ end }} - {{- if $op.Params }}, params {{ $op.Name }}Params {{ end }}) {{ $op.Responses.ResultTuple "" "" }} + {{- if $op.Params }}, params {{ $op.Name }}Params {{ end }} + {{- if $.RequestOptionsEnabled }}, options ...RequestOption{{ end }}) {{ $op.Responses.ResultTuple "" "" }} {{- end }} } @@ -32,7 +86,8 @@ type {{ $group.Name }}Invoker interface { {{ $op.Name }}(ctx context.Context {{- if $op.WebhookInfo }}, targetURL string{{ end }} {{- if $op.Request }}, request {{ $op.Request.GoType }}{{ end }} - {{- if $op.Params }}, params {{ $op.Name }}Params {{ end }}) {{ $op.Responses.ResultTuple "" "" }} + {{- if $op.Params }}, params {{ $op.Name }}Params {{ end }} + {{- if $.RequestOptionsEnabled }}, options ...RequestOption {{ end }}) {{ $op.Responses.ResultTuple "" "" }} {{- end }} } {{- end }} @@ -46,7 +101,7 @@ type Client struct { baseClient } -{{- if $.PathsServerEnabled }} +{{- if and $.PathsServerEnabled (not $.RequestOptionsEnabled) }} {{- if $.Error }} type errorHandler interface { NewError(ctx context.Context, err error) {{ $.ErrorGoType }} @@ -84,6 +139,7 @@ func NewClient(serverURL string, {{- if $.Securities }}sec SecuritySource,{{- en }, nil } +{{- if not $.RequestOptionsEnabled }} type serverURLKey struct{} // WithServerURL sets context key to override server URL. @@ -98,6 +154,7 @@ func (c *Client) requestURL(ctx context.Context) *url.URL { } return u } +{{- end }} {{- range $op := $.Operations }} {{ template "client/operation" op_elem $op $ }} @@ -142,7 +199,8 @@ func NewWebhookClient(opts ...ClientOption) (*WebhookClient, error) { func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) {{ $op.Name }}(ctx context.Context {{- if $op.WebhookInfo }}, targetURL string{{ end }} {{- if $op.Request }}, request {{ $op.Request.GoType }}{{ end }} - {{- if $op.Params }}, params {{ $op.Name }}Params {{ end }}) {{ $op.Responses.ResultTuple "" "" }} { + {{- if $op.Params }}, params {{ $op.Name }}Params {{ end }} + {{- if $cfg.RequestOptionsEnabled }}, options ...RequestOption {{ end }}) {{ $op.Responses.ResultTuple "" "" }} { {{ if $op.Responses.DoPass }}res{{ else }}_{{ end }}, err := c.send{{ $op.Name }}(ctx {{- if $op.WebhookInfo }},targetURL{{ end -}} {{- if $op.Request }},request{{ end -}} @@ -154,7 +212,8 @@ func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) {{ $op.Name }}(ctx cont func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) send{{ $op.Name }}(ctx context.Context {{- if $op.WebhookInfo }}, targetURL string{{ end }} {{- if $op.Request }}, request {{ $op.Request.GoType }}{{ end }} - {{- if $op.Params }}, params {{ $op.Name }}Params {{ end }}) (res {{ $op.Responses.GoType }}, err error) { + {{- if $op.Params }}, params {{ $op.Name }}Params {{ end }} + {{- if $cfg.RequestOptionsEnabled }}, requestOptions ...RequestOption {{ end }}) (res {{ $op.Responses.GoType }}, err error) { {{- if and $op.Request $cfg.RequestValidationEnabled }}{{/* Request validation */}} {{- if $op.Request.Type.IsInterface }} @@ -238,14 +297,36 @@ func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) send{{ $op.Name }}(ctx }() {{- end }} + {{ if $cfg.RequestOptionsEnabled -}} + var reqCfg requestConfig + reqCfg.setDefaults(c.baseClient) + for _, o := range requestOptions { + o(&reqCfg) + } + {{- end }} + {{ if $otel }}stage = "BuildURL"{{ end }} {{- if $op.WebhookInfo }} - u, err := url.Parse(targetURL) - if err != nil { - return res, errors.Wrap(err, "parse target URL") - } - trimTrailingSlashes(u) + u, err := url.Parse(targetURL) + if err != nil { + return res, errors.Wrap(err, "parse target URL") + } + {{- if $cfg.RequestOptionsEnabled }} + if override := reqCfg.ServerURL; override != nil { + u = uri.Clone(override) + } + {{- end }} + trimTrailingSlashes(u) {{- else }} + {{- if $cfg.RequestOptionsEnabled }} + u := c.serverURL + if override := reqCfg.ServerURL; override != nil { + u = override + } + u = uri.Clone(u) + {{- else }} + u := uri.Clone(c.requestURL(ctx)) + {{- end }} {{- template "encode_path_parameters" $op }} {{- end }} @@ -317,13 +398,29 @@ func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) send{{ $op.Name }}(ctx } {{- end }} + {{ if $cfg.RequestOptionsEnabled -}} + if err := reqCfg.onRequest(r); err != nil { + return res, errors.Wrap(err, "edit request") + } + {{- end }} + {{ if $otel }}stage = "SendRequest"{{ end }} + {{- if $cfg.RequestOptionsEnabled }} + resp, err := reqCfg.Client.Do(r) + {{- else }} resp, err := c.cfg.Client.Do(r) + {{- end }} if err != nil { return res, errors.Wrap(err, "do request") } defer resp.Body.Close() + {{ if $cfg.RequestOptionsEnabled -}} + if err := reqCfg.onResponse(resp); err != nil { + return res, errors.Wrap(err, "edit response") + } + {{- end }} + {{ if $otel }}stage = "DecodeResponse"{{ end }} result, err := decode{{ $op.Name }}Response(resp) if err != nil { diff --git a/gen/_template/parameter_encode.tmpl b/gen/_template/parameter_encode.tmpl index f6b3a46e0..46c016fe0 100644 --- a/gen/_template/parameter_encode.tmpl +++ b/gen/_template/parameter_encode.tmpl @@ -1,5 +1,4 @@ {{ define "encode_path_parameters" }}{{/*gotype: github.com/ogen-go/ogen/gen/ir.Operation*/}} -u := uri.Clone(c.requestURL(ctx)) var pathParts [{{ len $.PathParts }}]string {{- range $idx, $part := $.PathParts }}{{/* Range over path parts */}} {{- if $part.Raw }} diff --git a/gen/features.go b/gen/features.go index faadf882d..4d0a62455 100644 --- a/gen/features.go +++ b/gen/features.go @@ -113,6 +113,10 @@ var ( "client/security/reentrant", `Enables client usage in security source implementations`, } + ClientRequestOptions = Feature{ + "client/request/options", + `Enables function options for client requests`, + } ClientRequestValidation = Feature{ "client/request/validation", `Enables validation of client requests`, @@ -152,6 +156,7 @@ var AllFeatures = []Feature{ WebhooksClient, WebhooksServer, ClientSecurityReentrant, + ClientRequestOptions, ClientRequestValidation, ServerResponseValidation, OgenOtel, diff --git a/gen/write.go b/gen/write.go index b0341f7e2..65c518387 100644 --- a/gen/write.go +++ b/gen/write.go @@ -42,6 +42,7 @@ type TemplateConfig struct { WebhookServerEnabled bool OpenTelemetryEnabled bool SecurityReentrantEnabled bool + RequestOptionsEnabled bool RequestValidationEnabled bool ResponseValidationEnabled bool @@ -270,6 +271,7 @@ func (g *Generator) WriteSource(fs FileSystem, pkgName string) error { WebhookServerEnabled: features.Has(WebhooksServer) && len(g.webhooks) > 0, OpenTelemetryEnabled: features.Has(OgenOtel), SecurityReentrantEnabled: features.Has(ClientSecurityReentrant), + RequestOptionsEnabled: features.Has(ClientRequestOptions), RequestValidationEnabled: features.Has(ClientRequestValidation), ResponseValidationEnabled: features.Has(ServerResponseValidation), // Unused for now.