Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support chacha20-poly1305@openssh.com. #215

Merged
merged 5 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ Supported key exchange methods:
Supported encryption algorithms:
- aes256-gcm@openssh.com
- aes128-gcm@openssh.com
- chacha20-poly1305@openssh.com

Supported message authentication code algorithms:
- none
Expand Down
2 changes: 2 additions & 0 deletions src/Tmds.Ssh/AlgorithmNames.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ static class AlgorithmNames // TODO: rename to KnownNames
public static Name Aes128Gcm => new Name(Aes128GcmBytes);
private static readonly byte[] Aes256GcmBytes = "aes256-gcm@openssh.com"u8.ToArray();
public static Name Aes256Gcm => new Name(Aes256GcmBytes);
private static readonly byte[] ChaCha20Poly1305Bytes = "chacha20-poly1305@openssh.com"u8.ToArray();
public static Name ChaCha20Poly1305 => new Name(ChaCha20Poly1305Bytes);

// KDF algorithms:
private static readonly byte[] BCryptBytes = "bcrypt"u8.ToArray();
Expand Down
130 changes: 130 additions & 0 deletions src/Tmds.Ssh/ChaCha20Poly1305PacketDecoder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

using System;
using System.Buffers;
using System.Buffers.Binary;
using System.Security.Cryptography;

namespace Tmds.Ssh;

sealed class ChaCha20Poly1305PacketDecoder : ChaCha20Poly1305PacketEncDecBase, IPacketDecoder
{
private readonly SequencePool _sequencePool;
private int _currentPacketLength = -1;

public ChaCha20Poly1305PacketDecoder(SequencePool sequencePool, byte[] key) :
base(key)
{
_sequencePool = sequencePool;
}

public void Dispose()
{ }

public bool TryDecodePacket(Sequence receiveBuffer, uint sequenceNumber, int maxLength, out Packet packet)
{
packet = new Packet(null);

// Wait for the length.
if (receiveBuffer.Length < LengthSize)
{
return false;
}

// Decrypt length.
int packetLength = _currentPacketLength;
Span<byte> length_unencrypted = stackalloc byte[LengthSize];
if (packetLength == -1)
{
ConfigureCiphers(sequenceNumber);

Span<byte> length_encrypted = stackalloc byte[LengthSize];
if (receiveBuffer.FirstSpan.Length >= LengthSize)
{
receiveBuffer.FirstSpan.Slice(0, LengthSize).CopyTo(length_encrypted);
}
else
{
receiveBuffer.AsReadOnlySequence().Slice(0, LengthSize).CopyTo(length_encrypted);
}

LengthCipher.ProcessBytes(length_encrypted, length_unencrypted);

// Verify the packet length isn't too long and properly padded.
uint packet_length = BinaryPrimitives.ReadUInt32BigEndian(length_unencrypted);
if (packet_length > maxLength || (packet_length % PaddTo) != 0)
{
ThrowHelper.ThrowProtocolPacketTooLong();
}

_currentPacketLength = packetLength = (int)packet_length;
}
else
{
BinaryPrimitives.WriteInt32BigEndian(length_unencrypted, _currentPacketLength);
}

// Wait for the full encrypted packet.
int total_length = LengthSize + packetLength + TagSize;
if (receiveBuffer.Length < total_length)
{
return false;
}

// Check the mac.
ReadOnlySequence<byte> receiveBufferROSequence = receiveBuffer.AsReadOnlySequence();
ReadOnlySequence<byte> hashed = receiveBufferROSequence.Slice(0, LengthSize + packetLength);
Span<byte> packetTag = stackalloc byte[TagSize];
receiveBufferROSequence.Slice(LengthSize + packetLength, TagSize).CopyTo(packetTag);
if (hashed.IsSingleSegment)
{
Mac.BlockUpdate(hashed.FirstSpan);
}
else
{
foreach (var memory in hashed)
{
Mac.BlockUpdate(memory.Span);
}
}
Span<byte> tag = stackalloc byte[TagSize];
Mac.DoFinal(tag);
if (!CryptographicOperations.FixedTimeEquals(packetTag, tag))
{
throw new CryptographicException();
}

int decodedLength = total_length - TagSize;
Sequence decoded = _sequencePool.RentSequence();
Span<byte> dst = decoded.AllocGetSpan(decodedLength);

// Decrypt length.
length_unencrypted.CopyTo(dst);

// Decrypt payload.
Span<byte> plaintext = dst.Slice(LengthSize, packetLength);
ReadOnlySequence<byte> ciphertext = receiveBufferROSequence.Slice(LengthSize, packetLength);
if (ciphertext.IsSingleSegment)
{
PayloadCipher.ProcessBytes(ciphertext.FirstSpan, plaintext);
}
else
{
foreach (var memory in ciphertext)
{
PayloadCipher.ProcessBytes(memory.Span, plaintext);
plaintext = plaintext.Slice(memory.Length);
}
}

decoded.AppendAlloced(decodedLength);
packet = new Packet(decoded);

receiveBuffer.Remove(total_length);

_currentPacketLength = -1; // start decoding a new packet

return true;
}
}
60 changes: 60 additions & 0 deletions src/Tmds.Ssh/ChaCha20Poly1305PacketEncDecBase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

