Skip to content

Commit

Permalink
fix stream methods (#15)
Browse files Browse the repository at this point in the history
* fix

* fix
  • Loading branch information
henomis authored Jan 9, 2024
1 parent 6e4ff8f commit da56d3b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
4 changes: 3 additions & 1 deletion examples/cmd/stream/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ func (r *CreatePostResponse) SetHeaders(headers restclientgo.Headers) error { re
func main() {

var response string
restClient := restclientgo.New("http://localhost:11434/api").WithStream(
restClient := restclientgo.New("http://localhost:11434/api")

restClient.SetStreamCallback(
func(data []byte) error {
var createPostResponse CreatePostResponse

Expand Down
13 changes: 6 additions & 7 deletions restclientgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import (

const maxStreamBufferSize = 512 * 1024

type StreamDecodeFn func([]byte) error
type StreamCallback func([]byte) error

type RestClient struct {
httpClient *http.Client
endpoint string
requestModifier func(*http.Request) *http.Request
forceDecodeOnError bool
streamDecodeFn StreamDecodeFn
streamCallback StreamCallback
}

type Error string
Expand Down Expand Up @@ -105,9 +105,8 @@ func (r *RestClient) WithDecodeOnError(decodeOnError bool) *RestClient {
return r
}

func (r *RestClient) WithStream(streamDecodeFn StreamDecodeFn) *RestClient {
r.streamDecodeFn = streamDecodeFn
return r
func (r *RestClient) SetStreamCallback(streamCallback StreamCallback) {
r.streamCallback = streamCallback
}

func (r *RestClient) SetEndpoint(endpoint string) {
Expand Down Expand Up @@ -212,7 +211,7 @@ func (r *RestClient) do(ctx context.Context, method httpMethod, request Request,
return err
}

if r.streamDecodeFn == nil {
if r.streamCallback == nil {
err = response.Decode(httpResponse.Body)
} else {
err = r.decodeBody(httpResponse.Body)
Expand All @@ -231,7 +230,7 @@ func (r *RestClient) decodeBody(body io.Reader) error {
scanner.Buffer(scanBuf, maxStreamBufferSize)

for scanner.Scan() {
err := r.streamDecodeFn(scanner.Bytes())
err := r.streamCallback(scanner.Bytes())
if err != nil {
return err
}
Expand Down

0 comments on commit da56d3b

Please sign in to comment.