Skip to content

Commit

Permalink
fixed memory read issue in http adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
phin1x committed Mar 25, 2023
1 parent 95e159f commit e47a6ab
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 31 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
idea
cmd
.idea
cmd
tests
22 changes: 16 additions & 6 deletions adapter-http.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ func (h *HttpAdapter) SendRequest(url string, req *Request, additionalResponseDa
return nil, err
}

var body io.Reader
size := len(payload)

var body io.Reader
if req.File != nil && req.FileSize != -1 {
size += req.FileSize

body = io.MultiReader(bytes.NewBuffer(payload), req.File)
} else {
body = bytes.NewBuffer(payload)
Expand Down Expand Up @@ -79,13 +77,25 @@ func (h *HttpAdapter) SendRequest(url string, req *Request, additionalResponseDa
}
}

resp, err := NewResponseDecoder(httpResp.Body).Decode(additionalResponseData)
// buffer response to avoid read issues
buf := new(bytes.Buffer)
if httpResp.ContentLength > 0 {
buf.Grow(int(httpResp.ContentLength))
}
if _, err := io.Copy(buf, httpResp.Body); err != nil {
return nil, fmt.Errorf("unable to buffer response: %w", err)
}

ippResp, err := NewResponseDecoder(buf).Decode(additionalResponseData)
if err != nil {
return nil, err
}

err = resp.CheckForErrors()
return resp, err
if err = ippResp.CheckForErrors(); err != nil {
return nil, fmt.Errorf("received error IPP response: %w", err)
}

return ippResp, nil
}

func (h *HttpAdapter) GetHttpUri(namespace string, object interface{}) string {
Expand Down
49 changes: 26 additions & 23 deletions adapter-socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ func NewSocketAdapter(host string, useTLS bool) *SocketAdapter {
}
}

//DoRequest performs the given IPP request to the given URL, returning the IPP response or an error if one occurred.
//Additional data will be written to an io.Writer if additionalData is not nil
// SendRequest performs the given IPP request to the given URL, returning the IPP response or an error if one occurred.
// Additional data will be written to an io.Writer if additionalData is not nil
func (h *SocketAdapter) SendRequest(url string, r *Request, additionalData io.Writer) (*Response, error) {
for i := 0; i < h.RequestRetryLimit; i++ {
// encode request
payload, err := r.Encode()
if err != nil {
return nil, fmt.Errorf("unable to encode IPP request: %v", err)
return nil, fmt.Errorf("unable to encode IPP request: %w", err)
}

var body io.Reader
Expand All @@ -64,7 +64,7 @@ func (h *SocketAdapter) SendRequest(url string, r *Request, additionalData io.Wr

req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, fmt.Errorf("unable to create HTTP request: %v", err)
return nil, fmt.Errorf("unable to create HTTP request: %w", err)
}

sock, err := h.GetSocket()
Expand All @@ -91,39 +91,42 @@ func (h *SocketAdapter) SendRequest(url string, r *Request, additionalData io.Wr
}

// send request
resp, err := unixClient.Do(req)
httpResp, err := unixClient.Do(req)
if err != nil {
return nil, fmt.Errorf("unable to perform HTTP request: %v", err)
return nil, fmt.Errorf("unable to perform HTTP request: %w", err)
}

if resp.StatusCode == http.StatusUnauthorized {
if httpResp.StatusCode == http.StatusUnauthorized {
// retry with newly generated cert
resp.Body.Close()
httpResp.Body.Close()
continue
}

if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, fmt.Errorf("server did not return Status OK: %d", resp.StatusCode)
if httpResp.StatusCode != http.StatusOK {
httpResp.Body.Close()
return nil, fmt.Errorf("server did not return Status OK: %d", httpResp.StatusCode)
}

// buffer response to avoid read issues
buf := new(bytes.Buffer)
if _, err := io.Copy(buf, resp.Body); err != nil {
resp.Body.Close()
return nil, fmt.Errorf("unable to buffer response: %v", err)
if httpResp.ContentLength > 0 {
buf.Grow(int(httpResp.ContentLength))
}
if _, err := io.Copy(buf, httpResp.Body); err != nil {
httpResp.Body.Close()
return nil, fmt.Errorf("unable to buffer response: %w", err)
}

resp.Body.Close()
httpResp.Body.Close()

// decode reply
ippResp, err := NewResponseDecoder(bytes.NewReader(buf.Bytes())).Decode(additionalData)
ippResp, err := NewResponseDecoder(buf).Decode(additionalData)
if err != nil {
return nil, fmt.Errorf("unable to decode IPP response: %v", err)
return nil, fmt.Errorf("unable to decode IPP response: %w", err)
}

if err = ippResp.CheckForErrors(); err != nil {
return nil, fmt.Errorf("received error IPP response: %v", err)
return nil, fmt.Errorf("received error IPP response: %w", err)
}

return ippResp, nil
Expand All @@ -132,7 +135,7 @@ func (h *SocketAdapter) SendRequest(url string, r *Request, additionalData io.Wr
return nil, errors.New("request retry limit exceeded")
}

//GetSocket returns the path to the cupsd socket by searching SocketSearchPaths
// GetSocket returns the path to the cupsd socket by searching SocketSearchPaths
func (h *SocketAdapter) GetSocket() (string, error) {
for _, path := range h.SocketSearchPaths {
fi, err := os.Stat(path)
Expand All @@ -142,7 +145,7 @@ func (h *SocketAdapter) GetSocket() (string, error) {
} else if os.IsPermission(err) {
return "", errors.New("unable to access socket: Access denied")
}
return "", fmt.Errorf("unable to access socket: %v", err)
return "", fmt.Errorf("unable to access socket: %w", err)
}

if fi.Mode()&os.ModeSocket != 0 {
Expand All @@ -153,7 +156,7 @@ func (h *SocketAdapter) GetSocket() (string, error) {
return "", SocketNotFoundError
}

//GetCert returns the current CUPs authentication certificate by searching CertSearchPaths
// GetCert returns the current CUPs authentication certificate by searching CertSearchPaths
func (h *SocketAdapter) GetCert() (string, error) {
for _, path := range h.CertSearchPaths {
f, err := os.Open(path)
Expand All @@ -163,13 +166,13 @@ func (h *SocketAdapter) GetCert() (string, error) {
} else if os.IsPermission(err) {
return "", errors.New("unable to access certificate: Access denied")
}
return "", fmt.Errorf("unable to access certificate: %v", err)
return "", fmt.Errorf("unable to access certificate: %w", err)
}
defer f.Close()

buf := new(bytes.Buffer)
if _, err := io.Copy(buf, f); err != nil {
return "", fmt.Errorf("unable to access certificate: %v", err)
return "", fmt.Errorf("unable to access certificate: %w", err)
}
return buf.String(), nil
}
Expand Down

0 comments on commit e47a6ab

Please sign in to comment.