using System;
using System.Buffers.Binary;
using Org.BouncyCastle.Crypto.Engines;
using Org.BouncyCastle.Crypto.Macs;
using Org.BouncyCastle.Crypto.Parameters;

namespace Tmds.Ssh;

class ChaCha20Poly1305PacketEncDecBase
{
public const int TagSize = 16; // Poly1305 hash length.
protected const int PaddTo = 8; // We're not a block cipher. Padd to 8 octets per rfc4253.
protected const int LengthSize = 4; // SSH packet length field is 4 bytes.

protected readonly MyChaCha20 LengthCipher;
protected readonly MyChaCha20 PayloadCipher;
protected readonly Poly1305 Mac;
private readonly byte[] _iv;

protected ChaCha20Poly1305PacketEncDecBase(byte[] key)
{
_iv = new byte[12];
byte[] K_1 = key.AsSpan(32, 32).ToArray();
byte[] K_2 = key.AsSpan(0, 32).ToArray();
LengthCipher = new(K_1, _iv);
PayloadCipher = new(K_2, _iv);
Mac = new();
}

protected void ConfigureCiphers(uint sequenceNumber)
{
BinaryPrimitives.WriteUInt64BigEndian(_iv.AsSpan(4), sequenceNumber);
LengthCipher.SetIv(_iv);
PayloadCipher.SetIv(_iv);

// note: encrypting 64 bytes increments the ChaCha20 block counter.
Span<byte> polyKey = stackalloc byte[64];
PayloadCipher.ProcessBytes(input: polyKey, output: polyKey);
Mac.Init(new KeyParameter(polyKey[..32]));
}

// This class eliminates per packet ParametersWithIV/KeyParameter allocations.
sealed protected class MyChaCha20 : ChaCha7539Engine
{
public MyChaCha20(byte[] key, byte[] dummyIv)
{
Init(forEncryption: true, new ParametersWithIV(new KeyParameter(key), dummyIv));
}

public void SetIv(byte[] iv)
{
SetKey(null, iv);

Reset();
}
}
}
68 changes: 68 additions & 0 deletions src/Tmds.Ssh/ChaCha20Poly1305PacketEncoder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

using System;
using System.Buffers;

namespace Tmds.Ssh;

