Skip to content

Commit

Permalink
Add support for reading Chunked encoded responses over UDS (#1089)
Browse files Browse the repository at this point in the history
* Add Http Chunked stream reader

* Remove Http response content

* Add test

* Fix tests
  • Loading branch information
varunpuranik authored Apr 15, 2019
1 parent 5bf5324 commit 83118b2
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Azure.Devices.Edge.Util.Uds
{
using System;
using System.Globalization;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

class HttpChunkedStreamReader : Stream
{
readonly HttpBufferedStream stream;
int chunkBytes;
bool eos;

public HttpChunkedStreamReader(HttpBufferedStream stream)
{
this.stream = Preconditions.CheckNotNull(stream, nameof(stream));
}

public override bool CanRead => true;

public override bool CanSeek => false;

public override bool CanWrite => false;

public override long Length => throw new NotSupportedException();

public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}

public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
if (this.eos)
{
return 0;
}

if (this.chunkBytes == 0)
{
string line = await this.stream.ReadLineAsync(cancellationToken);
if (!int.TryParse(line, NumberStyles.HexNumber, CultureInfo.InvariantCulture, out this.chunkBytes))
{
throw new IOException($"Cannot parse chunk header - {line}");
}
}

int bytesRead = 0;
if (this.chunkBytes > 0)
{
int bytesToRead = Math.Min(count, this.chunkBytes);
bytesRead = await this.stream.ReadAsync(buffer, offset, bytesToRead, cancellationToken);
if (bytesToRead == 0)
{
throw new EndOfStreamException();
}

this.chunkBytes -= bytesToRead;
}

if (this.chunkBytes == 0)
{
await this.stream.ReadLineAsync(cancellationToken);
if (bytesRead == 0)
{
this.eos = true;
}
}

return bytesRead;
}

public override void Flush() => throw new NotImplementedException();

public override int Read(byte[] buffer, int offset, int count) => throw new NotImplementedException();

public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException();

public override void SetLength(long value) => throw new NotImplementedException();

public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ async Task SetHeadersAndContent(HttpResponseMessage httpResponse, HttpBufferedSt
}

httpResponse.Content = new StreamContent(bufferedStream);
var contentHeaders = new Dictionary<string, string>();
foreach (string header in headers)
{
if (string.IsNullOrWhiteSpace(header))
Expand All @@ -99,17 +100,28 @@ async Task SetHeadersAndContent(HttpResponseMessage httpResponse, HttpBufferedSt
bool headerAdded = httpResponse.Headers.TryAddWithoutValidation(headerName, headerValue);
if (!headerAdded)
{
if (string.Equals(headerName, ContentLengthHeaderName, StringComparison.InvariantCultureIgnoreCase))
{
if (!long.TryParse(headerValue, out long contentLength))
{
throw new HttpRequestException($"Header value is invalid for {headerName}.");
}
contentHeaders.Add(headerName, headerValue);
}
}

bool isChunked = httpResponse.Headers.TransferEncodingChunked.HasValue
&& httpResponse.Headers.TransferEncodingChunked.Value;

await httpResponse.Content.LoadIntoBufferAsync(contentLength);
httpResponse.Content = isChunked
? new StreamContent(new HttpChunkedStreamReader(bufferedStream))
: new StreamContent(bufferedStream);

foreach (KeyValuePair<string, string> contentHeader in contentHeaders)
{
httpResponse.Content.Headers.TryAddWithoutValidation(contentHeader.Key, contentHeader.Value);
if (string.Equals(contentHeader.Key, ContentLengthHeaderName, StringComparison.InvariantCultureIgnoreCase))
{
if (!long.TryParse(contentHeader.Value, out long contentLength))
{
throw new HttpRequestException($"Header value {contentHeader.Value} is invalid for {ContentLengthHeaderName}.");
}

httpResponse.Content.Headers.TryAddWithoutValidation(headerName, headerValue);
await httpResponse.Content.LoadIntoBufferAsync(contentLength);
}
}
}
Expand All @@ -136,7 +148,7 @@ async Task SetResponseStatusLine(HttpResponseMessage httpResponse, HttpBufferedS

httpResponse.Version = versionNumber;

if (!Enum.TryParse(statusLineParts[1], out HttpStatusCode statusCode))
if (!Enum.TryParse(statusLineParts[1], out HttpStatusCode statusCode) || !Enum.IsDefined(typeof(HttpStatusCode), statusCode))
{
throw new HttpRequestException($"StatusCode is not valid {statusLineParts[1]}.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ namespace Microsoft.Azure.Devices.Edge.Util.Test.Uds
[Unit]
public class HttpRequestResponseSerializerTest
{
static readonly string ChunkedResponseContentText = $"This is test content\nSecond chunk received from server\n";

static readonly byte[] ChunkedResponseBytes =
{
0x48, 0x54, 0x54, 0x50, 0x2F, 0x31, 0x2E, 0x31, 0x20, 0x32, 0x30, 0x30, 0x20, 0x4F, 0x4B, 0x0D,
0x0A, 0x74, 0x72, 0x61, 0x6E, 0x73, 0x66, 0x65, 0x72, 0x2D, 0x65, 0x6E, 0x63, 0x6F, 0x64, 0x69,
0x6E, 0x67, 0x3A, 0x20, 0x63, 0x68, 0x75, 0x6E, 0x6B, 0x65, 0x64, 0x0D, 0x0A, 0x64, 0x61, 0x74,
0x65, 0x3A, 0x20, 0x46, 0x72, 0x69, 0x2C, 0x20, 0x31, 0x32, 0x20, 0x41, 0x70, 0x72, 0x20, 0x32,
0x30, 0x31, 0x39, 0x20, 0x32, 0x32, 0x3A, 0x31, 0x36, 0x3A, 0x34, 0x33, 0x20, 0x47, 0x4D, 0x54,
0x0D, 0x0A, 0x0D, 0x0A, 0x31, 0x35, 0x0D, 0x0A, 0x54, 0x68, 0x69, 0x73, 0x20, 0x69, 0x73, 0x20,
0x74, 0x65, 0x73, 0x74, 0x20, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x0A, 0x0D, 0x0A, 0x32,
0x32, 0x0D, 0x0A, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x20,
0x72, 0x65, 0x63, 0x65, 0x69, 0x76, 0x65, 0x64, 0x20, 0x66, 0x72, 0x6f, 0x6d, 0x20, 0x73, 0x65,
0x72, 0x76, 0x65, 0x72, 0x0A, 0x0D, 0x0A, 0x30, 0x0D, 0x0A, 0x0D, 0x0A
};

[Fact]
public void TestSerializeRequest_MethodMissing_ShouldSerializeRequest()
{
Expand Down Expand Up @@ -113,80 +129,80 @@ public void TestSerializeRequest_ShouldSerializeRequest()
}

[Fact]
public void TestDeserializeResponse_InvalidEndOfStream_ShouldThrow()
public async Task TestDeserializeResponse_InvalidEndOfStream_ShouldThrow()
{
byte[] expected = Encoding.UTF8.GetBytes("invalid");
var memory = new MemoryStream(expected, true);
var stream = new HttpBufferedStream(memory);

CancellationToken cancellationToken = default(CancellationToken);
Assert.ThrowsAsync<HttpRequestException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
await Assert.ThrowsAsync<IOException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
}

[Fact]
public void TestDeserializeResponse_InvalidStatusLine_ShouldThrow()
public async Task TestDeserializeResponse_InvalidStatusLine_ShouldThrow()
{
byte[] expected = Encoding.UTF8.GetBytes("invalid\r\n");
var memory = new MemoryStream(expected, true);
var stream = new HttpBufferedStream(memory);

CancellationToken cancellationToken = default(CancellationToken);
Assert.ThrowsAsync<HttpRequestException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
await Assert.ThrowsAsync<HttpRequestException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
}

[Fact]
public void TestDeserializeResponse_InvalidVersion_ShouldThrow()
public async Task TestDeserializeResponse_InvalidVersion_ShouldThrow()
{
byte[] expected = Encoding.UTF8.GetBytes("HTTP/11 200 OK\r\n");
var memory = new MemoryStream(expected, true);
var stream = new HttpBufferedStream(memory);

CancellationToken cancellationToken = default(CancellationToken);
Assert.ThrowsAsync<HttpRequestException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
await Assert.ThrowsAsync<HttpRequestException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
}

[Fact]
public void TestDeserializeResponse_InvalidProtocolVersionSeparator_ShouldThrow()
public async Task TestDeserializeResponse_InvalidProtocolVersionSeparator_ShouldThrow()
{
byte[] expected = Encoding.UTF8.GetBytes("HTTP-1.1 200 OK\r\n");
var memory = new MemoryStream(expected, true);
var stream = new HttpBufferedStream(memory);

CancellationToken cancellationToken = default(CancellationToken);
Assert.ThrowsAsync<HttpRequestException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
await Assert.ThrowsAsync<HttpRequestException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
}

[Fact]
public void TestDeserializeResponse_InvalidStatusCode_ShouldThrow()
public async Task TestDeserializeResponse_InvalidStatusCode_ShouldThrow()
{
byte[] expected = Encoding.UTF8.GetBytes("HTTP/1.1 2000 OK\r\n");
var memory = new MemoryStream(expected, true);
var stream = new HttpBufferedStream(memory);

CancellationToken cancellationToken = default(CancellationToken);
Assert.ThrowsAsync<HttpRequestException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
await Assert.ThrowsAsync<HttpRequestException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
}

[Fact]
public void TestDeserializeResponse_MissingReasonPhrase_ShouldThrow()
public async Task TestDeserializeResponse_MissingReasonPhrase_ShouldThrow()
{
byte[] expected = Encoding.UTF8.GetBytes("HTTP/1.1 200\r\n");
var memory = new MemoryStream(expected, true);
var stream = new HttpBufferedStream(memory);

CancellationToken cancellationToken = default(CancellationToken);
Assert.ThrowsAsync<HttpRequestException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
await Assert.ThrowsAsync<HttpRequestException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
}

[Fact]
public void TestDeserializeResponse_InvalidEndOfStatusMessage_ShouldThrow()
public async Task TestDeserializeResponse_InvalidEndOfStatusMessage_ShouldThrow()
{
byte[] expected = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK \r\n");
var memory = new MemoryStream(expected, true);
var stream = new HttpBufferedStream(memory);

CancellationToken cancellationToken = default(CancellationToken);
Assert.ThrowsAsync<HttpRequestException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
await Assert.ThrowsAsync<IOException>(() => new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken));
}

[Fact]
Expand Down Expand Up @@ -271,6 +287,41 @@ public async Task TestDeserializeResponse_ValidContent_ShouldDeserialize()
Assert.Equal("Test", await response.Content.ReadAsStringAsync());
}

[Fact]
public async Task TestDeserializeChunkedResponse_ValidContent_ShouldDeserialize()
{
var memory = new MemoryStream(ChunkedResponseBytes, true);
var stream = new HttpBufferedStream(memory);

CancellationToken cancellationToken = default(CancellationToken);
HttpResponseMessage response = await new HttpRequestResponseSerializer().DeserializeResponse(stream, cancellationToken);

Assert.Equal(response.Version, Version.Parse("1.1"));
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("OK", response.ReasonPhrase);
Assert.False(response.Content.Headers.ContentLength.HasValue);

Stream responseStream = await response.Content.ReadAsStreamAsync();
byte[] responseBytes = await ReadStream(responseStream);
string responseText = Encoding.UTF8.GetString(responseBytes);
Assert.Equal(ChunkedResponseContentText, responseText);
}

static async Task<byte[]> ReadStream(Stream s)
{
byte[] buffer = new byte[16 * 1024];
using (MemoryStream ms = new MemoryStream())
{
int read;
while ((read = await s.ReadAsync(buffer, 0, buffer.Length)) > 0)
{
ms.Write(buffer, 0, read);
}

return ms.ToArray();
}
}

static void AssertNormalizedValues(string expected, string actual)
{
// Remove metacharacters before assertion to allow to run on both Windows and Linux; which Linux will return additional carriage return character.
Expand Down

0 comments on commit 83118b2

Please sign in to comment.