diff --git a/http2/server.go b/http2/server.go index 4bb0d66..5e2583c 100644 --- a/http2/server.go +++ b/http2/server.go @@ -2645,8 +2645,7 @@ func checkWriteHeaderCode(code int) { // Issue 22880: require valid WriteHeader status codes. // For now we only enforce that it's three digits. // In the future we might block things over 599 (600 and above aren't defined - // at http://httpwg.org/specs/rfc7231.html#status.codes) - // and we might block under 200 (once we have more mature 1xx support). + // at http://httpwg.org/specs/rfc7231.html#status.codes). // But for now any three digits. // // We used to send "HTTP/1.1 000 0" on the wire in responses but there's @@ -2667,13 +2666,33 @@ func (w *responseWriter) WriteHeader(code int) { } func (rws *responseWriterState) writeHeader(code int) { - if !rws.wroteHeader { - checkWriteHeaderCode(code) - rws.wroteHeader = true - rws.status = code - if len(rws.handlerHeader) > 0 { - rws.snapHeader = cloneHeader(rws.handlerHeader) + if rws.wroteHeader { + return + } + + checkWriteHeaderCode(code) + + // Handle informational headers + if code >= 100 && code <= 199 { + // Per RFC 8297 we must not clear the current header map + h := rws.handlerHeader + + if rws.conn.writeHeaders(rws.stream, &writeResHeaders{ + streamID: rws.stream.id, + httpResCode: code, + h: h, + endStream: rws.handlerDone && !rws.hasTrailers(), + }) != nil { + rws.dirty = true } + + return + } + + rws.wroteHeader = true + rws.status = code + if len(rws.handlerHeader) > 0 { + rws.snapHeader = cloneHeader(rws.handlerHeader) } } diff --git a/http2/server_test.go b/http2/server_test.go index 46ac6ee..ef5ad5b 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4356,3 +4356,92 @@ func TestNoErrorLoggedOnPostAfterGOAWAY(t *testing.T) { t.Error("got protocol error") } } + +func TestServerSendsProcessing(t *testing.T) { + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + w.WriteHeader(http.StatusProcessing) + w.Write([]byte("stuff")) + + return nil + }, func(st *serverTester) { + getSlash(st) + hf := st.wantHeaders() + goth := st.decodeHeader(hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "102"}, + } + + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got = %q; want %q", goth, wanth) + } + + hf = st.wantHeaders() + goth = st.decodeHeader(hf.HeaderBlockFragment()) + wanth = [][2]string{ + {":status", "200"}, + {"content-type", "text/plain; charset=utf-8"}, + {"content-length", "5"}, + } + + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got = %q; want %q", goth, wanth) + } + }) +} + +func TestServerSendsEarlyHints(t *testing.T) { + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + h := w.Header() + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + w.Write([]byte("stuff")) + + return nil + }, func(st *serverTester) { + getSlash(st) + hf := st.wantHeaders() + goth := st.decodeHeader(hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "103"}, + {"link", "; rel=preload; as=style"}, + {"link", "; rel=preload; as=script"}, + } + + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got = %q; want %q", goth, wanth) + } + + hf = st.wantHeaders() + goth = st.decodeHeader(hf.HeaderBlockFragment()) + wanth = [][2]string{ + {":status", "103"}, + {"link", "; rel=preload; as=style"}, + {"link", "; rel=preload; as=script"}, + {"link", "; rel=preload; as=script"}, + } + + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got = %q; want %q", goth, wanth) + } + + hf = st.wantHeaders() + goth = st.decodeHeader(hf.HeaderBlockFragment()) + wanth = [][2]string{ + {":status", "200"}, + {"link", "; rel=preload; as=style"}, + {"link", "; rel=preload; as=script"}, + {"link", "; rel=preload; as=script"}, + {"content-type", "text/plain; charset=utf-8"}, + {"content-length", "5"}, + } + + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got = %q; want %q", goth, wanth) + } + }) +}