diff --git a/pom.xml b/pom.xml old mode 100644 new mode 100755 index e165812..a4a3332 --- a/pom.xml +++ b/pom.xml @@ -26,9 +26,9 @@ redis-field-engineering redis-sql-trino - 3.3.1 - 6.2.2.RELEASE - 1.6.3 + 3.7.3 + 6.2.6.RELEASE + 1.6.4 5.1.0 true diff --git a/src/main/java/com/redis/trino/RediSearchConfig.java b/src/main/java/com/redis/trino/RediSearchConfig.java old mode 100644 new mode 100755 index 2122532..9314b6c --- a/src/main/java/com/redis/trino/RediSearchConfig.java +++ b/src/main/java/com/redis/trino/RediSearchConfig.java @@ -24,7 +24,6 @@ package com.redis.trino; import java.time.Duration; -import java.util.Optional; import javax.validation.constraints.Min; import javax.validation.constraints.NotNull; @@ -36,216 +35,228 @@ public class RediSearchConfig { - public static final String DEFAULT_SCHEMA = "default"; - public static final long DEFAULT_LIMIT = 10000; - public static final long DEFAULT_CURSOR_COUNT = 1000; - public static final Duration DEFAULT_TABLE_CACHE_EXPIRATION = Duration.ofHours(1); - public static final Duration DEFAULT_TABLE_CACHE_REFRESH = Duration.ofMinutes(1); - - private String defaultSchema = DEFAULT_SCHEMA; - private Optional uri = Optional.empty(); - private Optional username = Optional.empty(); - private Optional password = Optional.empty(); - private boolean insecure; - private boolean caseInsensitiveNames; - private long defaultLimit = DEFAULT_LIMIT; - private long cursorCount = DEFAULT_CURSOR_COUNT; - private long tableCacheExpiration = DEFAULT_TABLE_CACHE_EXPIRATION.toSeconds(); - private long tableCacheRefresh = DEFAULT_TABLE_CACHE_REFRESH.toSeconds(); - private boolean cluster; - private String caCertPath; - private String keyPath; - private String certPath; - private String keyPassword; - private boolean resp2; - - @Min(0) - public long getCursorCount() { - return cursorCount; - } - - @Config("redisearch.cursor-count") - public RediSearchConfig setCursorCount(long cursorCount) { - this.cursorCount = cursorCount; - return this; - } - - public long getDefaultLimit() { - return defaultLimit; - } - - @Config("redisearch.default-limit") - @ConfigDescription("Default search limit number to use") - public RediSearchConfig setDefaultLimit(long defaultLimit) { - this.defaultLimit = defaultLimit; - return this; - } - - public boolean isCaseInsensitiveNames() { - return caseInsensitiveNames; - } - - @Config("redisearch.case-insensitive-names") - @ConfigDescription("Case-insensitive name-matching") - public RediSearchConfig setCaseInsensitiveNames(boolean caseInsensitiveNames) { - this.caseInsensitiveNames = caseInsensitiveNames; - return this; - } - - public boolean isResp2() { - return resp2; - } - - @Config("redisearch.resp2") - @ConfigDescription("Force Redis protocol version to RESP2") - public RediSearchConfig setResp2(boolean resp2) { - this.resp2 = resp2; - return this; - } - - @Config("redisearch.table-cache-expiration") - @ConfigDescription("Duration in seconds since the entry creation after which a table should be automatically removed from the cache.") - public RediSearchConfig setTableCacheExpiration(long expirationDuration) { - this.tableCacheExpiration = expirationDuration; - return this; - } - - public long getTableCacheExpiration() { - return tableCacheExpiration; - } - - @Config("redisearch.table-cache-refresh") - @ConfigDescription("Duration in seconds since the entry creation after which to automatically refresh the table cache.") - public RediSearchConfig setTableCacheRefresh(long refreshDuration) { - this.tableCacheRefresh = refreshDuration; - return this; - } - - public long getTableCacheRefresh() { - return tableCacheRefresh; - } - - @NotNull - public String getDefaultSchema() { - return defaultSchema; - } - - @Config("redisearch.default-schema-name") - @ConfigDescription("Default schema name to use") - public RediSearchConfig setDefaultSchema(String defaultSchema) { - this.defaultSchema = defaultSchema; - return this; - } - - @NotNull - public Optional<@Pattern(message = "Invalid Redis URI. Expected redis:// rediss://", regexp = "^rediss?://.*") String> getUri() { - return uri; - } - - @Config("redisearch.uri") - @ConfigDescription("Redis connection URI e.g. 'redis://localhost:6379'") - @ConfigSecuritySensitive - public RediSearchConfig setUri(String uri) { - this.uri = Optional.ofNullable(uri); - return this; - } - - public Optional getUsername() { - return username; - } - - @Config("redisearch.username") - @ConfigDescription("Redis connection username") - @ConfigSecuritySensitive - public RediSearchConfig setUsername(String username) { - this.username = Optional.ofNullable(username); - return this; - } - - public Optional getPassword() { - return password; - } - - @Config("redisearch.password") - @ConfigDescription("Redis connection password") - @ConfigSecuritySensitive - public RediSearchConfig setPassword(String password) { - this.password = Optional.ofNullable(password); - return this; - } - - public boolean isCluster() { - return cluster; - } - - @Config("redisearch.cluster") - @ConfigDescription("Connect to a Redis Cluster") - public RediSearchConfig setCluster(boolean cluster) { - this.cluster = cluster; - return this; - } - - public boolean isInsecure() { - return insecure; - } - - @Config("redisearch.insecure") - @ConfigDescription("Allow insecure connections (e.g. invalid certificates) to Redis when using SSL") - public RediSearchConfig setInsecure(boolean insecure) { - this.insecure = insecure; - return this; - } - - public Optional getCaCertPath() { - return optionalPath(caCertPath); - } - - private Optional optionalPath(String path) { - if (path == null || path.isEmpty()) { - return Optional.empty(); - } - return Optional.of(path); - } - - @Config("redisearch.cacert-path") - @ConfigDescription("X.509 CA certificate file to verify with") - public RediSearchConfig setCaCertPath(String caCertPath) { - this.caCertPath = caCertPath; - return this; - } - - public Optional getKeyPath() { - return optionalPath(keyPath); - } - - @Config("redisearch.key-path") - @ConfigDescription("PKCS#8 private key file to authenticate with (PEM format)") - public RediSearchConfig setKeyPath(String keyPath) { - this.keyPath = keyPath; - return this; - } - - public Optional getKeyPassword() { - return Optional.ofNullable(keyPassword); - } - - @Config("redisearch.key-password") - @ConfigSecuritySensitive - @ConfigDescription("Password of the private key file, or null if it's not password-protected") - public RediSearchConfig setKeyPassword(String keyPassword) { - this.keyPassword = keyPassword; - return this; - } - - public Optional getCertPath() { - return optionalPath(certPath); - } - - @Config("redisearch.cert-path") - @ConfigDescription("X.509 certificate chain file to authenticate with (PEM format)") - public RediSearchConfig setCertPath(String certPath) { - this.certPath = certPath; - return this; - } + public static final String DEFAULT_SCHEMA = "default"; + + public static final long DEFAULT_LIMIT = 10000; + + public static final long DEFAULT_CURSOR_COUNT = 1000; + + public static final Duration DEFAULT_TABLE_CACHE_EXPIRATION = Duration.ofHours(1); + + public static final Duration DEFAULT_TABLE_CACHE_REFRESH = Duration.ofMinutes(1); + + private String defaultSchema = DEFAULT_SCHEMA; + + private String uri; + + private String username; + + private String password; + + private boolean insecure; + + private boolean cluster; + + private String caCertPath; + + private String keyPath; + + private String certPath; + + private String keyPassword; + + private boolean resp2; + + private boolean caseInsensitiveNames; + + private long defaultLimit = DEFAULT_LIMIT; + + private long cursorCount = DEFAULT_CURSOR_COUNT; + + private long tableCacheExpiration = DEFAULT_TABLE_CACHE_EXPIRATION.toSeconds(); + + private long tableCacheRefresh = DEFAULT_TABLE_CACHE_REFRESH.toSeconds(); + + @Min(0) + public long getCursorCount() { + return cursorCount; + } + + @Config("redisearch.cursor-count") + public RediSearchConfig setCursorCount(long cursorCount) { + this.cursorCount = cursorCount; + return this; + } + + public long getDefaultLimit() { + return defaultLimit; + } + + @Config("redisearch.default-limit") + @ConfigDescription("Default search limit number to use") + public RediSearchConfig setDefaultLimit(long defaultLimit) { + this.defaultLimit = defaultLimit; + return this; + } + + public boolean isCaseInsensitiveNames() { + return caseInsensitiveNames; + } + + @Config("redisearch.case-insensitive-names") + @ConfigDescription("Case-insensitive name-matching") + public RediSearchConfig setCaseInsensitiveNames(boolean caseInsensitiveNames) { + this.caseInsensitiveNames = caseInsensitiveNames; + return this; + } + + public boolean isResp2() { + return resp2; + } + + @Config("redisearch.resp2") + @ConfigDescription("Force Redis protocol version to RESP2") + public RediSearchConfig setResp2(boolean resp2) { + this.resp2 = resp2; + return this; + } + + @Config("redisearch.table-cache-expiration") + @ConfigDescription("Duration in seconds since the entry creation after which a table should be automatically removed from the cache.") + public RediSearchConfig setTableCacheExpiration(long expirationDuration) { + this.tableCacheExpiration = expirationDuration; + return this; + } + + public long getTableCacheExpiration() { + return tableCacheExpiration; + } + + @Config("redisearch.table-cache-refresh") + @ConfigDescription("Duration in seconds since the entry creation after which to automatically refresh the table cache.") + public RediSearchConfig setTableCacheRefresh(long refreshDuration) { + this.tableCacheRefresh = refreshDuration; + return this; + } + + public long getTableCacheRefresh() { + return tableCacheRefresh; + } + + @NotNull + public String getDefaultSchema() { + return defaultSchema; + } + + @Config("redisearch.default-schema-name") + @ConfigDescription("Default schema name to use") + public RediSearchConfig setDefaultSchema(String defaultSchema) { + this.defaultSchema = defaultSchema; + return this; + } + + @NotNull + public @Pattern(message = "Invalid Redis URI. Expected redis:// rediss://", regexp = "^rediss?://.*") String getUri() { + return uri; + } + + @Config("redisearch.uri") + @ConfigDescription("Redis connection URI e.g. 'redis://localhost:6379'") + @ConfigSecuritySensitive + public RediSearchConfig setUri(String uri) { + this.uri = uri; + return this; + } + + public String getUsername() { + return username; + } + + @Config("redisearch.username") + @ConfigDescription("Redis connection username") + @ConfigSecuritySensitive + public RediSearchConfig setUsername(String username) { + this.username = username; + return this; + } + + public String getPassword() { + return password; + } + + @Config("redisearch.password") + @ConfigDescription("Redis connection password") + @ConfigSecuritySensitive + public RediSearchConfig setPassword(String password) { + this.password = password; + return this; + } + + public boolean isCluster() { + return cluster; + } + + @Config("redisearch.cluster") + @ConfigDescription("Connect to a Redis Cluster") + public RediSearchConfig setCluster(boolean cluster) { + this.cluster = cluster; + return this; + } + + public boolean isInsecure() { + return insecure; + } + + @Config("redisearch.insecure") + @ConfigDescription("Allow insecure connections (e.g. invalid certificates) to Redis when using SSL") + public RediSearchConfig setInsecure(boolean insecure) { + this.insecure = insecure; + return this; + } + + public String getCaCertPath() { + return caCertPath; + } + + @Config("redisearch.cacert-path") + @ConfigDescription("X.509 CA certificate file to verify with") + public RediSearchConfig setCaCertPath(String caCertPath) { + this.caCertPath = caCertPath; + return this; + } + + public String getKeyPath() { + return keyPath; + } + + @Config("redisearch.key-path") + @ConfigDescription("PKCS#8 private key file to authenticate with (PEM format)") + public RediSearchConfig setKeyPath(String keyPath) { + this.keyPath = keyPath; + return this; + } + + public String getKeyPassword() { + return keyPassword; + } + + @Config("redisearch.key-password") + @ConfigSecuritySensitive + @ConfigDescription("Password of the private key file, or null if it's not password-protected") + public RediSearchConfig setKeyPassword(String keyPassword) { + this.keyPassword = keyPassword; + return this; + } + + public String getCertPath() { + return certPath; + } + + @Config("redisearch.cert-path") + @ConfigDescription("X.509 certificate chain file to authenticate with (PEM format)") + public RediSearchConfig setCertPath(String certPath) { + this.certPath = certPath; + return this; + } } diff --git a/src/main/java/com/redis/trino/RediSearchSession.java b/src/main/java/com/redis/trino/RediSearchSession.java old mode 100644 new mode 100755 index 1103d53..37b008b --- a/src/main/java/com/redis/trino/RediSearchSession.java +++ b/src/main/java/com/redis/trino/RediSearchSession.java @@ -49,7 +49,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.UncheckedExecutionException; +import com.redis.lettucemod.RedisModulesClient; import com.redis.lettucemod.api.StatefulRedisModulesConnection; +import com.redis.lettucemod.cluster.RedisModulesClusterClient; import com.redis.lettucemod.search.AggregateOperation; import com.redis.lettucemod.search.AggregateOptions; import com.redis.lettucemod.search.AggregateWithCursorResults; @@ -60,16 +62,17 @@ import com.redis.lettucemod.search.Group; import com.redis.lettucemod.search.IndexInfo; import com.redis.lettucemod.search.SearchResults; -import com.redis.lettucemod.util.ClientBuilder; import com.redis.lettucemod.util.RedisModulesUtils; -import com.redis.lettucemod.util.RedisURIBuilder; import com.redis.trino.RediSearchTranslator.Aggregation; import com.redis.trino.RediSearchTranslator.Search; import io.airlift.log.Logger; import io.lettuce.core.AbstractRedisClient; +import io.lettuce.core.ClientOptions; import io.lettuce.core.RedisURI; -import io.lettuce.core.SslVerifyMode; +import io.lettuce.core.SslOptions; +import io.lettuce.core.SslOptions.Builder; +import io.lettuce.core.cluster.ClusterClientOptions; import io.lettuce.core.protocol.ProtocolVersion; import io.trino.collect.cache.EvictableCacheBuilder; import io.trino.spi.HostAddress; @@ -98,334 +101,365 @@ public class RediSearchSession { - private static final Logger log = Logger.get(RediSearchSession.class); - - private final TypeManager typeManager; - private final RediSearchConfig config; - private final RediSearchTranslator translator; - private final AbstractRedisClient client; - private final StatefulRedisModulesConnection connection; - private final Cache tableCache; - - public RediSearchSession(TypeManager typeManager, RediSearchConfig config) { - this.typeManager = requireNonNull(typeManager, "typeManager is null"); - this.config = requireNonNull(config, "config is null"); - this.translator = new RediSearchTranslator(config); - this.client = client(config); - this.connection = RedisModulesUtils.connection(client); - this.tableCache = EvictableCacheBuilder.newBuilder() - .expireAfterWrite(config.getTableCacheRefresh(), TimeUnit.SECONDS).build(); - } - - private AbstractRedisClient client(RediSearchConfig config) { - ClientBuilder builder = ClientBuilder.create(redisURI(config)); - builder.cluster(config.isCluster()); - config.getKeyPath().map(File::new).ifPresent(builder::key); - config.getCertPath().map(File::new).ifPresent(builder::keyCert); - config.getKeyPassword().ifPresent(p -> builder.keyPassword(p.toCharArray())); - config.getCaCertPath().map(File::new).ifPresent(builder::trustManager); - if (config.isResp2()) { - builder.protocolVersion(ProtocolVersion.RESP2); - } - return builder.build(); - } - - private RedisURI redisURI(RediSearchConfig config) { - RedisURIBuilder uri = RedisURIBuilder.create(); - uri.uriString(config.getUri()); - config.getUsername().ifPresent(uri::username); - config.getPassword().ifPresent(p -> uri.password(p.toCharArray())); - if (config.isInsecure()) { - uri.sslVerifyMode(SslVerifyMode.NONE); - } - return uri.build(); - } - - public StatefulRedisModulesConnection getConnection() { - return connection; - } - - public RediSearchConfig getConfig() { - return config; - } - - public void shutdown() { - connection.close(); - client.shutdown(); - client.getResources().shutdown(); - } - - public List getAddresses() { - Optional uri = config.getUri(); - if (uri.isPresent()) { - RedisURI redisURI = RedisURI.create(uri.get()); - return Collections.singletonList(HostAddress.fromParts(redisURI.getHost(), redisURI.getPort())); - } - return Collections.emptyList(); - } - - private Set listIndexNames() throws SchemaNotFoundException { - ImmutableSet.Builder builder = ImmutableSet.builder(); - builder.addAll(connection.sync().ftList()); - return builder.build(); - } - - /** - * - * @param schemaTableName SchemaTableName to load - * @return RediSearchTable describing the RediSearch index - * @throws TableNotFoundException if no index by that name was found - */ - public RediSearchTable getTable(SchemaTableName tableName) throws TableNotFoundException { - try { - return tableCache.get(tableName, () -> loadTableSchema(tableName)); - } catch (ExecutionException | UncheckedExecutionException e) { - throwIfInstanceOf(e.getCause(), TrinoException.class); - throw new RuntimeException(e); - } - } - - public Set getAllTables() { - return listIndexNames().stream().collect(toSet()); - } - - @SuppressWarnings("unchecked") - public void createTable(SchemaTableName schemaTableName, List columns) { - String index = schemaTableName.getTableName(); - if (!connection.sync().ftList().contains(index)) { - List> fields = columns.stream().filter(c -> !RediSearchBuiltinField.isKeyColumn(c.getName())) - .map(c -> buildField(c.getName(), c.getType())).collect(Collectors.toList()); - CreateOptions.Builder options = CreateOptions.builder(); - options.prefix(index + ":"); - connection.sync().ftCreate(index, options.build(), fields.toArray(Field[]::new)); - } - } - - public void dropTable(SchemaTableName tableName) { - connection.sync().ftDropindexDeleteDocs(toRemoteTableName(tableName.getTableName())); - tableCache.invalidate(tableName); - } - - public void addColumn(SchemaTableName schemaTableName, ColumnMetadata columnMetadata) { - String tableName = toRemoteTableName(schemaTableName.getTableName()); - connection.sync().ftAlter(tableName, buildField(columnMetadata.getName(), columnMetadata.getType())); - tableCache.invalidate(schemaTableName); - } - - private String toRemoteTableName(String tableName) { - verify(tableName.equals(tableName.toLowerCase(ENGLISH)), "tableName not in lower-case: %s", tableName); - if (!config.isCaseInsensitiveNames()) { - return tableName; - } - for (String remoteTableName : listIndexNames()) { - if (tableName.equals(remoteTableName.toLowerCase(ENGLISH))) { - return remoteTableName; - } - } - return tableName; - } - - public void dropColumn(SchemaTableName schemaTableName, String columnName) { - throw new TrinoException(NOT_SUPPORTED, "This connector does not support dropping columns"); - } - - /** - * - * @param schemaTableName SchemaTableName to load - * @return RediSearchTable describing the RediSearch index - * @throws TableNotFoundException if no index by that name was found - */ - private RediSearchTable loadTableSchema(SchemaTableName schemaTableName) throws TableNotFoundException { - String index = toRemoteTableName(schemaTableName.getTableName()); - Optional indexInfoOptional = indexInfo(index); - if (indexInfoOptional.isEmpty()) { - throw new TableNotFoundException(schemaTableName, format("Index '%s' not found", index), null); - } - IndexInfo indexInfo = indexInfoOptional.get(); - Set fields = new HashSet<>(); - ImmutableList.Builder columns = ImmutableList.builder(); - for (RediSearchBuiltinField builtinfield : RediSearchBuiltinField.values()) { - fields.add(builtinfield.getName()); - columns.add(builtinfield.getColumnHandle()); - } - for (Field indexedField : indexInfo.getFields()) { - RediSearchColumnHandle column = buildColumnHandle(indexedField); - fields.add(column.getName()); - columns.add(column); - } - SearchResults results = connection.sync().ftSearch(index, "*"); - for (Document doc : results) { - for (String docField : doc.keySet()) { - if (fields.contains(docField)) { - continue; - } - columns.add(new RediSearchColumnHandle(docField, VarcharType.VARCHAR, Field.Type.TEXT, false, false)); - fields.add(docField); - } - } - RediSearchTableHandle tableHandle = new RediSearchTableHandle(schemaTableName, index); - return new RediSearchTable(tableHandle, columns.build(), indexInfo); - } - - private Optional indexInfo(String index) { - try { - List indexInfoList = connection.sync().ftInfo(index); - if (indexInfoList != null) { - return Optional.of(RedisModulesUtils.indexInfo(indexInfoList)); - } - } catch (Exception e) { - // Ignore as index might not exist - } - return Optional.empty(); - } - - private RediSearchColumnHandle buildColumnHandle(Field field) { - return buildColumnHandle(name(field), field.getType(), false, true); - } - - private String name(Field field) { - Optional as = field.getAs(); - if (as.isEmpty()) { - return field.getName(); - } - return as.get(); - } - - private RediSearchColumnHandle buildColumnHandle(String name, Field.Type type, boolean hidden, - boolean supportsPredicates) { - return new RediSearchColumnHandle(name, columnType(type), type, hidden, supportsPredicates); - } - - private Type columnType(Field.Type type) { - return columnType(typeSignature(type)); - } - - private Type columnType(TypeSignature typeSignature) { - return typeManager.fromSqlType(typeSignature.toString()); - } - - public SearchResults search(RediSearchTableHandle tableHandle, String[] columns) { - Search search = translator.search(tableHandle, columns); - log.info("Running %s", search); - return connection.sync().ftSearch(search.getIndex(), search.getQuery(), search.getOptions()); - } - - public AggregateWithCursorResults aggregate(RediSearchTableHandle table, String[] columnNames) { - Aggregation aggregation = translator.aggregate(table, columnNames); - log.info("Running %s", aggregation); - String index = aggregation.getIndex(); - String query = aggregation.getQuery(); - CursorOptions cursor = aggregation.getCursorOptions(); - AggregateOptions options = aggregation.getOptions(); - AggregateWithCursorResults results = connection.sync().ftAggregate(index, query, cursor, options); - List> groupBys = aggregation.getOptions().getOperations().stream() - .filter(this::isGroupOperation).collect(Collectors.toList()); - if (results.isEmpty() && !groupBys.isEmpty()) { - Group groupBy = (Group) groupBys.get(0); - Optional as = groupBy.getReducers()[0].getAs(); - if (as.isPresent()) { - Map doc = new HashMap<>(); - doc.put(as.get(), 0); - results.add(doc); - } - } - return results; - } - - private boolean isGroupOperation(AggregateOperation operation) { - return operation.getType() == AggregateOperation.Type.GROUP; - } - - public AggregateWithCursorResults cursorRead(RediSearchTableHandle tableHandle, long cursor) { - String index = tableHandle.getIndex(); - if (config.getCursorCount() > 0) { - return connection.sync().ftCursorRead(index, cursor, config.getCursorCount()); - } - return connection.sync().ftCursorRead(index, cursor); - } - - private Field buildField(String columnName, Type columnType) { - Field.Type fieldType = toFieldType(columnType); - switch (fieldType) { - case GEO: - return Field.geo(columnName).build(); - case NUMERIC: - return Field.numeric(columnName).build(); - case TAG: - return Field.tag(columnName).build(); - case TEXT: - return Field.text(columnName).build(); - } - throw new IllegalArgumentException(String.format("Field type %s not supported", fieldType)); - } - - public static Field.Type toFieldType(Type type) { - if (type.equals(BooleanType.BOOLEAN)) { - return Field.Type.NUMERIC; - } - if (type.equals(BigintType.BIGINT)) { - return Field.Type.NUMERIC; - } - if (type.equals(IntegerType.INTEGER)) { - return Field.Type.NUMERIC; - } - if (type.equals(SmallintType.SMALLINT)) { - return Field.Type.NUMERIC; - } - if (type.equals(TinyintType.TINYINT)) { - return Field.Type.NUMERIC; - } - if (type.equals(DoubleType.DOUBLE)) { - return Field.Type.NUMERIC; - } - if (type.equals(RealType.REAL)) { - return Field.Type.NUMERIC; - } - if (type instanceof DecimalType) { - return Field.Type.NUMERIC; - } - if (type instanceof VarcharType) { - return Field.Type.TAG; - } - if (type instanceof CharType) { - return Field.Type.TAG; - } - if (type.equals(DateType.DATE)) { - return Field.Type.NUMERIC; - } - if (type.equals(TimestampType.TIMESTAMP_MILLIS)) { - return Field.Type.NUMERIC; - } - if (type.equals(TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS)) { - return Field.Type.NUMERIC; - } - if (type.equals(UuidType.UUID)) { - return Field.Type.TAG; - } - throw new IllegalArgumentException("unsupported type: " + type); - } - - private TypeSignature typeSignature(Field.Type type) { - if (type == Field.Type.NUMERIC) { - return doubleType(); - } - return varcharType(); - } - - private TypeSignature doubleType() { - return DOUBLE.getTypeSignature(); - } - - private TypeSignature varcharType() { - return createUnboundedVarcharType().getTypeSignature(); - } - - public void cursorDelete(RediSearchTableHandle tableHandle, long cursor) { - connection.sync().ftCursorDelete(tableHandle.getIndex(), cursor); - } - - public Long deleteDocs(List docIds) { - return connection.sync().del(docIds.toArray(String[]::new)); - } + private static final Logger log = Logger.get(RediSearchSession.class); + + private final TypeManager typeManager; + + private final RediSearchConfig config; + + private final RediSearchTranslator translator; + + private final AbstractRedisClient client; + + private final StatefulRedisModulesConnection connection; + + private final Cache tableCache; + + public RediSearchSession(TypeManager typeManager, RediSearchConfig config) { + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.config = requireNonNull(config, "config is null"); + this.translator = new RediSearchTranslator(config); + this.client = client(config); + this.connection = RedisModulesUtils.connection(client); + this.tableCache = EvictableCacheBuilder.newBuilder().expireAfterWrite(config.getTableCacheRefresh(), TimeUnit.SECONDS) + .build(); + } + + private AbstractRedisClient client(RediSearchConfig config) { + RedisURI redisURI = redisURI(config); + if (config.isCluster()) { + RedisModulesClusterClient clusterClient = RedisModulesClusterClient.create(redisURI); + clusterClient.setOptions(ClusterClientOptions.builder(clientOptions(config)).build()); + return clusterClient; + } + RedisModulesClient redisClient = RedisModulesClient.create(redisURI); + redisClient.setOptions(clientOptions(config)); + return redisClient; + } + + private ClientOptions clientOptions(RediSearchConfig config) { + ClientOptions.Builder builder = ClientOptions.builder(); + builder.sslOptions(sslOptions(config)); + builder.protocolVersion(protocolVersion(config)); + return builder.build(); + } + + private ProtocolVersion protocolVersion(RediSearchConfig config) { + if (config.isResp2()) { + return ProtocolVersion.RESP2; + } + return RedisModulesClient.DEFAULT_PROTOCOL_VERSION; + } + + public SslOptions sslOptions(RediSearchConfig config) { + Builder ssl = SslOptions.builder(); + if (config.getKeyPath() != null) { + ssl.keyManager(new File(config.getCertPath()), new File(config.getKeyPath()), + config.getKeyPassword().toCharArray()); + } + if (config.getCaCertPath() != null) { + ssl.trustManager(new File(config.getCaCertPath())); + } + return ssl.build(); + } + + private RedisURI redisURI(RediSearchConfig config) { + RedisURI.Builder uri = RedisURI.builder(RedisURI.create(config.getUri())); + if (config.getPassword() != null) { + if (config.getUsername() != null) { + uri.withAuthentication(config.getUsername(), config.getPassword()); + } else { + uri.withPassword(config.getPassword().toCharArray()); + } + } + if (config.isInsecure()) { + uri.withVerifyPeer(false); + } + return uri.build(); + } + + public StatefulRedisModulesConnection getConnection() { + return connection; + } + + public RediSearchConfig getConfig() { + return config; + } + + public void shutdown() { + connection.close(); + client.shutdown(); + client.getResources().shutdown(); + } + + public List getAddresses() { + RedisURI redisURI = RedisURI.create(config.getUri()); + return Collections.singletonList(HostAddress.fromParts(redisURI.getHost(), redisURI.getPort())); + } + + private Set listIndexNames() throws SchemaNotFoundException { + ImmutableSet.Builder builder = ImmutableSet.builder(); + builder.addAll(connection.sync().ftList()); + return builder.build(); + } + + /** + * + * @param schemaTableName SchemaTableName to load + * @return RediSearchTable describing the RediSearch index + * @throws TableNotFoundException if no index by that name was found + */ + public RediSearchTable getTable(SchemaTableName tableName) throws TableNotFoundException { + try { + return tableCache.get(tableName, () -> loadTableSchema(tableName)); + } catch (ExecutionException | UncheckedExecutionException e) { + throwIfInstanceOf(e.getCause(), TrinoException.class); + throw new RuntimeException(e); + } + } + + public Set getAllTables() { + return listIndexNames().stream().collect(toSet()); + } + + @SuppressWarnings("unchecked") + public void createTable(SchemaTableName schemaTableName, List columns) { + String index = schemaTableName.getTableName(); + if (!connection.sync().ftList().contains(index)) { + List> fields = columns.stream().filter(c -> !RediSearchBuiltinField.isKeyColumn(c.getName())) + .map(c -> buildField(c.getName(), c.getType())).collect(Collectors.toList()); + CreateOptions.Builder options = CreateOptions. builder(); + options.prefix(index + ":"); + connection.sync().ftCreate(index, options.build(), fields.toArray(Field[]::new)); + } + } + + public void dropTable(SchemaTableName tableName) { + connection.sync().ftDropindexDeleteDocs(toRemoteTableName(tableName.getTableName())); + tableCache.invalidate(tableName); + } + + public void addColumn(SchemaTableName schemaTableName, ColumnMetadata columnMetadata) { + String tableName = toRemoteTableName(schemaTableName.getTableName()); + connection.sync().ftAlter(tableName, buildField(columnMetadata.getName(), columnMetadata.getType())); + tableCache.invalidate(schemaTableName); + } + + private String toRemoteTableName(String tableName) { + verify(tableName.equals(tableName.toLowerCase(ENGLISH)), "tableName not in lower-case: %s", tableName); + if (!config.isCaseInsensitiveNames()) { + return tableName; + } + for (String remoteTableName : listIndexNames()) { + if (tableName.equals(remoteTableName.toLowerCase(ENGLISH))) { + return remoteTableName; + } + } + return tableName; + } + + public void dropColumn(SchemaTableName schemaTableName, String columnName) { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support dropping columns"); + } + + /** + * + * @param schemaTableName SchemaTableName to load + * @return RediSearchTable describing the RediSearch index + * @throws TableNotFoundException if no index by that name was found + */ + private RediSearchTable loadTableSchema(SchemaTableName schemaTableName) throws TableNotFoundException { + String index = toRemoteTableName(schemaTableName.getTableName()); + Optional indexInfoOptional = indexInfo(index); + if (indexInfoOptional.isEmpty()) { + throw new TableNotFoundException(schemaTableName, format("Index '%s' not found", index), null); + } + IndexInfo indexInfo = indexInfoOptional.get(); + Set fields = new HashSet<>(); + ImmutableList.Builder columns = ImmutableList.builder(); + for (RediSearchBuiltinField builtinfield : RediSearchBuiltinField.values()) { + fields.add(builtinfield.getName()); + columns.add(builtinfield.getColumnHandle()); + } + for (Field indexedField : indexInfo.getFields()) { + RediSearchColumnHandle column = buildColumnHandle(indexedField); + fields.add(column.getName()); + columns.add(column); + } + SearchResults results = connection.sync().ftSearch(index, "*"); + for (Document doc : results) { + for (String docField : doc.keySet()) { + if (fields.contains(docField)) { + continue; + } + columns.add(new RediSearchColumnHandle(docField, VarcharType.VARCHAR, Field.Type.TEXT, false, false)); + fields.add(docField); + } + } + RediSearchTableHandle tableHandle = new RediSearchTableHandle(schemaTableName, index); + return new RediSearchTable(tableHandle, columns.build(), indexInfo); + } + + private Optional indexInfo(String index) { + try { + List indexInfoList = connection.sync().ftInfo(index); + if (indexInfoList != null) { + return Optional.of(RedisModulesUtils.indexInfo(indexInfoList)); + } + } catch (Exception e) { + // Ignore as index might not exist + } + return Optional.empty(); + } + + private RediSearchColumnHandle buildColumnHandle(Field field) { + return buildColumnHandle(name(field), field.getType(), false, true); + } + + private String name(Field field) { + Optional as = field.getAs(); + if (as.isEmpty()) { + return field.getName(); + } + return as.get(); + } + + private RediSearchColumnHandle buildColumnHandle(String name, Field.Type type, boolean hidden, boolean supportsPredicates) { + return new RediSearchColumnHandle(name, columnType(type), type, hidden, supportsPredicates); + } + + private Type columnType(Field.Type type) { + return columnType(typeSignature(type)); + } + + private Type columnType(TypeSignature typeSignature) { + return typeManager.fromSqlType(typeSignature.toString()); + } + + public SearchResults search(RediSearchTableHandle tableHandle, String[] columns) { + Search search = translator.search(tableHandle, columns); + log.info("Running %s", search); + return connection.sync().ftSearch(search.getIndex(), search.getQuery(), search.getOptions()); + } + + public AggregateWithCursorResults aggregate(RediSearchTableHandle table, String[] columnNames) { + Aggregation aggregation = translator.aggregate(table, columnNames); + log.info("Running %s", aggregation); + String index = aggregation.getIndex(); + String query = aggregation.getQuery(); + CursorOptions cursor = aggregation.getCursorOptions(); + AggregateOptions options = aggregation.getOptions(); + AggregateWithCursorResults results = connection.sync().ftAggregate(index, query, cursor, options); + List> groupBys = aggregation.getOptions().getOperations().stream() + .filter(this::isGroupOperation).collect(Collectors.toList()); + if (results.isEmpty() && !groupBys.isEmpty()) { + Group groupBy = (Group) groupBys.get(0); + Optional as = groupBy.getReducers()[0].getAs(); + if (as.isPresent()) { + Map doc = new HashMap<>(); + doc.put(as.get(), 0); + results.add(doc); + } + } + return results; + } + + private boolean isGroupOperation(AggregateOperation operation) { + return operation.getType() == AggregateOperation.Type.GROUP; + } + + public AggregateWithCursorResults cursorRead(RediSearchTableHandle tableHandle, long cursor) { + String index = tableHandle.getIndex(); + if (config.getCursorCount() > 0) { + return connection.sync().ftCursorRead(index, cursor, config.getCursorCount()); + } + return connection.sync().ftCursorRead(index, cursor); + } + + private Field buildField(String columnName, Type columnType) { + Field.Type fieldType = toFieldType(columnType); + switch (fieldType) { + case GEO: + return Field.geo(columnName).build(); + case NUMERIC: + return Field.numeric(columnName).build(); + case TAG: + return Field.tag(columnName).build(); + case TEXT: + return Field.text(columnName).build(); + case VECTOR: + throw new UnsupportedOperationException("Vector field not supported"); + } + throw new IllegalArgumentException(String.format("Field type %s not supported", fieldType)); + } + + public static Field.Type toFieldType(Type type) { + if (type.equals(BooleanType.BOOLEAN)) { + return Field.Type.NUMERIC; + } + if (type.equals(BigintType.BIGINT)) { + return Field.Type.NUMERIC; + } + if (type.equals(IntegerType.INTEGER)) { + return Field.Type.NUMERIC; + } + if (type.equals(SmallintType.SMALLINT)) { + return Field.Type.NUMERIC; + } + if (type.equals(TinyintType.TINYINT)) { + return Field.Type.NUMERIC; + } + if (type.equals(DoubleType.DOUBLE)) { + return Field.Type.NUMERIC; + } + if (type.equals(RealType.REAL)) { + return Field.Type.NUMERIC; + } + if (type instanceof DecimalType) { + return Field.Type.NUMERIC; + } + if (type instanceof VarcharType) { + return Field.Type.TAG; + } + if (type instanceof CharType) { + return Field.Type.TAG; + } + if (type.equals(DateType.DATE)) { + return Field.Type.NUMERIC; + } + if (type.equals(TimestampType.TIMESTAMP_MILLIS)) { + return Field.Type.NUMERIC; + } + if (type.equals(TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS)) { + return Field.Type.NUMERIC; + } + if (type.equals(UuidType.UUID)) { + return Field.Type.TAG; + } + throw new IllegalArgumentException("unsupported type: " + type); + } + + private TypeSignature typeSignature(Field.Type type) { + if (type == Field.Type.NUMERIC) { + return doubleType(); + } + return varcharType(); + } + + private TypeSignature doubleType() { + return DOUBLE.getTypeSignature(); + } + + private TypeSignature varcharType() { + return createUnboundedVarcharType().getTypeSignature(); + } + + public void cursorDelete(RediSearchTableHandle tableHandle, long cursor) { + connection.sync().ftCursorDelete(tableHandle.getIndex(), cursor); + } + + public Long deleteDocs(List docIds) { + return connection.sync().del(docIds.toArray(String[]::new)); + } } diff --git a/src/test/java/com/redis/trino/RediSearchServer.java b/src/test/java/com/redis/trino/RediSearchServer.java old mode 100644 new mode 100755 index 1b8e24b..645c2e4 --- a/src/test/java/com/redis/trino/RediSearchServer.java +++ b/src/test/java/com/redis/trino/RediSearchServer.java @@ -2,10 +2,9 @@ import java.io.Closeable; -import org.testcontainers.utility.DockerImageName; - +import com.redis.lettucemod.RedisModulesClient; import com.redis.lettucemod.api.StatefulRedisModulesConnection; -import com.redis.lettucemod.util.ClientBuilder; +import com.redis.lettucemod.cluster.RedisModulesClusterClient; import com.redis.lettucemod.util.RedisModulesUtils; import com.redis.testcontainers.RedisStackContainer; @@ -14,12 +13,9 @@ public class RediSearchServer implements Closeable { - private static final String TAG = "6.2.6-v9"; - - private static final DockerImageName DOCKER_IMAGE_NAME = RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(TAG); - - private final RedisStackContainer container = new RedisStackContainer(DOCKER_IMAGE_NAME).withEnv("REDISEARCH_ARGS", - "MAXAGGREGATERESULTS -1"); + private final RedisStackContainer container = new RedisStackContainer( + RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)).withEnv("REDISEARCH_ARGS", + "MAXAGGREGATERESULTS -1"); private final AbstractRedisClient client; @@ -27,7 +23,8 @@ public class RediSearchServer implements Closeable { public RediSearchServer() { this.container.start(); - this.client = ClientBuilder.create(RedisURI.create(container.getRedisURI())).cluster(container.isCluster()).build(); + RedisURI uri = RedisURI.create(container.getRedisURI()); + this.client = container.isCluster() ? RedisModulesClusterClient.create(uri) : RedisModulesClient.create(uri); this.connection = RedisModulesUtils.connection(client); }