Skip to content

Commit

Permalink
Add explicit protocol validation when reading RESP messages (#332)
Browse files Browse the repository at this point in the history
* Add explicit protocol validation when reading RESP messages

* Code cleanup

* Add unit tests for RespReadUtils

* Add length header fast-path for booleans

* Improve readability of RespParsingExceptions

* Add context to resp parsing integer overflow exceptions

* Code cleanup

* Misc. code cleanup for RESP read utils

* Allow empty string values for INCRBY and SET commands.

* Allow empty keys in RESP messages.

* Code cleanup

* Make RESP null parsing opt-in

---------

Co-authored-by: Badrish Chandramouli <badrishc@microsoft.com>
  • Loading branch information
lmaas and badrishc authored May 2, 2024
1 parent 5202499 commit 3a4b349
Show file tree
Hide file tree
Showing 16 changed files with 1,004 additions and 314 deletions.
12 changes: 9 additions & 3 deletions benchmark/BDN.benchmark/Resp/RespIntegerReadBenchmarks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace BDN.benchmark.Resp
public unsafe class RespIntegerReadBenchmarks
{
[Benchmark]
[ArgumentsSource(nameof(SignedInt32EncodedValues))]
public int ReadInt32(AsciiTestCase testCase)
[ArgumentsSource(nameof(LengthHeaderValues))]
public int ReadLengthHeader(AsciiTestCase testCase)
{
fixed (byte* inputPtr = testCase.Bytes)
{
var start = inputPtr;
RespReadUtils.ReadInt(out var value, ref start, start + testCase.Bytes.Length);
RespReadUtils.ReadLengthHeader(out var value, ref start, start + testCase.Bytes.Length, allowNull: true);
return value;
}
}
Expand Down Expand Up @@ -72,6 +72,9 @@ public ulong ReadULongWithLengthHeader(AsciiTestCase testCase)
public static IEnumerable<object> SignedInt32EncodedValues
=> ToRespIntegerTestCases(RespIntegerWriteBenchmarks.SignedInt32Values);

public static IEnumerable<object> LengthHeaderValues
=> ToRespLengthHeaderTestCases(RespIntegerWriteBenchmarks.SignedInt32Values);

public static IEnumerable<object> SignedInt64EncodedValues
=> ToRespIntegerTestCases(RespIntegerWriteBenchmarks.SignedInt64Values);

Expand All @@ -90,6 +93,9 @@ public static IEnumerable<object> UnsignedInt64EncodedValuesWithLengthHeader
public static IEnumerable<AsciiTestCase> ToRespIntegerTestCases<T>(T[] integerValues) where T : struct
=> integerValues.Select(testCase => new AsciiTestCase($":{testCase}\r\n"));

public static IEnumerable<AsciiTestCase> ToRespLengthHeaderTestCases<T>(T[] integerValues) where T : struct
=> integerValues.Select(testCase => new AsciiTestCase($"${testCase}\r\n"));

public static IEnumerable<AsciiTestCase> ToRespIntegerWithLengthHeader<T>(T[] integerValues) where T : struct
=> integerValues.Select(testCase => new AsciiTestCase($"${testCase.ToString()?.Length ?? 0}\r\n{testCase}\r\n"));

Expand Down
2 changes: 1 addition & 1 deletion libs/client/GarnetClientProcessReplies.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ unsafe bool ProcessReplyAsString(ref byte* ptr, byte* end, out string result, ou
break;

case (byte)'$':
if (!RespReadUtils.ReadStringWithLengthHeader(out result, ref ptr, end))
if (!RespReadUtils.ReadStringWithLengthHeader(out result, ref ptr, end, allowNull: true))
return false;
break;

Expand Down
43 changes: 21 additions & 22 deletions libs/cluster/Session/ClusterSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Garnet.common;
using Garnet.common.Parsing;
using Garnet.networking;
using Garnet.server;
using Garnet.server.ACL;
Expand Down Expand Up @@ -225,38 +227,35 @@ bool CheckACLAdminPermissions()

ReadOnlySpan<byte> GetCommand(ReadOnlySpan<byte> bufSpan, out bool success)
{
if (bytesRead - readHead < 6)
success = false;

var ptr = recvBufferPtr + readHead;
var end = recvBufferPtr + bytesRead;

// Try to read the command length
if (!RespReadUtils.ReadLengthHeader(out int length, ref ptr, end))
{
success = false;
return default;
}

Debug.Assert(*(recvBufferPtr + readHead) == '$');
int psize = *(recvBufferPtr + readHead + 1) - '0';
readHead += 2;
while (*(recvBufferPtr + readHead) != '\r')
{
psize = psize * 10 + *(recvBufferPtr + readHead) - '0';
if (bytesRead - readHead < 1)
{
success = false;
return default;
}
readHead++;
}
if (bytesRead - readHead < 2 + psize + 2)
readHead = (int)(ptr - recvBufferPtr);

// Try to read the command value
ptr += length;
if (ptr + 2 > end)
{
success = false;
return default;
}
Debug.Assert(*(recvBufferPtr + readHead + 1) == '\n');

var result = bufSpan.Slice(readHead + 2, psize);
Debug.Assert(*(recvBufferPtr + readHead + 2 + psize) == '\r');
Debug.Assert(*(recvBufferPtr + readHead + 2 + psize + 1) == '\n');
if (*(ushort*)ptr != MemoryMarshal.Read<ushort>("\r\n"u8))
{
RespParsingException.ThrowUnexpectedToken(*ptr);
}

readHead += 2 + psize + 2;
success = true;
var result = bufSpan.Slice(readHead, length);
readHead += length + 2;

return result;
}
}
Expand Down
85 changes: 85 additions & 0 deletions libs/common/Parsing/RespParsingException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

using System.Diagnostics.CodeAnalysis;
using System.Text;

namespace Garnet.common.Parsing
{
/// <summary>
/// Exception wrapper for RESP parsing errors.
/// </summary>
public class RespParsingException : GarnetException
{
/// <summary>
/// Construct a new RESP parsing exception with the given message.
/// </summary>
/// <param name="message">Message that described the exception that has occurred.</param>
RespParsingException(string message) : base(message)
{
// Nothing...
}

/// <summary>
/// Throw an "Unexcepted Token" exception.
/// </summary>
/// <param name="token">The character that was unexpected.</param>
[DoesNotReturn]
public static void ThrowUnexpectedToken(byte token)
{
var c = (char)token;
var escaped = char.IsControl(c) ? $"\\x{token:x2}" : c.ToString();
Throw($"Unexpected character '{escaped}'.");
}

/// <summary>
/// Throw an invalid string length exception.
/// </summary>
/// <param name="len">The invalid string length.</param>
[DoesNotReturn]
public static void ThrowInvalidStringLength(long len)
{
Throw($"Invalid string length '{len}'.");
}

/// <summary>
/// Throw an invalid length exception.
/// </summary>
/// <param name="len">The invalid length.</param>
[DoesNotReturn]
public static void ThrowInvalidLength(long len)
{
Throw($"Invalid length '{len}'.");
}

/// <summary>
/// Throw NaN (not a number) exception.
/// </summary>
/// <param name="buffer">Pointer to an ASCII-encoded byte buffer containing the string that could not be converted.</param>
/// <param name="length">Length of the buffer.</param>
[DoesNotReturn]
public static unsafe void ThrowNotANumber(byte* buffer, int length)
{
Throw($"Unable to parse number: {Encoding.ASCII.GetString(buffer, length)}");
}

/// <summary>
/// Throw a exception indicating that an integer overflow has occurred.
/// </summary>
/// <param name="buffer">Pointer to an ASCII-encoded byte buffer containing the string that caused the overflow.</param>
/// <param name="length">Length of the buffer.</param>
[DoesNotReturn]
public static unsafe void ThrowIntegerOverflow(byte* buffer, int length)
{
Throw($"Unable to parse integer. The given number is larger than allowed: {Encoding.ASCII.GetString(buffer, length)}");
}

/// <summary>
/// Throw helper that throws a RespParsingException.
/// </summary>
/// <param name="message">Exception message.</param>
[DoesNotReturn]
public static void Throw(string message) =>
throw new RespParsingException(message);
}
}
Loading

0 comments on commit 3a4b349

Please sign in to comment.