// https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.chacha20poly1305?annotate=HEAD
sealed class ChaCha20Poly1305PacketEncoder : ChaCha20Poly1305PacketEncDecBase, IPacketEncoder
{
public ChaCha20Poly1305PacketEncoder(byte[] key) :
base(key)
{ }

public void Dispose()
{ }

public void Encode(uint sequenceNumber, Packet packet, Sequence output)
{
using var pkt = packet.Move(); // Dispose the packet.

ConfigureCiphers(sequenceNumber);

// Padding.
uint payload_length = (uint)pkt.PayloadLength;
// PT (Plain Text)
// byte padding_length; // 4 <= padding_length < 256
// byte[n1] payload; // n1 = packet_length-padding_length-1
// byte[n2] random_padding; // n2 = padding_length
byte padding_length = IPacketEncoder.DeterminePaddingLength(payload_length + 1, multipleOf: PaddTo);
pkt.WriteHeaderAndPadding(padding_length);

var unencrypted_packet = pkt.AsReadOnlySequence();
ReadOnlySpan<byte> packet_length = unencrypted_packet.FirstSpan.Slice(0, LengthSize); // packet_length
ReadOnlySequence<byte> pt = unencrypted_packet.Slice(LengthSize); // PT (Plain Text)

int textLength = (int)pt.Length;
int encodedLength = LengthSize + textLength + TagSize;
Span<byte> dst = output.AllocGetSpan(encodedLength);

// Encrypt length.
Span<byte> length_encrypted = dst.Slice(0, LengthSize);
LengthCipher.ProcessBytes(packet_length, length_encrypted);

// Encrypt payload.
Span<byte> ciphertext = dst.Slice(LengthSize, textLength);
if (pt.IsSingleSegment)
{
PayloadCipher.ProcessBytes(pt.FirstSpan, ciphertext);
}
else
{
foreach (var memory in pt)
{
PayloadCipher.ProcessBytes(memory.Span, ciphertext);
ciphertext = ciphertext.Slice(memory.Length);
}
}

// Mac.
Span<byte> tag = dst.Slice(LengthSize + textLength, TagSize);
Mac.BlockUpdate(dst.Slice(0, LengthSize + textLength));
Mac.DoFinal(tag);

output.AppendAlloced(encodedLength);
}
}
41 changes: 26 additions & 15 deletions src/Tmds.Ssh/ECDHKeyExchange.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ public async Task<KeyExchangeOutput> TryExchangeAsync(SshConnection connection,
}

byte[] sessionId = input.ConnectionInfo.SessionId ?? exchangeHash;
byte[] initialIVC2S = Hash(sequencePool, sharedSecret, exchangeHash, (byte)'A', sessionId, input.InitialIVC2SLength);
byte[] initialIVS2C = Hash(sequencePool, sharedSecret, exchangeHash, (byte)'B', sessionId, input.InitialIVS2CLength);
byte[] encryptionKeyC2S = Hash(sequencePool, sharedSecret, exchangeHash, (byte)'C', sessionId, input.EncryptionKeyC2SLength);
byte[] encryptionKeyS2C = Hash(sequencePool, sharedSecret, exchangeHash, (byte)'D', sessionId, input.EncryptionKeyS2CLength);
byte[] integrityKeyC2S = Hash(sequencePool, sharedSecret, exchangeHash, (byte)'E', sessionId, input.IntegrityKeyC2SLength);
byte[] integrityKeyS2C = Hash(sequencePool, sharedSecret, exchangeHash, (byte)'F', sessionId, input.IntegrityKeyS2CLength);
byte[] initialIVC2S = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'A', sessionId, input.InitialIVC2SLength);
byte[] initialIVS2C = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'B', sessionId, input.InitialIVS2CLength);
byte[] encryptionKeyC2S = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'C', sessionId, input.EncryptionKeyC2SLength);
byte[] encryptionKeyS2C = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'D', sessionId, input.EncryptionKeyS2CLength);
byte[] integrityKeyC2S = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'E', sessionId, input.IntegrityKeyC2SLength);
byte[] integrityKeyS2C = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'F', sessionId, input.IntegrityKeyS2CLength);

