diff --git a/src/main/java/org/opensearch/geospatial/ip2geo/action/PutDatasourceRequest.java b/src/main/java/org/opensearch/geospatial/ip2geo/action/PutDatasourceRequest.java index 82e513f0..83aa3acc 100644 --- a/src/main/java/org/opensearch/geospatial/ip2geo/action/PutDatasourceRequest.java +++ b/src/main/java/org/opensearch/geospatial/ip2geo/action/PutDatasourceRequest.java @@ -9,6 +9,7 @@ package org.opensearch.geospatial.ip2geo.action; import java.io.IOException; +import java.io.UnsupportedEncodingException; import java.net.MalformedURLException; import java.net.URISyntaxException; import java.net.URL; @@ -19,8 +20,10 @@ import lombok.Setter; import lombok.extern.log4j.Log4j2; +import org.opensearch.OpenSearchException; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.master.AcknowledgedRequest; +import org.opensearch.common.Strings; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.unit.TimeValue; @@ -38,6 +41,7 @@ public class PutDatasourceRequest extends AcknowledgedRequest { private static final ParseField ENDPOINT_FIELD = new ParseField("endpoint"); private static final ParseField UPDATE_INTERVAL_IN_DAYS_FIELD = new ParseField("update_interval_in_days"); + private static final int MAX_DATASOURCE_NAME_BYTES = 255; /** * @param datasourceName the datasource name * @return the datasource name @@ -95,11 +99,50 @@ public void writeTo(final StreamOutput out) throws IOException { @Override public ActionRequestValidationException validate() { ActionRequestValidationException errors = new ActionRequestValidationException(); + validateDatasourceName(errors); validateEndpoint(errors); validateUpdateInterval(errors); return errors.validationErrors().isEmpty() ? null : errors; } + private void validateDatasourceName(final ActionRequestValidationException errors) { + if (!Strings.validFileName(datasourceName)) { + errors.addValidationError("Datasource name must not contain the following characters " + Strings.INVALID_FILENAME_CHARS); + return; + } + if (datasourceName.isEmpty()) { + errors.addValidationError("Datasource name must not be empty"); + return; + } + if (datasourceName.contains("#")) { + errors.addValidationError("Datasource name must not contain '#'"); + return; + } + if (datasourceName.contains(":")) { + errors.addValidationError("Datasource name must not contain ':'"); + return; + } + if (datasourceName.charAt(0) == '_' || datasourceName.charAt(0) == '-' || datasourceName.charAt(0) == '+') { + errors.addValidationError("Datasource name must not start with '_', '-', or '+'"); + return; + } + int byteCount = 0; + try { + byteCount = datasourceName.getBytes("UTF-8").length; + } catch (UnsupportedEncodingException e) { + // UTF-8 should always be supported, but rethrow this if it is not for some reason + throw new OpenSearchException("Unable to determine length of datasource name", e); + } + if (byteCount > MAX_DATASOURCE_NAME_BYTES) { + errors.addValidationError("Datasource name is too long, (" + byteCount + " > " + MAX_DATASOURCE_NAME_BYTES + ")"); + return; + } + if (datasourceName.equals(".") || datasourceName.equals("..")) { + errors.addValidationError("Datasource name must not be '.' or '..'"); + return; + } + } + /** * Conduct following validation on endpoint * 1. endpoint format complies with RFC-2396 diff --git a/src/test/java/org/opensearch/geospatial/ip2geo/action/PutDatasourceRequestTests.java b/src/test/java/org/opensearch/geospatial/ip2geo/action/PutDatasourceRequestTests.java index 383d832a..634d9581 100644 --- a/src/test/java/org/opensearch/geospatial/ip2geo/action/PutDatasourceRequestTests.java +++ b/src/test/java/org/opensearch/geospatial/ip2geo/action/PutDatasourceRequestTests.java @@ -8,10 +8,13 @@ package org.opensearch.geospatial.ip2geo.action; +import java.util.Arrays; import java.util.Locale; +import java.util.Map; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.Randomness; +import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamInput; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.unit.TimeValue; @@ -97,6 +100,60 @@ public void testValidateWithInvalidUrlInsideManifest() throws Exception { assertTrue(exception.validationErrors().get(0).contains("Invalid URL format")); } + public void testValidateDatasourceNames() throws Exception { + String validDatasourceName = GeospatialTestHelper.randomLowerCaseString(); + String domain = GeospatialTestHelper.randomLowerCaseString(); + PutDatasourceRequest request = new PutDatasourceRequest(validDatasourceName); + request.setEndpoint(sampleManifestUrl()); + request.setUpdateInterval(TimeValue.timeValueDays(Randomness.get().nextInt(30) + 1)); + + // Run + ActionRequestValidationException exception = request.validate(); + + // Verify + assertNull(exception); + + String fileNameChar = validDatasourceName + Strings.INVALID_FILENAME_CHARS.stream() + .skip(Randomness.get().nextInt(Strings.INVALID_FILENAME_CHARS.size() - 1)) + .findFirst(); + String startsWith = Arrays.asList("_", "-", "+").get(Randomness.get().nextInt(3)) + validDatasourceName; + String empty = ""; + String hash = validDatasourceName + "#"; + String colon = validDatasourceName + ":"; + StringBuilder longName = new StringBuilder(); + while (longName.length() < 256) { + longName.append(GeospatialTestHelper.randomLowerCaseString()); + } + String point = Arrays.asList(".", "..").get(Randomness.get().nextInt(2)); + Map nameToError = Map.of( + fileNameChar, + "not contain the following characters", + empty, + "must not be empty", + hash, + "must not contain '#'", + colon, + "must not contain ':'", + startsWith, + "must not start with", + longName.toString(), + "name is too long", + point, + "must not be '.' or '..'" + ); + + for (Map.Entry entry : nameToError.entrySet()) { + request.setName(entry.getKey()); + + // Run + exception = request.validate(); + + // Verify + assertEquals(1, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains(entry.getValue())); + } + } + public void testStreamInOut() throws Exception { String datasourceName = GeospatialTestHelper.randomLowerCaseString(); String domain = GeospatialTestHelper.randomLowerCaseString();