Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable native JDBC connections #28

Merged
merged 3 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ public static void sendStartupMessage(
new ParameterStatusResponse(output, "session_authorization".getBytes(), "PGAdapter".getBytes())
.send();
new ParameterStatusResponse(output, "integer_datetimes".getBytes(), "on".getBytes()).send();
new ParameterStatusResponse(output, "server_encoding".getBytes(), "utf8".getBytes()).send();
new ParameterStatusResponse(output, "client_encoding".getBytes(), "utf8".getBytes()).send();
new ParameterStatusResponse(output, "DateStyle".getBytes(), "ISO".getBytes()).send();
new ParameterStatusResponse(output, "server_encoding".getBytes(), "UTF8".getBytes()).send();
new ParameterStatusResponse(output, "client_encoding".getBytes(), "UTF8".getBytes()).send();
new ParameterStatusResponse(output, "DateStyle".getBytes(), "ISO,YMD".getBytes()).send();
new ParameterStatusResponse(output, "IntervalStyle".getBytes(), "iso_8601".getBytes()).send();
new ParameterStatusResponse(output, "standard_conforming_strings".getBytes(), "true".getBytes())
new ParameterStatusResponse(output, "standard_conforming_strings".getBytes(), "on".getBytes())
.send();
new ParameterStatusResponse(
output,
Expand Down
146 changes: 146 additions & 0 deletions src/test/java/com/google/cloud/spanner/pgadapter/ITJdbcTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.cloud.spanner.pgadapter;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import com.google.cloud.ByteArray;
import com.google.cloud.Timestamp;
import com.google.cloud.spanner.Database;
import com.google.cloud.spanner.KeySet;
import com.google.cloud.spanner.Mutation;
import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata;
import com.google.common.collect.ImmutableList;
import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Collections;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@Category(IntegrationTest.class)
@RunWith(JUnit4.class)
public class ITJdbcTest implements IntegrationTest {
private static final PgAdapterTestEnv testEnv = new PgAdapterTestEnv();
private static ProxyServer server;
private static Database database;

@BeforeClass
public static void setup() throws Exception {
// Make sure the PG JDBC driver is loaded.
Class.forName("org.postgresql.Driver");

testEnv.setUp();
if (testEnv.isUseExistingDb()) {
database = testEnv.getExistingDatabase();
} else {
database = testEnv.createDatabase();
testEnv.updateDdl(
database.getId().getDatabase(),
Arrays.asList(
"create table numbers (num int not null primary key, name varchar(100))",
"create table all_types ("
+ "col_bigint bigint not null primary key, "
+ "col_bool bool, "
+ "col_bytea bytea, "
+ "col_float8 float8, "
+ "col_int int, "
+ "col_numeric numeric, "
+ "col_timestamptz timestamptz, "
+ "col_varchar varchar(100))"));
}
String credentials = testEnv.getCredentials();
ImmutableList.Builder<String> argsListBuilder =
ImmutableList.<String>builder()
.add(
"-p",
testEnv.getProjectId(),
"-i",
testEnv.getInstanceId(),
"-d",
database.getId().getDatabase(),
"-s",
String.valueOf(testEnv.getPort()),
"-e",
testEnv.getUrl().getHost());
if (credentials != null) {
argsListBuilder.add("-c", testEnv.getCredentials());
}
String[] args = argsListBuilder.build().toArray(new String[0]);
server = new ProxyServer(new OptionsMetadata(args));
server.startServer();
}

@AfterClass
public static void teardown() {
if (server != null) {
server.stopServer();
}
testEnv.cleanUp();
}

@Before
public void insertTestData() {
String databaseId = database.getId().getDatabase();
testEnv.write(databaseId, Collections.singleton(Mutation.delete("numbers", KeySet.all())));
testEnv.write(databaseId, Collections.singleton(Mutation.delete("all_types", KeySet.all())));
testEnv.write(
databaseId,
Arrays.asList(
Mutation.newInsertBuilder("numbers").set("num").to(1L).set("name").to("One").build(),
Mutation.newInsertBuilder("all_types")
.set("col_bigint")
.to(1L)
.set("col_bool")
.to(true)
.set("col_bytea")
.to(ByteArray.copyFrom("test"))
.set("col_float8")
.to(3.14d)
.set("col_int")
.to(1)
.set("col_numeric")
.to(new BigDecimal("3.14"))
.set("col_timestamptz")
.to(Timestamp.parseTimestamp("2022-01-27T17:51:30+01:00"))
.set("col_varchar")
.to("test")
.build()));
}

@Test
public void testSelectHelloWorld() throws SQLException {
try (Connection connection =
DriverManager.getConnection(
String.format("jdbc:postgresql://localhost:%d/", testEnv.getPort()))) {
try (ResultSet resultSet =
connection.createStatement().executeQuery("SELECT 'Hello World!'")) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work with all other queries as well using pg-jdbc? I assume the wireprotocol is the same for pg-jdbc?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the wireprotocol is the same for pg-jdbc, so yes it will in theory work with other queries. However:

  1. The pg-jdbc driver uses binary format for many query parameter types, so we need to support that for parameterized queries to work. That is included in feat: add support incoming binary values #27
  2. The pg-jdbc driver sends timestamp and date parameter values as untyped parameters. Support for that is also included in feat: add support incoming binary values #27 (only for timestamp at this moment, as we currently do not support date).
  3. The pg-jdbc driver prefers the extended query mode over simple query mode (which is normally used in psql). That will probably also need some tweaking.

assertTrue(resultSet.next());
assertEquals("Hello World!", resultSet.getString(1));
assertFalse(resultSet.next());
}
}
}
}
103 changes: 76 additions & 27 deletions src/test/java/com/google/cloud/spanner/pgadapter/PgAdapterTestEnv.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
import com.google.cloud.spanner.DatabaseClient;
import com.google.cloud.spanner.DatabaseId;
import com.google.cloud.spanner.Dialect;
import com.google.cloud.spanner.Mutation;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerOptions;
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.testing.RemoteSpannerHelper;
import com.google.common.base.Strings;
import com.google.common.primitives.Bytes;
import com.google.spanner.admin.database.v1.CreateDatabaseMetadata;
import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata;
Expand All @@ -40,6 +41,7 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
Expand Down Expand Up @@ -71,6 +73,12 @@ public final class PgAdapterTestEnv {
// PgAdapter port should be set through this system property.
public static final String SERVICE_PORT = "PG_ADAPTER_PORT";

// Environment variable that can be used to force the test env to assume that the test database
// already exists. This can be used to speed up local testing by manually creating the test
// database and re-run the tests multiple times against the same database without the need to
// recreate it for every test run.
public static final String USE_EXISTING_DB = "PG_ADAPTER_USE_EXISTING_DB";

// Default fallback project Id will be used if one isn't set via the system property.
private static final String DEFAULT_PROJECT_ID = "span-cloud-testing";

Expand Down Expand Up @@ -102,6 +110,9 @@ public final class PgAdapterTestEnv {
// Port used by the pgadapter.
private int port = 0;

// Shared Spanner instance that is automatically created and closed.
private Spanner spanner;

// Spanner options for creating a client.
private SpannerOptions options;

Expand All @@ -111,24 +122,27 @@ public final class PgAdapterTestEnv {
// Log stream for the test process.
private static final Logger logger = Logger.getLogger(PgAdapterTestEnv.class.getName());

// Utility class for setting up test connection.
private RemoteSpannerHelper spannerHelper;

private final List<Database> databases = new ArrayList<>();

public void setUp() throws Exception {
spannerURL = new URL(getHostUrl());
options = createSpannerOptions();
}

public SpannerOptions spannerOptions() {
return options;
public Spanner getSpanner() {
if (spanner == null) {
spanner = options.getService();
}
return spanner;
}

public String getCredentials() throws Exception {
public String getCredentials() {
if (System.getenv().get(GCP_CREDENTIALS) == null) {
return null;
}

if (gcpCredentials == null) {
Map<String, String> env = System.getenv();
gcpCredentials = env.get(GCP_CREDENTIALS);
gcpCredentials = System.getenv().get(GCP_CREDENTIALS);
if (gcpCredentials.isEmpty()) {
throw new IllegalArgumentException("Invalid GCP credentials file.");
}
Expand Down Expand Up @@ -178,27 +192,41 @@ public URL getUrl() {
return spannerURL;
}

public boolean isUseExistingDb() {
return Boolean.parseBoolean(System.getProperty(USE_EXISTING_DB, "false"));
}

public Database getExistingDatabase() {
if (databaseId == null) {
databaseId = System.getProperty(TEST_DATABASE_PROPERTY, DEFAULT_DATABASE_ID);
}
return getSpanner().getDatabaseAdminClient().getDatabase(instanceId, databaseId);
}

// Create database.
public Database createDatabase() throws Exception {
if (isUseExistingDb()) {
throw new IllegalStateException(
"Cannot create a new test database if " + USE_EXISTING_DB + " is true.");
}
String databaseId = getDatabaseId();
Spanner spanner = options.getService();
Spanner spanner = getSpanner();
DatabaseAdminClient client = spanner.getDatabaseAdminClient();
Database db = null;
OperationFuture<Database, CreateDatabaseMetadata> op =
client.createDatabase(
client
.newDatabaseBuilder(DatabaseId.of(projectId, instanceId, databaseId))
.setDialect(Dialect.POSTGRESQL)
.build(),
Arrays.asList());
db = op.get();
Collections.emptyList());
Database db = op.get();
databases.add(db);
logger.log(Level.INFO, "Created database [" + db.getId() + "]");
return db;
}

public void updateDdl(String databaseId, Iterable<String> statements) throws Exception {
Spanner spanner = options.getService();
Spanner spanner = getSpanner();
DatabaseAdminClient client = spanner.getDatabaseAdminClient();
OperationFuture<Void, UpdateDatabaseDdlMetadata> op =
client.updateDatabaseDdl(instanceId, databaseId, statements, null);
Expand All @@ -207,8 +235,8 @@ public void updateDdl(String databaseId, Iterable<String> statements) throws Exc
}

// Update tables of the database.
public void updateTables(String databaseId, Iterable<String> statements) throws Exception {
Spanner spanner = options.getService();
public void updateTables(String databaseId, Iterable<String> statements) {
Spanner spanner = getSpanner();
DatabaseId db = DatabaseId.of(projectId, instanceId, databaseId);
DatabaseClient dbClient = spanner.getDatabaseClient(db);
dbClient
Expand All @@ -228,6 +256,19 @@ public void updateTables(String databaseId, Iterable<String> statements) throws
});
}

/**
* Writes data to the given test database.
*
* @param databaseId The id of the database to write to
* @param mutations The mutations to write
*/
public void write(String databaseId, Iterable<Mutation> mutations) {
Spanner spanner = getSpanner();
DatabaseId db = DatabaseId.of(projectId, instanceId, databaseId);
DatabaseClient dbClient = spanner.getDatabaseClient(db);
dbClient.write(mutations);
}

// Setup spanner options.
private SpannerOptions createSpannerOptions() throws Exception {
projectId = getProjectId();
Expand All @@ -236,12 +277,15 @@ private SpannerOptions createSpannerOptions() throws Exception {
Map<String, String> env = System.getenv();
gcpCredentials = env.get(GCP_CREDENTIALS);
GoogleCredentials credentials = null;
credentials = GoogleCredentials.fromStream(new FileInputStream(gcpCredentials));
return SpannerOptions.newBuilder()
.setProjectId(projectId)
.setHost(spannerURL.toString())
.setCredentials(credentials)
.build();
if (!Strings.isNullOrEmpty(gcpCredentials)) {
credentials = GoogleCredentials.fromStream(new FileInputStream(gcpCredentials));
}
SpannerOptions.Builder builder =
SpannerOptions.newBuilder().setProjectId(projectId).setHost(spannerURL.toString());
if (credentials != null) {
builder.setCredentials(credentials);
}
return builder.build();
}

public static class PGMessage {
Expand Down Expand Up @@ -365,12 +409,17 @@ void consumeStartUpMessages(DataInputStream in) throws java.io.IOException {

// Drop all the databases we created explicitly.
public void cleanUp() {
for (Database db : databases) {
try {
db.drop();
} catch (Exception e) {
logger.log(Level.WARNING, "Failed to drop test database " + db.getId(), e);
if (!isUseExistingDb()) {
for (Database db : databases) {
try {
db.drop();
} catch (Exception e) {
logger.log(Level.WARNING, "Failed to drop test database " + db.getId(), e);
}
}
}
if (spanner != null) {
spanner.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1370,25 +1370,25 @@ public void testStartUpMessage() throws Exception {
Assert.assertEquals(outputResult.readByte(), 'S');
Assert.assertEquals(outputResult.readInt(), 25);
Assert.assertEquals(readUntil(outputResult, "server_encoding\0".length()), "server_encoding\0");
Assert.assertEquals(readUntil(outputResult, "utf8\0".length()), "utf8\0");
Assert.assertEquals(readUntil(outputResult, "UTF8\0".length()), "UTF8\0");
Assert.assertEquals(outputResult.readByte(), 'S');
Assert.assertEquals(outputResult.readInt(), 25);
Assert.assertEquals(readUntil(outputResult, "client_encoding\0".length()), "client_encoding\0");
Assert.assertEquals(readUntil(outputResult, "utf8\0".length()), "utf8\0");
Assert.assertEquals(readUntil(outputResult, "UTF8\0".length()), "UTF8\0");
Assert.assertEquals(outputResult.readByte(), 'S');
Assert.assertEquals(outputResult.readInt(), 18);
Assert.assertEquals(outputResult.readInt(), 22);
Assert.assertEquals(readUntil(outputResult, "DateStyle\0".length()), "DateStyle\0");
Assert.assertEquals(readUntil(outputResult, "ISO\0".length()), "ISO\0");
Assert.assertEquals(readUntil(outputResult, "ISO,YMD\0".length()), "ISO,YMD\0");
Assert.assertEquals(outputResult.readByte(), 'S');
Assert.assertEquals(outputResult.readInt(), 27);
Assert.assertEquals(readUntil(outputResult, "IntervalStyle\0".length()), "IntervalStyle\0");
Assert.assertEquals(readUntil(outputResult, "iso_8601\0".length()), "iso_8601\0");
Assert.assertEquals(outputResult.readByte(), 'S');
Assert.assertEquals(outputResult.readInt(), 37);
Assert.assertEquals(outputResult.readInt(), 35);
Assert.assertEquals(
readUntil(outputResult, "standard_conforming_strings\0".length()),
"standard_conforming_strings\0");
Assert.assertEquals(readUntil(outputResult, "true\0".length()), "true\0");
Assert.assertEquals(readUntil(outputResult, "on\0".length()), "on\0");
Assert.assertEquals(outputResult.readByte(), 'S');
Assert.assertEquals(outputResult.readInt(), 17);
Assert.assertEquals(readUntil(outputResult, "TimeZone\0".length()), "TimeZone\0");
Expand Down