return new KeyExchangeOutput(exchangeHash,
initialIVS2C, encryptionKeyS2C, integrityKeyS2C,
Expand Down Expand Up @@ -117,14 +117,13 @@ private byte[] CalculateExchangeHash(SequencePool sequencePool, SshConnectionInf
return hash.GetHashAndReset();
}

private byte[] Hash(SequencePool sequencePool, BigInteger sharedSecret, byte[] exchangeHash, byte c, byte[] sessionId, int hashLength)
private byte[] CalculateKey(SequencePool sequencePool, BigInteger sharedSecret, byte[] exchangeHash, byte c, byte[] sessionId, int keyLength)
{
// https://tools.ietf.org/html/rfc4253#section-7.2

byte[] hashRv = new byte[hashLength];
int hashOffset = 0;
byte[] key = new byte[keyLength];
int keyOffset = 0;

// TODO: handle 'If the key length needed is longer than the output of the HASH'
// HASH(K || H || c || session_id)
using Sequence sequence = sequencePool.RentSequence();
var writer = new SequenceWriter(sequence);
Expand All @@ -139,16 +138,28 @@ private byte[] Hash(SequencePool sequencePool, BigInteger sharedSecret, byte[] e
hash.AppendData(segment.Span);
}
byte[] K1 = hash.GetHashAndReset();
Append(hashRv, K1, ref hashOffset);
Append(key, K1, ref keyOffset);

while (hashOffset != hashRv.Length)
while (keyOffset != key.Length)
{
// TODO: handle 'If the key length needed is longer than the output of the HASH'
sequence.Clear();

// K3 = HASH(K || H || K1 || K2)
throw new NotSupportedException();
writer = new SequenceWriter(sequence);
writer.WriteMPInt(sharedSecret);
writer.Write(exchangeHash);
writer.Write(key.AsSpan(0, keyOffset));

foreach (var segment in sequence.AsReadOnlySequence())
{
hash.AppendData(segment.Span);
}
byte[] Kn = hash.GetHashAndReset();

Append(key, Kn, ref keyOffset);
}

return hashRv;
return key;

static void Append(byte[] key, byte[] append, ref int offset)
{
Expand Down
8 changes: 8 additions & 0 deletions src/Tmds.Ssh/EncryptionAlgorithm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,13 @@ public static EncryptionAlgorithm Find(Name name)
=> new AesGcmPacketDecoder(sequencePool, key, iv, algorithm.TagLength),
isAuthenticated: true,
tagLength: 16) },
{ AlgorithmNames.ChaCha20Poly1305,
new EncryptionAlgorithm(keyLength: 512 / 8, ivLength: 0,
(EncryptionAlgorithm algorithm, byte[] key, byte[] iv, HMacAlgorithm? hmac, byte[] hmacKey)
=> new ChaCha20Poly1305PacketEncoder(key),
(EncryptionAlgorithm algorithm, SequencePool sequencePool, byte[] key, byte[] iv, HMacAlgorithm? hmac, byte[] hmacKey)
=> new ChaCha20Poly1305PacketDecoder(sequencePool, key),
isAuthenticated: true,
tagLength: ChaCha20Poly1305PacketEncoder.TagSize) },
};
}
12 changes: 10 additions & 2 deletions src/Tmds.Ssh/SshChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,15 @@ public async ValueTask WriteAsync(ReadOnlyMemory<byte> memory, CancellationToken
int sendWindow = Volatile.Read(ref _sendWindow);
if (sendWindow > 0)
{
// We need to check the cancellation token in case we send a huge amount of data
// and the peer can keep up (and the send window never becomes zero).
if (cancellationToken.IsCancellationRequested)
{
Cancel();

cancellationToken.ThrowIfCancellationRequested();
}

int toSend = Math.Min(sendWindow, memory.Length);
toSend = Math.Min(toSend, SendMaxPacket);
if (Interlocked.CompareExchange(ref _sendWindow, sendWindow - toSend, sendWindow) == sendWindow)
Expand All @@ -213,8 +222,7 @@ public async ValueTask WriteAsync(ReadOnlyMemory<byte> memory, CancellationToken
{
Cancel();

cancellationToken.ThrowIfCancellationRequested();
throw CreateCloseException();
throw;
}
}
}
Expand Down
Loading
Loading