Skip to content

Commit

Permalink
feat: feat add support for information http code
Browse files Browse the repository at this point in the history
  • Loading branch information
tigerwill90 committed Oct 28, 2024
1 parent 3bea8fb commit 540c883
Showing 1 changed file with 54 additions and 12 deletions.
66 changes: 54 additions & 12 deletions response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,16 @@ const notWritten = -1

type recorder struct {
http.ResponseWriter
size int
status int
size int
status int
hijacked bool
}

func (r *recorder) reset(w http.ResponseWriter) {
r.ResponseWriter = w
r.size = notWritten
r.status = http.StatusOK
r.hijacked = false
}

// Status recorded after Write or WriteHeader.
Expand Down Expand Up @@ -105,12 +107,25 @@ func (r *recorder) Unwrap() http.ResponseWriter {
// WriteHeader sends an HTTP response header with the provided
// status code. See http.ResponseWriter for more details.
func (r *recorder) WriteHeader(code int) {
if r.Written() {
if r.hijacked {
caller := relevantCaller()
log.Printf("http: response.WriteHeader on hijacked connection from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
return
}
if r.size != notWritten {
caller := relevantCaller()
log.Printf("http: superfluous response.WriteHeader call from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
return
}

// Handle informational headers.
// We shouldn't send any further headers after 101 Switching Protocols,
// so it takes the non-informational path.
if code >= 100 && code <= 199 && code != http.StatusSwitchingProtocols {
r.ResponseWriter.WriteHeader(code)
return
}

r.size = 0
r.status = code
r.ResponseWriter.WriteHeader(code)
Expand All @@ -119,10 +134,19 @@ func (r *recorder) WriteHeader(code int) {
// Write writes the data to the connection as part of an HTTP reply.
// See http.ResponseWriter for more details.
func (r *recorder) Write(buf []byte) (n int, err error) {
if !r.Written() {
if r.hijacked {
if len(buf) > 0 {
caller := relevantCaller()
log.Printf("http: response.Write on hijacked connection from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
}
return 0, http.ErrHijacked
}

if r.size == notWritten {
r.size = 0
r.ResponseWriter.WriteHeader(r.status)
}

n, err = r.ResponseWriter.Write(buf)
r.size += n
return
Expand All @@ -132,7 +156,15 @@ func (r *recorder) Write(buf []byte) (n int, err error) {
// as part of an HTTP reply. The method returns the number of bytes written
// and an error, if any.
func (r *recorder) WriteString(s string) (n int, err error) {
if !r.Written() {
if r.hijacked {
if len(s) > 0 {
caller := relevantCaller()
log.Printf("http: response.Write on hijacked connection from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
}
return 0, http.ErrHijacked
}

if r.size == notWritten {
r.size = 0
r.ResponseWriter.WriteHeader(r.status)
}
Expand All @@ -145,14 +177,15 @@ func (r *recorder) WriteString(s string) (n int, err error) {
// ReadFrom reads data from src until EOF or error. The return value n is the number of bytes read.
// Any error except EOF encountered during the read is also returned.
func (r *recorder) ReadFrom(src io.Reader) (n int64, err error) {
if !r.Written() {
r.size = 0
}

if rf, ok := r.ResponseWriter.(io.ReaderFrom); ok {
n, err = rf.ReadFrom(src)
r.size += int(n)
return
if err == nil {
if r.size == notWritten {
r.size = 0
}
r.size += int(n)
}
return n, err
}

// Fallback in compatibility mode.
Expand All @@ -168,8 +201,14 @@ func (r *recorder) ReadFrom(src io.Reader) (n int64, err error) {
func (r *recorder) FlushError() error {
switch flusher := r.ResponseWriter.(type) {
case interface{ FlushError() error }:
if r.size == notWritten {
r.WriteHeader(r.status)
}
return flusher.FlushError()
case http.Flusher:
if r.size == notWritten {
r.WriteHeader(r.status)
}
flusher.Flush()
return nil
default:
Expand All @@ -190,6 +229,7 @@ func (r *recorder) Push(target string, opts *http.PushOptions) error {
// an error matching http.ErrNotSupported. See http.Hijacker for more details.
func (r *recorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := r.ResponseWriter.(http.Hijacker); ok {
r.hijacked = true
return hijacker.Hijack()
}
return nil, nil, ErrNotSupported()
Expand Down Expand Up @@ -258,7 +298,9 @@ func relevantCaller() runtime.Frame {
return frame
}

var errHttpNotSupported = fmt.Errorf("%w", http.ErrNotSupported)

// ErrNotSupported returns an error that Is ErrNotSupported, but is not == to it.
func ErrNotSupported() error {
return fmt.Errorf("%w", http.ErrNotSupported)
return errHttpNotSupported
}

0 comments on commit 540c883

Please sign in to comment.