Skip to content

Commit

Permalink
optional max payload size for ByteToMessageDecoder (#957)
Browse files Browse the repository at this point in the history
* optional max payload size for ByteToMessageDecoder

Motivation:

ByteToMessageDecoder aggregate data in memory as part of their normal operation. the ability to limit how much they aggregate is critical in many real-life applications

Modifications:

* add optional maximumBufferSize argument to ByteToMessageDecoder initializer
* test for buffer size when maximumBufferSize is set and throw ByteToMessageDecoderError.payloadTooLarge error

Result:

users can limit how much memory ByteToMessageDecoder takes and handle the exception on their end
  • Loading branch information
tomerd authored and Lukasa committed Apr 11, 2019
1 parent c90b159 commit e32a436
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 3 deletions.
22 changes: 19 additions & 3 deletions Sources/NIO/Codec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ public enum ByteToMessageDecoderError: Error {
case leftoverDataWhenDone(ByteBuffer)
}

// TODO (tomer): Merge into ByteToMessageDecoderError next major version
/// This error can be thrown by `ByteToMessageDecoder`s if the incoming payload is larger than the max specified.
public struct ByteToMessageDecoderPayloadTooLargeError: Error {}

/// `ByteToMessageDecoder`s decode bytes in a stream-like fashion from `ByteBuffer` to another message type.
///
/// To add a `ByteToMessageDecoder` to the `ChannelPipeline` use
Expand Down Expand Up @@ -346,6 +350,7 @@ public final class ByteToMessageHandler<Decoder: ByteToMessageDecoder> {
}

internal private(set) var decoder: Decoder? // only `nil` if we're already decoding (ie. we're re-entered)
private let maximumBufferSize: Int?
private var queuedWrites = CircularBuffer<NIOAny>(initialCapacity: 1) // queues writes received whilst we're already decoding (re-entrant write)
private var state: State = .active {
willSet {
Expand All @@ -358,8 +363,14 @@ public final class ByteToMessageHandler<Decoder: ByteToMessageDecoder> {
private var seenEOF: Bool = false
private var selfAsCanDequeueWrites: CanDequeueWrites? = nil

public init(_ decoder: Decoder) {
/// Initialize a `ByteToMessageHandler`.
///
/// - parameters:
/// - decoder: The `ByteToMessageDecoder` to decode the bytes into message.
/// - maximumBufferSize: The maximum number of bytes to aggregate in-memory.
public init(_ decoder: Decoder, maximumBufferSize: Int? = nil) {
self.decoder = decoder
self.maximumBufferSize = maximumBufferSize
}

deinit {
Expand Down Expand Up @@ -448,13 +459,18 @@ extension ByteToMessageHandler {
var allowEmptyBuffer = decodeMode == .last
while (self.state.isActive && self.removalState == .notBeingRemoved) || decodeMode == .last {
let result = try self.withNextBuffer(allowEmptyBuffer: allowEmptyBuffer) { decoder, buffer in
let decoderResult: DecodingState
if decodeMode == .normal {
assert(self.state.isActive, "illegal state for normal decode: \(self.state)")
return try decoder.decode(context: context, buffer: &buffer)
decoderResult = try decoder.decode(context: context, buffer: &buffer)
} else {
allowEmptyBuffer = false
return try decoder.decodeLast(context: context, buffer: &buffer, seenEOF: self.seenEOF)
decoderResult = try decoder.decodeLast(context: context, buffer: &buffer, seenEOF: self.seenEOF)
}
if decoderResult == .needMoreData, let maximumBufferSize = self.maximumBufferSize, buffer.readableBytes > maximumBufferSize {
throw ByteToMessageDecoderPayloadTooLargeError()
}
return decoderResult
}
switch result {
case .didProcess(.continue):
Expand Down
2 changes: 2 additions & 0 deletions Tests/NIOTests/CodecTest+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ extension ByteToMessageDecoderTest {
("testErrorInDecodeLastWhenCloseIsReceivedReentrantlyInDecode", testErrorInDecodeLastWhenCloseIsReceivedReentrantlyInDecode),
("testWeAreOkayWithReceivingDataAfterHalfClosureEOF", testWeAreOkayWithReceivingDataAfterHalfClosureEOF),
("testWeAreOkayWithReceivingDataAfterFullClose", testWeAreOkayWithReceivingDataAfterFullClose),
("testPayloadTooLarge", testPayloadTooLarge),
("testPayloadTooLargeButHandlerOk", testPayloadTooLargeButHandlerOk),
]
}
}
Expand Down
51 changes: 51 additions & 0 deletions Tests/NIOTests/CodecTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,7 @@ public final class ByteToMessageDecoderTest: XCTestCase {
return .needMoreData
}
}

let decoder = Decoder()
let channel = EmbeddedChannel(handler: ByteToMessageHandler(decoder))
var buffer = channel.allocator.buffer(capacity: 16)
Expand Down Expand Up @@ -1400,7 +1401,57 @@ public final class ByteToMessageDecoderTest: XCTestCase {
XCTAssertEqual(1, decoder.decodeCalls)
XCTAssertEqual(1, decoder.decodeLastCalls)
}

func testPayloadTooLarge() {
struct Decoder: ByteToMessageDecoder {
typealias InboundOut = Never

func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
return .needMoreData
}

func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState {
return .needMoreData
}
}

let max = 100
let channel = EmbeddedChannel(handler: ByteToMessageHandler(Decoder(), maximumBufferSize: max))
var buffer = channel.allocator.buffer(capacity: max + 1)
buffer.writeString(String(repeating: "*", count: max + 1))
XCTAssertThrowsError(try channel.writeInbound(buffer)) { error in
XCTAssertTrue(error is ByteToMessageDecoderPayloadTooLargeError)
}
}

func testPayloadTooLargeButHandlerOk() {
class Decoder: ByteToMessageDecoder {
typealias InboundOut = ByteBuffer

var decodeCalls = 0

func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
self.decodeCalls += 1
buffer.moveReaderIndex(to: buffer.readableBytes)
return .continue
}

func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState {
self.decodeCalls += 1
buffer.moveReaderIndex(to: buffer.readableBytes)
return .continue
}
}

let max = 100
let decoder = Decoder()
let channel = EmbeddedChannel(handler: ByteToMessageHandler(decoder, maximumBufferSize: max))
var buffer = channel.allocator.buffer(capacity: max + 1)
buffer.writeString(String(repeating: "*", count: max + 1))
XCTAssertNoThrow(try channel.writeInbound(buffer))
XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean))
XCTAssertGreaterThan(decoder.decodeCalls, 0)
}
}

public final class MessageToByteEncoderTest: XCTestCase {
Expand Down

0 comments on commit e32a436

Please sign in to comment.