diff --git a/README.md b/README.md index 1a83150..e3538cc 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,11 @@ func (r *MyResponse) SetStatusCode(code int) error { func (r *MyResponse) AcceptContentType() string { // Return the accepted content type of the response } + +// Optional. Implement this method if you want to stream the response body. +func (r *MyResponse) StreamCallback() StreamCallback { + // Return the stream callback if any. +} ``` ## Usage diff --git a/examples/cmd/stream/main.go b/examples/cmd/stream/main.go index 5badf01..f519df1 100644 --- a/examples/cmd/stream/main.go +++ b/examples/cmd/stream/main.go @@ -35,11 +35,12 @@ func (r *createPostRequest) ContentType() string { } type CreatePostResponse struct { - HTTPStatusCode int `json:"-"` - Model string `json:"model"` - CreatedAt string `json:"created_at"` - Response string `json:"response"` - Done bool `json:"done"` + HTTPStatusCode int `json:"-"` + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Response string `json:"response"` + Done bool `json:"done"` + StreamCallbackFn restclientgo.StreamCallback `json:"-"` } func (r *CreatePostResponse) Decode(body io.Reader) error { @@ -61,34 +62,34 @@ func (r *CreatePostResponse) SetStatusCode(code int) error { } func (r *CreatePostResponse) SetHeaders(headers restclientgo.Headers) error { return nil } +func (r *CreatePostResponse) StreamCallback() restclientgo.StreamCallback { + return r.StreamCallbackFn +} func main() { var response string restClient := restclientgo.New("http://localhost:11434/api") - restClient.SetStreamCallback( - func(data []byte) error { - var createPostResponse CreatePostResponse - - err := json.Unmarshal(data, &createPostResponse) - if err != nil { - return err - } - - response += createPostResponse.Response - fmt.Printf(createPostResponse.Response) - - return nil - }, - ) - restClient.SetRequestModifier(func(req *http.Request) *http.Request { req.Header.Set("Accept", "application/json") return req }) var createPostResponse CreatePostResponse + createPostResponse.StreamCallbackFn = func(data []byte) error { + var createPostResponse CreatePostResponse + + err := json.Unmarshal(data, &createPostResponse) + if err != nil { + return err + } + + response += createPostResponse.Response + fmt.Printf(createPostResponse.Response) + + return nil + } err := restClient.Post( context.Background(), diff --git a/restclientgo.go b/restclientgo.go index 8f03d3e..4048e98 100644 --- a/restclientgo.go +++ b/restclientgo.go @@ -18,7 +18,6 @@ type RestClient struct { endpoint string requestModifier func(*http.Request) *http.Request forceDecodeOnError bool - streamCallback StreamCallback } type Error string @@ -69,6 +68,11 @@ type Response interface { SetHeaders(headers Headers) error } +type Streamable interface { + // StreamCallback get the stream callback if any. + StreamCallback() StreamCallback +} + // New creates a new RestClient. func New(endpoint string) *RestClient { return &RestClient{ @@ -105,10 +109,6 @@ func (r *RestClient) WithDecodeOnError(decodeOnError bool) *RestClient { return r } -func (r *RestClient) SetStreamCallback(streamCallback StreamCallback) { - r.streamCallback = streamCallback -} - func (r *RestClient) SetEndpoint(endpoint string) { r.endpoint = endpoint } @@ -142,6 +142,7 @@ func (r *RestClient) Patch(ctx context.Context, request Request, response Respon return r.do(ctx, methodPatch, request, response) } +//nolint:gocognit func (r *RestClient) do(ctx context.Context, method httpMethod, request Request, response Response) error { requestPath, err := request.Path() if err != nil { @@ -206,16 +207,17 @@ func (r *RestClient) do(ctx context.Context, method httpMethod, request Request, return nil } - err = r.matchContentType(httpResponse, response) + err = matchContentType(httpResponse, response) if err != nil { return err } - if r.streamCallback == nil { - err = response.Decode(httpResponse.Body) + if streamable, isStreamable := response.(Streamable); isStreamable && streamable.StreamCallback() != nil { + err = stream(streamable.StreamCallback(), httpResponse.Body) } else { - err = r.decodeBody(httpResponse.Body) + err = response.Decode(httpResponse.Body) } + if err != nil { return fmt.Errorf("%w: %w", ErrResponseDecode, err) } @@ -223,23 +225,7 @@ func (r *RestClient) do(ctx context.Context, method httpMethod, request Request, return nil } -func (r *RestClient) decodeBody(body io.Reader) error { - scanner := bufio.NewScanner(body) - - scanBuf := make([]byte, 0, maxStreamBufferSize) - scanner.Buffer(scanBuf, maxStreamBufferSize) - - for scanner.Scan() { - err := r.streamCallback(scanner.Bytes()) - if err != nil { - return err - } - } - - return nil -} - -func (r *RestClient) matchContentType(httpResponse *http.Response, response Response) error { +func matchContentType(httpResponse *http.Response, response Response) error { contentTypeToMatch := response.AcceptContentType() contentType := httpResponse.Header.Get("Content-Type") @@ -255,3 +241,19 @@ func (r *RestClient) matchContentType(httpResponse *http.Response, response Resp return ErrNoContentType } + +func stream(streamCallback StreamCallback, body io.Reader) error { + scanner := bufio.NewScanner(body) + + scanBuf := make([]byte, 0, maxStreamBufferSize) + scanner.Buffer(scanBuf, maxStreamBufferSize) + + for scanner.Scan() { + err := streamCallback(scanner.Bytes()) + if err != nil { + return err + } + } + + return nil +}