forked from improbable-eng/grpc-web
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgrpc_web_response.go
138 lines (122 loc) · 3.67 KB
/
grpc_web_response.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
//Copyright 2017 Improbable. All Rights Reserved.
// See LICENSE for licensing terms.
package grpcweb
import (
"bytes"
"encoding/binary"
"net/http"
"strings"
"golang.org/x/net/http2"
)
// grpcWebResponse implements http.ResponseWriter.
type grpcWebResponse struct {
wroteHeaders bool
wroteBody bool
headers http.Header
wrapped http.ResponseWriter
}
func newGrpcWebResponse(resp http.ResponseWriter) *grpcWebResponse {
return &grpcWebResponse{headers: make(http.Header), wrapped: resp}
}
func (w *grpcWebResponse) Header() http.Header {
return w.headers
}
func (w *grpcWebResponse) Write(b []byte) (int, error) {
w.wroteBody = true
return w.wrapped.Write(b)
}
func (w *grpcWebResponse) WriteHeader(code int) {
w.copyJustHeadersToWrapped()
w.writeCorsExposedHeaders()
w.wrapped.WriteHeader(code)
w.wroteHeaders = true
}
func (w *grpcWebResponse) Flush() {
if w.wroteHeaders || w.wroteBody {
// Work around the fact that WriteHeader and a call to Flush would have caused a 200 response.
// This is the case when there is no payload.
w.wrapped.(http.Flusher).Flush()
}
}
func (w *grpcWebResponse) CloseNotify() <-chan bool {
return w.wrapped.(http.CloseNotifier).CloseNotify()
}
func (w *grpcWebResponse) copyJustHeadersToWrapped() {
wrappedHeader := w.wrapped.Header()
for k, vv := range w.headers {
// Skip the pre-annoucement of Trailer headers. Don't add them to the response headers.
if strings.ToLower(k) == "trailer" {
continue
}
for _, v := range vv {
wrappedHeader.Add(k, v)
}
}
}
func (w *grpcWebResponse) finishRequest(req *http.Request) {
if w.wroteHeaders || w.wroteBody {
w.copyTrailersToPayload()
} else {
w.copyTrailersAndHeadersToWrapped()
}
}
func (w *grpcWebResponse) copyTrailersAndHeadersToWrapped() {
w.wroteHeaders = true
wrappedHeader := w.wrapped.Header()
for k, vv := range w.headers {
// Skip the pre-annoucement of Trailer headers. Don't add them to the response headers.
if strings.ToLower(k) == "trailer" {
continue
}
// Skip the Trailer prefix
if strings.HasPrefix(k, http2.TrailerPrefix) {
k = k[len(http2.TrailerPrefix):]
}
for _, v := range vv {
wrappedHeader.Add(k, v)
}
}
w.writeCorsExposedHeaders()
w.wrapped.WriteHeader(http.StatusOK)
w.wrapped.(http.Flusher).Flush()
}
func (w *grpcWebResponse) writeCorsExposedHeaders() {
// These cors handlers are added to the *response*, not a preflight.
knownHeaders := []string{}
for h := range w.wrapped.Header() {
knownHeaders = append(knownHeaders, http.CanonicalHeaderKey(h))
}
w.wrapped.Header().Set("Access-Control-Expose-Headers", strings.Join(knownHeaders, ", "))
}
func (w *grpcWebResponse) copyTrailersToPayload() {
trailers := w.extractTrailerHeaders()
trailerBuffer := new(bytes.Buffer)
trailers.Write(trailerBuffer)
trailerGrpcDataHeader := []byte{1 << 7, 0, 0, 0, 0} // MSB=1 indicates this is a trailer data frame.
binary.BigEndian.PutUint32(trailerGrpcDataHeader[1:5], uint32(trailerBuffer.Len()))
w.wrapped.Write(trailerGrpcDataHeader)
w.wrapped.Write(trailerBuffer.Bytes())
w.wrapped.(http.Flusher).Flush()
}
func (w *grpcWebResponse) extractTrailerHeaders() http.Header {
flushedHeaders := w.wrapped.Header()
trailerHeaders := make(http.Header)
for k, vv := range w.headers {
// Skip the pre-annoucement of Trailer headers. Don't add them to the response headers.
if strings.ToLower(k) == "trailer" {
continue
}
// Skip existing headers that were already sent.
if _, exists := flushedHeaders[k]; exists {
continue
}
// Skip the Trailer prefix
if strings.HasPrefix(k, http2.TrailerPrefix) {
k = k[len(http2.TrailerPrefix):]
}
for _, v := range vv {
trailerHeaders.Add(k, v)
}
}
return trailerHeaders
}