Skip to content

Commit

Permalink
[Inference API] Use extractOptionalPositiveInteger instead of removeA…
Browse files Browse the repository at this point in the history
…sType in AzureAiStudioEmbeddingsServiceSettings (elastic#110366)
  • Loading branch information
timgrein authored Jul 4, 2024
1 parent c629947 commit f3c811c
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,13 @@ public static int randomNonNegativeInt() {
return randomInt() & Integer.MAX_VALUE;
}

/**
* @return an <code>int</code> between <code>Integer.MIN_VALUE</code> and <code>-1</code> (inclusive) chosen uniformly at random.
*/
public static int randomNegativeInt() {
return randomInt() | Integer.MIN_VALUE;
}

public static float randomFloat() {
return random().nextFloat();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
Expand Down Expand Up @@ -185,6 +186,10 @@ public void testRandomNonNegativeInt() {
assertThat(randomNonNegativeInt(), greaterThanOrEqualTo(0));
}

public void testRandomNegativeInt() {
assertThat(randomNegativeInt(), lessThan(0));
}

public void testRandomValueOtherThan() {
// "normal" way of calling where the value is not null
int bad = randomInt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;

public class AzureAiStudioEmbeddingsServiceSettings extends AzureAiStudioServiceSettings {

Expand All @@ -59,10 +59,15 @@ private static AzureAiStudioEmbeddingCommonFields embeddingSettingsFromMap(
ConfigurationParseContext context
) {
var baseSettings = AzureAiStudioServiceSettings.fromMap(map, validationException, context);
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);

SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer maxTokens = extractOptionalPositiveInteger(
map,
MAX_INPUT_TOKENS,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);
Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException);

switch (context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,92 @@ public void testFromMap_Persistent_CreatesSettingsCorrectly() {
);
}

public void testFromMap_ThrowsException_WhenDimensionsAreZero() {
var target = "http://sometarget.local";
var provider = "openai";
var endpointType = "token";
var dimensions = 0;

var settingsMap = createRequestSettingsMap(target, provider, endpointType, dimensions, true, null, SimilarityMeasure.COSINE);

var thrownException = expectThrows(
ValidationException.class,
() -> AzureAiStudioEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST)
);

assertThat(
thrownException.getMessage(),
containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;")
);
}

public void testFromMap_ThrowsException_WhenDimensionsAreNegative() {
var target = "http://sometarget.local";
var provider = "openai";
var endpointType = "token";
var dimensions = randomNegativeInt();

var settingsMap = createRequestSettingsMap(target, provider, endpointType, dimensions, true, null, SimilarityMeasure.COSINE);

var thrownException = expectThrows(
ValidationException.class,
() -> AzureAiStudioEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST)
);

assertThat(
thrownException.getMessage(),
containsString(
Strings.format(
"Validation Failed: 1: [service_settings] Invalid value [%d]. [dimensions] must be a positive integer;",
dimensions
)
)
);
}

public void testFromMap_ThrowsException_WhenMaxInputTokensAreZero() {
var target = "http://sometarget.local";
var provider = "openai";
var endpointType = "token";
var maxInputTokens = 0;

var settingsMap = createRequestSettingsMap(target, provider, endpointType, null, true, maxInputTokens, SimilarityMeasure.COSINE);

var thrownException = expectThrows(
ValidationException.class,
() -> AzureAiStudioEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST)
);

assertThat(
thrownException.getMessage(),
containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;")
);
}

public void testFromMap_ThrowsException_WhenMaxInputTokensAreNegative() {
var target = "http://sometarget.local";
var provider = "openai";
var endpointType = "token";
var maxInputTokens = randomNegativeInt();

var settingsMap = createRequestSettingsMap(target, provider, endpointType, null, true, maxInputTokens, SimilarityMeasure.COSINE);

var thrownException = expectThrows(
ValidationException.class,
() -> AzureAiStudioEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST)
);

assertThat(
thrownException.getMessage(),
containsString(
Strings.format(
"Validation Failed: 1: [service_settings] Invalid value [%d]. [max_input_tokens] must be a positive integer;",
maxInputTokens
)
)
);
}

public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIsNull() {
var target = "http://sometarget.local";
var provider = "openai";
Expand Down

0 comments on commit f3c811c

Please sign in to comment.