Skip to content

Commit

Permalink
Support for RFC 7766
Browse files Browse the repository at this point in the history
  • Loading branch information
alexreinert committed Jun 5, 2023
1 parent ba3e4a4 commit 5bff4c5
Show file tree
Hide file tree
Showing 36 changed files with 1,210 additions and 1,129 deletions.
4 changes: 2 additions & 2 deletions ARSoft.Tools.Net/ARSoft.Tools.Net.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
<Description>This project contains a complete managed .Net DNS and DNSSEC client, a DNS server and SPF and SenderID validation.</Description>
<PackageProjectUrl>https://github.com/alexreinert/ARSoft.Tools.Net</PackageProjectUrl>
<PackageTags>dns dnssec spf</PackageTags>
<PackageLicenseUrl>https://github.com/alexreinert/ARSoft.Tools.Net/blob/master/LICENSE</PackageLicenseUrl>
<PackageLicenseExpression>Apache-2.0</PackageLicenseExpression>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<Copyright>Copyright 2010..2023 Alexander Reinert</Copyright>
<VersionPrefix>3.3.0</VersionPrefix>
<VersionPrefix>3.4.0</VersionPrefix>
</PropertyGroup>

<ItemGroup>
Expand Down
34 changes: 11 additions & 23 deletions ARSoft.Tools.Net/Dns/DnsClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,11 @@ public DnsClient(IEnumerable<IPAddress> dnsServers, IClientTransport[] transport

DnsMessage message = new DnsMessage() { IsQuery = true, OperationCode = OperationCode.Query, IsRecursionDesired = true, IsEDnsEnabled = true };

if (options == null)
{
message.IsRecursionDesired = true;
message.IsEDnsEnabled = true;
}
else
{
message.IsRecursionDesired = options.IsRecursionDesired;
message.IsCheckingDisabled = options.IsCheckingDisabled;
message.EDnsOptions = options.EDnsOptions;
}
options ??= DnsQueryOptions.DefaultQueryOptions;

message.IsRecursionDesired = options.IsRecursionDesired;
message.IsCheckingDisabled = options.IsCheckingDisabled;
message.EDnsOptions = options.EDnsOptions;

message.Questions.Add(new DnsQuestion(name, recordType, recordClass));

Expand All @@ -120,19 +114,13 @@ public DnsClient(IEnumerable<IPAddress> dnsServers, IClientTransport[] transport
{
_ = name ?? throw new ArgumentNullException(nameof(name), "Name must be provided");

DnsMessage message = new DnsMessage() { IsQuery = true, OperationCode = OperationCode.Query, IsRecursionDesired = true, IsEDnsEnabled = true };
options ??= DnsQueryOptions.DefaultQueryOptions;

if (options == null)
{
message.IsRecursionDesired = true;
message.IsEDnsEnabled = true;
}
else
{
message.IsRecursionDesired = options.IsRecursionDesired;
message.IsCheckingDisabled = options.IsCheckingDisabled;
message.EDnsOptions = options.EDnsOptions;
}
var message = new DnsMessage { IsQuery = true, OperationCode = OperationCode.Query, IsRecursionDesired = true, IsEDnsEnabled = true };

message.IsRecursionDesired = options.IsRecursionDesired;
message.IsCheckingDisabled = options.IsCheckingDisabled;
message.EDnsOptions = options.EDnsOptions;

message.Questions.Add(new DnsQuestion(name, recordType, recordClass));

Expand Down
155 changes: 9 additions & 146 deletions ARSoft.Tools.Net/Dns/DnsClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,151 +86,13 @@ internal DnsClientBase(IEnumerable<IPAddress> servers, int queryTimeout, IClient
protected TMessage? SendMessage<TMessage>(TMessage query)
where TMessage : DnsMessageBase, new()
{
SelectTsigKey? tsigKeySelector;

var package = PrepareMessage(query, out tsigKeySelector, out var tsigOriginalMac);

TMessage? response = null;

foreach (var connection in GetConnections(package, query.IsReliableSendingRequested))
{
try
{
var receivedMessage = SendMessage<TMessage>(package, connection, tsigKeySelector, tsigOriginalMac);

if ((receivedMessage != null) && ValidateResponse(query, receivedMessage.Message))
{
if (receivedMessage.Message.ReturnCode == ReturnCode.ServerFailure)
{
response = receivedMessage.Message;
continue;
}

if (!receivedMessage.Message.IsReliableResendingRequested)
return receivedMessage.Message;

var resendTransport = _transports.FirstOrDefault(t => t.SupportsReliableTransfer && t.MaximumAllowedQuerySize <= package.Length && t != connection.Transport);

if (resendTransport != null)
{
using (IClientConnection? resendConnection = resendTransport.Connect(new DnsClientEndpointInfo(false, receivedMessage.ResponderAddress.Address, receivedMessage.LocalAddress.Address), QueryTimeout))
{
if (resendConnection == null)
{
response = receivedMessage.Message;
}
else
{
var resendResponse = SendMessage<TMessage>(package, resendConnection, tsigKeySelector, tsigOriginalMac);

if ((resendResponse != null)
&& ValidateResponse(query, resendResponse.Message)
&& ((resendResponse.Message.ReturnCode != ReturnCode.ServerFailure)))
{
return resendResponse.Message;
}
else
{
resendConnection.MarkFaulty();
response = receivedMessage.Message;
}
}
}
}
}
else
{
connection.MarkFaulty();
}
}
catch (Exception e)
{
Trace.TraceError("Error on dns query: " + e);
connection.MarkFaulty();
}
finally
{
connection.Dispose();
}
}

return response;
}

private IEnumerable<IClientConnection> GetConnections(DnsRawPackage package, bool isReliableTransportRequested)
{
foreach (var transport in _transports)
{
if (transport.SupportsPooledConnections
&& package.Length <= transport.MaximumAllowedQuerySize
&& (!isReliableTransportRequested || transport.SupportsReliableTransfer))
{
foreach (var endpointInfo in _endpointInfos)
{
var connection = transport.GetPooledConnection(endpointInfo);
if (connection != null)
yield return connection;
}
}
}

foreach (var transport in _transports)
{
if (package.Length <= transport.MaximumAllowedQuerySize
&& (!isReliableTransportRequested || transport.SupportsReliableTransfer))
{
foreach (var endpointInfo in _endpointInfos)
{
var connection = transport.Connect(endpointInfo, QueryTimeout);
if (connection != null)
yield return connection;
}
}
}
}

private ReceivedMessage<TMessage>? SendMessage<TMessage>(DnsRawPackage package, IClientConnection connection, SelectTsigKey? tsigKeySelector, byte[]? tsigOriginalMac)
where TMessage : DnsMessageBase, new()
{
if (!connection.Send(package))
return null;

var resultData = connection.Receive();

if (resultData == null)
return null;

var response = DnsMessageBase.Parse<TMessage>(resultData.ToArraySegment(false), tsigKeySelector, tsigOriginalMac);

var isNextMessageWaiting = response.IsNextMessageWaiting(false);

while (isNextMessageWaiting)
{
resultData = connection.Receive();

if (resultData == null)
return null;

var nextResult = DnsMessageBase.Parse<TMessage>(resultData.ToArraySegment(false), tsigKeySelector, tsigOriginalMac);

if (nextResult.ReturnCode == ReturnCode.ServerFailure)
return null;

response.AddSubsequentResponse(nextResult);
isNextMessageWaiting = nextResult.IsNextMessageWaiting(true);
}

return new ReceivedMessage<TMessage>(resultData.RemoteEndpoint, resultData.LocalEndpoint, response);
return SendMessageAsync<TMessage>(query, CancellationToken.None).GetAwaiter().GetResult();
}

protected List<TMessage> SendMessageParallel<TMessage>(TMessage message)
where TMessage : DnsMessageBase, new()
{
var result = SendMessageParallelAsync(message, default);

result.Wait();

return result.Result;
return SendMessageParallelAsync(message, default).GetAwaiter().GetResult();
}

private bool ValidateResponse<TMessage>(TMessage message, TMessage response)
Expand Down Expand Up @@ -293,6 +155,8 @@ private DnsRawPackage PrepareMessage<TMessage>(TMessage message, out SelectTsigK

if ((receivedMessage != null) && ValidateResponse(query, receivedMessage.Message))
{
connection.RestartIdleTimeout(receivedMessage.Message.GetEDnsKeepAliveTimeout());

if (receivedMessage.Message.ReturnCode == ReturnCode.ServerFailure)
{
response = receivedMessage.Message;
Expand Down Expand Up @@ -320,6 +184,7 @@ private DnsRawPackage PrepareMessage<TMessage>(TMessage message, out SelectTsigK
&& ValidateResponse(query, resendResponse.Message)
&& ((resendResponse.Message.ReturnCode != ReturnCode.ServerFailure)))
{
resendConnection.RestartIdleTimeout(receivedMessage.Message.GetEDnsKeepAliveTimeout());
return resendResponse.Message;
}
else
Expand Down Expand Up @@ -360,9 +225,7 @@ private DnsRawPackage PrepareMessage<TMessage>(TMessage message, out SelectTsigK
{
foreach (var endpointInfo in _endpointInfos)
{
var connection = transport.GetPooledConnection(endpointInfo);
if (connection != null)
yield return Task.FromResult<IClientConnection?>(connection);
yield return transport.GetPooledConnectionAsync(endpointInfo, token);
}
}
}
Expand All @@ -386,7 +249,7 @@ private DnsRawPackage PrepareMessage<TMessage>(TMessage message, out SelectTsigK
if (!await connection.SendAsync(package, token))
return null;

var resultData = await connection.ReceiveAsync(token);
var resultData = await connection.ReceiveAsync(package.MessageIdentification, token);

if (resultData == null)
return null;
Expand All @@ -397,7 +260,7 @@ private DnsRawPackage PrepareMessage<TMessage>(TMessage message, out SelectTsigK

while (isNextMessageWaiting)
{
resultData = await connection.ReceiveAsync(token);
resultData = await connection.ReceiveAsync(package.MessageIdentification, token);

if (resultData == null)
return null;
Expand Down Expand Up @@ -458,7 +321,7 @@ private async Task SendMessageParallelAsync<TMessage>(IClientTransport transport
if (token.IsCancellationRequested)
break;

var response = await connection.ReceiveAsync(token);
var response = await connection.ReceiveAsync(package.MessageIdentification, token);

if (response == null)
continue;
Expand Down
14 changes: 4 additions & 10 deletions ARSoft.Tools.Net/Dns/DnsMessageBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,7 @@ protected void ParseQuestionSection(IList<byte> data, ref int currentPosition, i

for (var i = 0; i < recordCount; i++)
{
DnsQuestion question = new DnsQuestion(
ParseDomainName(data, ref currentPosition),
(RecordType)ParseUShort(data, ref currentPosition),
(RecordClass)ParseUShort(data, ref currentPosition));

questions.Add(question);
questions.Add(DnsQuestion.Parse(data, ref currentPosition));
}

SetQuestionSection(questions);
Expand Down Expand Up @@ -971,7 +966,7 @@ internal static TMessage ParseRfc8427Json<TMessage>(JsonElement json)
msg.IsQuery = !ReadBoolFlag(prop.Value);
break;
case "Opcode":
msg.OperationCodeInternal = (OperationCode)prop.Value.GetUInt16();
msg.OperationCodeInternal = (OperationCode) prop.Value.GetUInt16();
break;
case "AA":
msg.AAFlagInternal = ReadBoolFlag(prop.Value);
Expand All @@ -992,7 +987,7 @@ internal static TMessage ParseRfc8427Json<TMessage>(JsonElement json)
msg.CDFlagInternal = ReadBoolFlag(prop.Value);
break;
case "RCODE":
msg.ReturnCode = (ReturnCode)prop.Value.GetUInt16();
msg.ReturnCode = (ReturnCode) prop.Value.GetUInt16();
break;
case "QNAME":
qname = DomainName.Parse(prop.Value.GetString() ?? String.Empty);
Expand All @@ -1004,7 +999,7 @@ internal static TMessage ParseRfc8427Json<TMessage>(JsonElement json)
qtype = RecordTypeHelper.ParseShortString(prop.Value.GetString() ?? String.Empty);
break;
case "QCLASS":
qclass = (RecordClass)prop.Value.GetUInt16();
qclass = (RecordClass) prop.Value.GetUInt16();
break;
case "QCLASSname":
qclass = RecordClassHelper.ParseShortString(prop.Value.GetString() ?? String.Empty);
Expand Down Expand Up @@ -1082,6 +1077,5 @@ private static bool ReadBoolFlag(JsonElement json)
default: throw new JsonException("Not a valid boolean flag: '" + json.GetRawText() + "'");
}
}

}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright and License
#region Copyright and License
// Copyright 2010..2023 Alexander Reinert
//
// This file is part of the ARSoft.Tools.Net - C# DNS client/server and SPF Library (https://github.com/alexreinert/ARSoft.Tools.Net)
Expand All @@ -16,16 +16,12 @@
// limitations under the License.
#endregion

namespace ARSoft.Tools.Net.Dns
namespace ARSoft.Tools.Net.Dns;

internal static class DnsMessageBaseExtensions
{
/// <summary>
/// Interface of a pooled connection initiated by a client
/// </summary>
public interface IPoolableClientConnection : IClientConnection
public static TimeSpan? GetEDnsKeepAliveTimeout(this DnsMessageBase message)
{
/// <summary>
/// Returns a value indicating if the connection is still alive
/// </summary>
bool IsAlive { get; }
return message.EDnsOptions?.Options?.OfType<TcpKeepAliveOption>()?.FirstOrDefault()?.Timeout;
}
}
24 changes: 24 additions & 0 deletions ARSoft.Tools.Net/Dns/DnsQueryOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,29 @@ public bool IsDnsSecOk
}
}
}

internal static DnsQueryOptions DefaultQueryOptions { get; } = new()
{
IsRecursionDesired = true,
EDnsOptions = new(
1232,
new DnssecAlgorithmUnderstoodOption(EnumHelper<DnsSecAlgorithm>.Names.Keys.Where(a => a.IsSupported()).ToArray()),
new DsHashUnderstoodOption(EnumHelper<DnsSecDigestType>.Names.Keys.Where(d => d.IsSupported()).ToArray()),
new Nsec3HashUnderstoodOption(EnumHelper<NSec3HashAlgorithm>.Names.Keys.Where(a => a.IsSupported()).ToArray())
)
};

internal static DnsQueryOptions DefaultDnsSecQueryOptions { get; } = new()
{
IsRecursionDesired = true,
IsCheckingDisabled = true,
EDnsOptions = new(
1232,
new DnssecAlgorithmUnderstoodOption(EnumHelper<DnsSecAlgorithm>.Names.Keys.Where(a => a.IsSupported()).ToArray()),
new DsHashUnderstoodOption(EnumHelper<DnsSecDigestType>.Names.Keys.Where(d => d.IsSupported()).ToArray()),
new Nsec3HashUnderstoodOption(EnumHelper<NSec3HashAlgorithm>.Names.Keys.Where(a => a.IsSupported()).ToArray())
),
IsDnsSecOk = true
};
}
}
Loading

0 comments on commit 5bff4c5

Please sign in to comment.