Skip to content

Commit

Permalink
Propagate RST errors (#17)
Browse files Browse the repository at this point in the history
fixes part of #16
  • Loading branch information
nitely authored Jul 13, 2024
1 parent 18954a5 commit 1a70e8c
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 37 deletions.
42 changes: 19 additions & 23 deletions src/hyperx/clientserver.nim
Original file line number Diff line number Diff line change
Expand Up @@ -900,15 +900,16 @@ proc read(stream: Stream): Future[Frame] {.async.} =
if frm.typ == frmtRstStream:
for frm2 in stream.msgs:
stream.doTransitionRecv frm2
stream.error = newStrmError(frm.errorCode, hxRemoteErr)
stream.close()
raise newGotRstError(frm.errorCode)
raise newStrmError(frm.errorCode, hxRemoteErr)
if frm.typ == frmtPushPromise:
raise newStrmError(errProtocolError)
raise newStrmError errProtocolError
if frm.typ == frmtWindowUpdate:
check frm.windowSizeInc > 0, newStrmError(errProtocolError)
check frm.windowSizeInc <= stgMaxWindowSize, newStrmError(errProtocolError)
check frm.windowSizeInc > 0, newStrmError errProtocolError
check frm.windowSizeInc <= stgMaxWindowSize, newStrmError errProtocolError
check stream.peerWindow <= stgMaxWindowSize.int32 - frm.windowSizeInc.int32,
newStrmError(errFlowControlError)
newStrmError errFlowControlError
stream.peerWindow += frm.windowSizeInc.int32
if not stream.peerWindowUpdateSig.isClosed:
stream.peerWindowUpdateSig.trigger()
Expand All @@ -923,12 +924,11 @@ proc recvHeadersTaskNaked(strm: ClientStream) {.async.} =
var frm: Frame
while true:
frm = await strm.stream.read()
check frm.typ == frmtHeaders, newStrmError(errProtocolError)
check frm.typ == frmtHeaders, newStrmError errProtocolError
validateHeaders(frm.payload, strm.client.typ)
if strm.client.typ == ctServer:
break
check frm.payload.len >= statusLineLen, newStrmError(errProtocolError)
#check frm.payload.startsWith ":status: ", newStrmError(errProtocolError)
check frm.payload.len >= statusLineLen, newStrmError errProtocolError
if frm.payload[9] == '1'.byte:
check frmfEndStream notin frm.flags, newStrmError(errProtocolError)
else:
Expand Down Expand Up @@ -995,11 +995,6 @@ proc recvTask(strm: ClientStream) {.async.} =
discard await stream.read()
except QueueClosedError:
discard
except GotRstError as err:
debugInfo err.getStackTrace()
debugInfo err.msg
stream.error = err
raise err
except ConnError as err:
debugInfo err.getStackTrace()
debugInfo err.msg
Expand All @@ -1015,9 +1010,10 @@ proc recvTask(strm: ClientStream) {.async.} =
debugInfo err.msg
stream.error = err
strm.close()
await client.sendSilently newRstStreamFrame(
stream.id.FrmSid, err.code.int
)
if err.typ == hxLocalErr:
await client.sendSilently newRstStreamFrame(
stream.id.FrmSid, err.code.int
)
raise err
except CatchableError as err:
debugInfo err.getStackTrace()
Expand Down Expand Up @@ -1045,7 +1041,7 @@ proc recvHeaders*(strm: ClientStream, data: ref string) {.async.} =
if strm.stream.error != nil:
debugInfo strm.stream.error.getStackTrace()
debugInfo strm.stream.error.msg
raise newStrmError(strm.stream.error.code)
raise newError strm.stream.error
raise err

proc recvBodyNaked(strm: ClientStream, data: ref string) {.async.} =
Expand Down Expand Up @@ -1086,7 +1082,7 @@ proc recvBody*(strm: ClientStream, data: ref string) {.async.} =
if strm.stream.error != nil:
debugInfo strm.stream.error.getStackTrace()
debugInfo strm.stream.error.msg
raise newStrmError(strm.stream.error.code)
raise newError strm.stream.error
raise err

func recvTrailers*(strm: ClientStream): string =
Expand Down Expand Up @@ -1130,7 +1126,7 @@ proc sendHeaders*(
if strm.stream.error != nil:
debugInfo strm.stream.error.getStackTrace()
debugInfo strm.stream.error.msg
raise newStrmError(strm.stream.error.code)
raise newError strm.stream.error
raise err

proc sendHeaders*(
Expand Down Expand Up @@ -1161,8 +1157,8 @@ proc sendBodyNaked(
while stream.peerWindow <= 0:
await stream.peerWindowUpdateSig.waitFor()
while client.peerWindow <= 0:
check stream.state != strmClosed,
newStrmError(stream.errCodeOrDefault errStreamClosed)
check stream.state != strmClosed,
newErrorOrDefault(stream.error, newStrmError errStreamClosed)
await client.peerWindowUpdateSig.waitFor()
let peerWindow = min(client.peerWindow, stream.peerWindow)
dataIdxB = min(dataIdxA+min(peerWindow, stgInitialMaxFrameSize.int), L)
Expand All @@ -1177,7 +1173,7 @@ proc sendBodyNaked(
stream.peerWindow -= frm.payloadLen.int32
client.peerWindow -= frm.payloadLen.int32
check stream.state != strmClosed,
newStrmError(stream.errCodeOrDefault errStreamClosed)
newErrorOrDefault(stream.error, newStrmError errStreamClosed)
await client.write frm
dataIdxA = dataIdxB
# allow sending empty data frame
Expand All @@ -1199,7 +1195,7 @@ proc sendBody*(
if strm.stream.error != nil:
debugInfo strm.stream.error.getStackTrace()
debugInfo strm.stream.error.msg
raise newStrmError(strm.stream.error.code)
raise newError strm.stream.error
raise err

template with*(strm: ClientStream, body: untyped): untyped =
Expand Down
28 changes: 19 additions & 9 deletions src/hyperx/errors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,17 @@ func toErrorCode(e: uint32): ErrorCode {.raises: [].} =

# XXX remove ConnError and StrmError; expose code in Hyperx*
type
HyperxErrTyp* = enum
hxLocalErr, hxRemoteErr
HyperxError* = object of CatchableError
HyperxConnError* = object of HyperxError
HyperxStrmError* = object of HyperxError
ConnClosedError* = object of HyperxConnError
ConnError* = object of HyperxConnError
code*: ErrorCode
StrmError* = object of HyperxStrmError
typ*: HyperxErrTyp
code*: ErrorCode
GotRstError* = object of StrmError
QueueError* = object of HyperxError
QueueClosedError* = object of QueueError

Expand All @@ -69,14 +71,22 @@ func newConnClosedError*(): ref ConnClosedError {.raises: [].} =
func newConnError*(errCode: ErrorCode): ref ConnError {.raises: [].} =
result = (ref ConnError)(code: errCode, msg: "Connection Error: " & $errCode)

func newStrmError*(errCode: ErrorCode): ref StrmError {.raises: [].} =
result = (ref StrmError)(code: errCode, msg: "Stream Error: " & $errCode)
func newStrmError*(errCode: ErrorCode, typ = hxLocalErr): ref StrmError {.raises: [].} =
let msg = case typ
of hxLocalErr: "Stream Error: " & $errCode
of hxRemoteErr: "Got Rst Error: " & $errCode
result = (ref StrmError)(typ: typ, code: errCode, msg: msg)

func newStrmError*(errCode: uint32): ref StrmError {.raises: [].} =
result = newStrmError(errCode.toErrorCode)
func newStrmError*(errCode: uint32, typ = hxLocalErr): ref StrmError {.raises: [].} =
result = newStrmError(errCode.toErrorCode, typ)

func newGotRstError*(errCode: ErrorCode): ref GotRstError {.raises: [].} =
result = (ref GotRstError)(code: errCode, msg: "Got Rst Error: " & $errCode)
func newError*(err: ref StrmError): ref StrmError {.raises: [].} =
result = (ref StrmError)(
typ: err.typ, code: err.code, msg: err.msg
)

func newGotRstError*(errCode: uint32): ref GotRstError {.raises: [].} =
result = newGotRstError(errCode.toErrorCode)
func newErrorOrDefault*(err, default: ref StrmError): ref StrmError {.raises: [].} =
if err != nil:
return newError(err)
else:
return default
5 changes: 0 additions & 5 deletions src/hyperx/stream.nim
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,6 @@ proc close*(stream: Stream) {.raises: [].} =
stream.msgs.close()
stream.peerWindowUpdateSig.close()

func errCodeOrDefault*(stream: Stream, default: ErrorCode): ErrorCode =
if stream.error != nil:
return stream.error.code
return default

type
StreamsClosedError* = object of HyperxError
Streams* = object
Expand Down
4 changes: 4 additions & 0 deletions src/hyperx/utils.nim
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ func serverHeadersValidation*(s: openArray[byte]) {.raises: [StrmError].} =
check hasPath, newStrmError(errProtocolError)

func clientHeadersValidation*(s: openArray[byte]) {.raises: [StrmError].} =
var hasStatus = false
var regularFieldCount = 0
for (nn, vv) in headersIt(s):
if s[nn.a].char != ':':
Expand All @@ -181,6 +182,9 @@ func clientHeadersValidation*(s: openArray[byte]) {.raises: [StrmError].} =
check regularFieldCount == 0, newStrmError(errProtocolError)
check toOpenArray(s, nn.a, nn.b) == ":status",
newStrmError(errProtocolError)
check not hasStatus, newStrmError(errProtocolError)
hasStatus = true
check hasStatus, newStrmError(errProtocolError)

func validateTrailers*(s: openArray[byte]) {.raises: [StrmError].} =
for (nn, _) in headersIt(s):
Expand Down

0 comments on commit 1a70e8c

Please sign in to comment.