From 579d5fc7458111f5b63b80f1ef6836e79c9596d8 Mon Sep 17 00:00:00 2001 From: Christoph Kempfer <christoph.kempfer@adesso.de> Date: Thu, 25 Jan 2024 09:26:17 +0100 Subject: [PATCH] Added tests for dns resolver --- .../com/hivemq/security/ssl/DnsResolver.java | 32 ++++++++++++ .../security/ssl/SslContextFactory.java | 29 ----------- .../hivemq/security/ssl/DnsResolverTest.java | 51 +++++++++++++++++++ 3 files changed, 83 insertions(+), 29 deletions(-) create mode 100644 src/main/java/com/hivemq/security/ssl/DnsResolver.java create mode 100644 src/test/java/com/hivemq/security/ssl/DnsResolverTest.java diff --git a/src/main/java/com/hivemq/security/ssl/DnsResolver.java b/src/main/java/com/hivemq/security/ssl/DnsResolver.java new file mode 100644 index 000000000..55210ef42 --- /dev/null +++ b/src/main/java/com/hivemq/security/ssl/DnsResolver.java @@ -0,0 +1,32 @@ +package com.hivemq.security.ssl; + +import java.util.Map; + +public class DnsResolver { + + private final Map<String, String> dnsMap; + + DnsResolver(final Map<String, String> dnsMap) { + this.dnsMap = dnsMap; + } + + String resolve(final String domain) { + String alias = dnsMap.get(domain); + if (alias != null) { + return alias; + } + + int index = domain.indexOf('.'); + while (index >= 0) { + final String wildcardDomain = "*" + domain.substring(index); + alias = dnsMap.get(wildcardDomain); + if (alias != null) { + return alias; + } + index = domain.indexOf('.', index + 1); + } + + return null; + } + +} diff --git a/src/main/java/com/hivemq/security/ssl/SslContextFactory.java b/src/main/java/com/hivemq/security/ssl/SslContextFactory.java index 295f24687..a09d04885 100644 --- a/src/main/java/com/hivemq/security/ssl/SslContextFactory.java +++ b/src/main/java/com/hivemq/security/ssl/SslContextFactory.java @@ -242,33 +242,4 @@ private static Set<String> getDnsHostnamesFromCertificate(final X509Certificate throw new SslException("Invalid auth mode: " + clientAuthMode); } - private static class DnsResolver { - - private final Map<String, String> dnsMap; - - DnsResolver(final Map<String, String> dnsMap) { - this.dnsMap = dnsMap; - } - - String resolve(final String domain) { - String alias = dnsMap.get(domain); - if (alias != null) { - return alias; - } - - int index = domain.indexOf('.'); - while (index >= 0) { - final String wildcardDomain = "*" + domain.substring(index); - alias = dnsMap.get(wildcardDomain); - if (alias != null) { - return alias; - } - index = domain.indexOf('.', index + 1); - } - - return null; - } - - } - } diff --git a/src/test/java/com/hivemq/security/ssl/DnsResolverTest.java b/src/test/java/com/hivemq/security/ssl/DnsResolverTest.java new file mode 100644 index 000000000..3b3db28da --- /dev/null +++ b/src/test/java/com/hivemq/security/ssl/DnsResolverTest.java @@ -0,0 +1,51 @@ +package com.hivemq.security.ssl; + +import org.junit.Test; + +import java.util.Map; + +import static org.junit.Assert.*; + +public class DnsResolverTest { + public static final String ALIAS_1 = "alias1"; + public static final String TEST_EXAMPLE_COM = "test.example.com"; + + @Test + public void test_resolve_simple_dns_name() { + final DnsResolver dnsResolver = new DnsResolver(Map.of(TEST_EXAMPLE_COM, ALIAS_1)); + + final String resolve = dnsResolver.resolve(TEST_EXAMPLE_COM); + + assertNotNull(resolve); + assertEquals(ALIAS_1, resolve); + } + + @Test + public void test_resolve_non_matching_dns_name() { + final DnsResolver dnsResolver = new DnsResolver(Map.of(TEST_EXAMPLE_COM, ALIAS_1)); + + final String resolve = dnsResolver.resolve("other.example.com"); + + assertNull(resolve); + } + + @Test + public void test_resolve_wildcard_dns_name() { + final DnsResolver dnsResolver = new DnsResolver(Map.of("*.example.com", ALIAS_1)); + + final String resolve = dnsResolver.resolve(TEST_EXAMPLE_COM); + + assertNotNull(resolve); + assertEquals(ALIAS_1, resolve); + } + + @Test + public void test_resolve_nested_wildcard_dns_name() { + final DnsResolver dnsResolver = new DnsResolver(Map.of("*.example.com", ALIAS_1)); + + final String resolve = dnsResolver.resolve("sub.test.example.com"); + + assertNotNull(resolve); + assertEquals(ALIAS_1, resolve); + } +}