Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
Vector DB Support: download automatically the AstraDB secure bundle (L…
Browse files Browse the repository at this point in the history
  • Loading branch information
eolivelli authored Sep 8, 2023
1 parent 5a051d9 commit 5a5d75b
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 68 deletions.
10 changes: 8 additions & 2 deletions examples/applications/astradb-sink/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ This is handled by the 'cassandra-table' assets in the pipeline.yaml file.

## Configure the pipeline

Update the same file and set username, password and the other parameters.
Update the secrets.yaml file and set the Astra credentials and the database name:
- clientId
- secret
- token
- database

You can find the credentials in the Astra DB console when you create a Token.

## Deploy the LangStream application

Expand All @@ -38,7 +44,7 @@ Update the same file and set username, password and the other parameters.

## Verify the data on Cassandra

Query Cassandra to see the results
Query Cassandra to see the results using the Astra DB console or the cqlsh tool:

```
SELECT * FROM products.products;
Expand Down
2 changes: 0 additions & 2 deletions examples/applications/astradb-sink/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ configuration:
service: "astra"
clientId: "{{{ secrets.astra.clientId }}}"
secret: "{{{ secrets.astra.secret }}}"
secureBundle: "{{{ secrets.astra.secureBundle }}}"
# These are optional, but if you want to use the astra-keyspace asset you need them
token: "{{{ secrets.astra.token }}}"
database: "{{{ secrets.astra.database }}}"
environment: "{{{ secrets.astra.environment }}}"
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ protected TypeCodec<?> createCodec(
@Override
public void initialize(Map<String, Object> dataSourceConfig) {
log.info("Initializing AstraDBDataSource with config {}", dataSourceConfig);
this.session = buildCqlSession(dataSourceConfig);
this.astraToken = ConfigurationUtils.getString("token", null, dataSourceConfig);
this.astraToken = ConfigurationUtils.getString("token", "", dataSourceConfig);
this.astraEnvironment =
ConfigurationUtils.getString("environment", "PROD", dataSourceConfig);
this.astraDatabase = ConfigurationUtils.getString("database", null, dataSourceConfig);
this.astraDatabase = ConfigurationUtils.getString("database", "", dataSourceConfig);
this.session = buildCqlSession(dataSourceConfig);
}

@Override
Expand Down Expand Up @@ -192,10 +192,17 @@ public void executeStatement(String query, List<Object> params) {
session.execute(bind);
}

public static CqlSession buildCqlSession(Map<String, Object> dataSourceConfig) {
private CqlSession buildCqlSession(Map<String, Object> dataSourceConfig) {

String username = ConfigurationUtils.getString("username", null, dataSourceConfig);
String password = ConfigurationUtils.getString("password", null, dataSourceConfig);
// these are the values used by the Astra UI
if (username == null) {
username = ConfigurationUtils.getString("clientId", null, dataSourceConfig);
}
if (password == null) {
password = ConfigurationUtils.getString("secret", null, dataSourceConfig);
}
String secureBundle = ConfigurationUtils.getString("secureBundle", null, dataSourceConfig);
List<String> contactPoints = ConfigurationUtils.getList("contact-points", dataSourceConfig);
String loadBalancingLocalDc =
Expand All @@ -204,18 +211,17 @@ public static CqlSession buildCqlSession(Map<String, Object> dataSourceConfig) {

byte[] secureBundleDecoded = null;
if (secureBundle != null && !secureBundle.isEmpty()) {
// these are the values used by the Astra UI
if (username == null) {
username = ConfigurationUtils.getString("clientId", null, dataSourceConfig);
}
if (password == null) {
password = ConfigurationUtils.getString("secret", null, dataSourceConfig);
}
// Remove the base64: prefix if present
if (secureBundle.startsWith("base64:")) {
secureBundle = secureBundle.substring("base64:".length());
}
secureBundleDecoded = Base64.getDecoder().decode(secureBundle);
} else if (!astraDatabase.isEmpty() && !astraToken.isEmpty()) {
log.info(
"Automatically downloading the secure bundle for database {} from AstraDB",
astraDatabase);
DatabaseClient databaseClient = this.buildAstraClient();
secureBundleDecoded = downloadSecureBundle(databaseClient);
}
CqlSessionBuilder builder = new CqlSessionBuilder().withCodecRegistry(CODEC_REGISTRY);

Expand Down Expand Up @@ -246,11 +252,24 @@ public CqlSession getSession() {
}

public DatabaseClient buildAstraClient() {
if (astraToken == null || astraDatabase == null) {
return buildAstraClient(astraToken, astraDatabase, astraEnvironment);
}

public static DatabaseClient buildAstraClient(
String astraToken, String astraDatabase, String astraEnvironment) {
if (astraToken.isEmpty() || astraDatabase.isEmpty()) {
throw new IllegalArgumentException(
"You must configure both astra-token and astra-database");
}
return new AstraDbClient(astraToken, ApiLocator.AstraEnvironment.valueOf(astraEnvironment))
.databaseByName(astraDatabase);
}

public static byte[] downloadSecureBundle(DatabaseClient databaseClient) {
long start = System.currentTimeMillis();
byte[] secureBundleDecoded = databaseClient.downloadDefaultSecureConnectBundle();
long delta = System.currentTimeMillis() - start;
log.info("Downloaded {} bytes in {} ms", secureBundleDecoded.length, delta);
return secureBundleDecoded;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
import ai.langstream.api.database.VectorDatabaseWriter;
import ai.langstream.api.database.VectorDatabaseWriterProvider;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.util.ConfigurationUtils;
import com.datastax.oss.common.sink.AbstractField;
import com.datastax.oss.common.sink.AbstractSchema;
import com.datastax.oss.common.sink.AbstractSinkRecord;
import com.datastax.oss.common.sink.AbstractSinkRecordHeader;
import com.datastax.oss.common.sink.AbstractSinkTask;
import com.datastax.oss.common.sink.config.CassandraSinkConfig;
import com.datastax.oss.common.sink.util.SinkUtil;
import com.datastax.oss.streaming.ai.datasource.CassandraDataSource;
import com.dtsx.astra.sdk.db.DatabaseClient;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -102,6 +106,32 @@ public void initialise(Map<String, Object> agentConfiguration) {
datasource
.getOrDefault("secureBundle", "")
.toString());
} else {
String token =
ConfigurationUtils.getString(
"token", "", datasource);
String database =
ConfigurationUtils.getString(
"database", "", datasource);
String environment =
ConfigurationUtils.getString(
"environment", "DEV", datasource);
if (!token.isEmpty() && !database.isEmpty()) {
DatabaseClient databaseClient =
CassandraDataSource.buildAstraClient(
token, database, environment);
log.info(
"Automatically downloading the secure bundle for database {} from AstraDB",
database);
byte[] secureBundle =
CassandraDataSource.downloadSecureBundle(
databaseClient);
configuration.put(
"cloud.secureConnectBundle",
"base64:"
+ Base64.getEncoder()
.encodeToString(secureBundle));
}
}

configuration.put(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,17 @@ protected void validateAsset(AssetDefinition assetDefinition, Map<String, Object
case "cassandra-keyspace" -> {
requiredNonEmptyField(assetDefinition, configuration, "keyspace");
requiredListField(assetDefinition, configuration, "create-statements");
if (datasourceConfiguration.containsKey("secureBundle")) {
throw new IllegalArgumentException(
"Use astra-keyspace for AstraDB services (not expecting a secureBundle in a Cassandra datasource).");
if (datasourceConfiguration.containsKey("secureBundle")
|| datasourceConfiguration.containsKey("database")) {
throw new IllegalArgumentException("Use astra-keyspace for AstraDB services");
}
}
case "astra-keyspace" -> {
requiredNonEmptyField(assetDefinition, configuration, "keyspace");
if (!datasourceConfiguration.containsKey("secureBundle")) {
if (!datasourceConfiguration.containsKey("secureBundle")
&& !datasourceConfiguration.containsKey("database")) {
throw new IllegalArgumentException(
"Use cassandra-keyspace for a standard Cassandra service (expecting a secureBundle, but found only "
+ datasourceConfiguration.keySet()
+ " .");
"Use cassandra-keyspace for a standard Cassandra service (not AstraDB)");
}
// are we are using the AstraDB SDK we need also the AstraCS token and
// the name of the database
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ private Map<String, Object> planAsset(
Resource resource = resources.get(resourceId);
if (resource != null) {
value = Map.of("configuration", resource.configuration());
} else {
throw new IllegalArgumentException(
"Resource with name="
+ resourceId
+ " not found, declared as "
+ key
+ " in asset "
+ assetDefinition.getId());
}
}
configuration.put(key, value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,11 @@
*/
package ai.langstream.kafka;

import static org.junit.jupiter.api.Assertions.assertEquals;

import ai.langstream.AbstractApplicationRunner;
import com.datastax.oss.driver.api.core.cql.ResultSet;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.streaming.ai.datasource.CassandraDataSource;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.producer.KafkaProducer;
Expand All @@ -35,12 +30,15 @@
@Disabled
class AstraDBAssetQueryWriteIT extends AbstractApplicationRunner {

static final String SECRETS_PATH = "";

@Test
@Disabled
public void testAstra() throws Exception {
String tenant = "tenant";
String[] expectedAgents = {"app-step1", "app-step2"};

String secrets = Files.readString(Paths.get(SECRETS_PATH));

Map<String, String> application =
Map.of(
"configuration.yaml",
Expand All @@ -51,12 +49,12 @@ public void testAstra() throws Exception {
name: "AstraDBDatasource"
configuration:
service: "astra"
secret: "xxx"
clientId: "xxx"
secureBundle: "base64:xxx"
token: "AstraCS:xxx"
database: "xxx"
environment: "DEV"
clientId: "{{{ secrets.astra.clientId }}}"
secret: "{{{ secrets.astra.secret }}}"
# These are optional, but if you want to use the astra-keyspace asset you need them
token: "{{{ secrets.astra.token }}}"
database: "{{{ secrets.astra.database }}}"
environment: "{{{ secrets.astra.environment }}}"
""",
"pipeline.yaml",
"""
Expand Down Expand Up @@ -117,8 +115,8 @@ public void testAstra() throws Exception {
""");

try (ApplicationRuntime applicationRuntime =
deployApplication(
tenant, "app", application, buildInstanceYaml(), expectedAgents)) {
deployApplicationWithSecrets(
tenant, "app", application, buildInstanceYaml(), secrets, expectedAgents)) {
try (KafkaProducer<String, String> producer = createProducer();
KafkaConsumer<String, String> consumer = createConsumer("output-topic")) {

Expand All @@ -129,35 +127,6 @@ public void testAstra() throws Exception {
consumer,
List.of(
"{\"documentId\":2,\"queryresult\":{\"name\":\"A\",\"description\":\"A description\",\"id\":\"1\"},\"name\":\"A\",\"description\":\"A description\"}"));

try (CassandraDataSource cassandraDataSource = new CassandraDataSource()) {
cassandraDataSource.initialize(
Map.of(
"service",
"astra",
"secret",
"xxxx",
"clientId",
"xxx",
"secureBundle",
"base64:xxx",
"token",
"AstraCS:xxx",
"database",
"xxxx",
"environment",
"DEV"));
ResultSet execute =
cassandraDataSource
.getSession()
.execute("SELECT * FROM vsearch.documents");
List<Row> all = execute.all();
Set<Integer> documentIds =
all.stream().map(row -> row.getInt("id")).collect(Collectors.toSet());
all.forEach(row -> log.info("row id {}", row.get("id", Integer.class)));
assertEquals(2, all.size());
assertEquals(Set.of(1, 2), documentIds);
}
}
}
}
Expand Down

0 comments on commit 5a5d75b

Please sign in to comment.