Skip to content

Commit

Permalink
feat: add cloud sql cloudbuild workflow (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 authored Jan 16, 2024
1 parent 693c19d commit 3dd3444
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 8 deletions.
84 changes: 84 additions & 0 deletions retrieval_service/cloudsql.tests.cloudbuild.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

steps:
- id: Install dependencies
name: python:3.11
dir: retrieval_service
entrypoint: pip
args:
[
"install",
"-r",
"requirements.txt",
"-r",
"requirements-test.txt",
"--user",
]

- id: Update config
name: python:3.11
dir: retrieval_service
secretEnv:
- PGUSER
- PGPASSWORD
entrypoint: /bin/bash
args:
- "-c"
- |
# Create config
cp example-config-cloudsql.yml config.yml
sed -i "s/my_database/${_DATABASE_NAME}/g" config.yml
sed -i "s/my-user/$$PGUSER/g" config.yml
sed -i "s/my-password/$$PGPASSWORD/g" config.yml
sed -i "s/my-project/$PROJECT_ID/g" config.yml
sed -i "s/my-region/${_CLOUDSQL_REGION}/g" config.yml
sed -i "s/my-instance/${_CLOUDSQL_INSTANCE}/g" config.yml
- id: Run Cloud SQL DB integration tests
name: python:3.11
dir: retrieval_service
env: # Set env var expected by tests
- "DB_NAME=${_DATABASE_NAME}"
- "DB_PROJECT=$PROJECT_ID"
- "DB_REGION=${_CLOUDSQL_REGION}"
- "DB_INSTANCE=${_CLOUDSQL_INSTANCE}"
secretEnv:
- PGUSER
- PGPASSWORD
entrypoint: /bin/bash
args:
- "-c"
- |
# Set env var expected by tests
export DB_USER=$$PGUSER
export DB_PASS=$$PGPASSWORD
python -m pytest datastore/providers/cloudsql_postgres_test.py
substitutions:
_DATABASE_NAME: test_${SHORT_SHA}
_DATABASE_USER: postgres
_CLOUDSQL_REGION: "us-central1"
_CLOUDSQL_INSTANCE: "my-cloudsql-instance"

availableSecrets:
secretManager:
- versionName: projects/$PROJECT_ID/secrets/cloudsql_pass/versions/latest
env: PGPASSWORD
- versionName: projects/$PROJECT_ID/secrets/cloudsql_user/versions/latest
env: PGUSER

options:
substitutionOption: 'ALLOW_LOOSE'
dynamic_substitutions: true
104 changes: 96 additions & 8 deletions retrieval_service/datastore/providers/cloudsql_postgres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from datetime import datetime
from ipaddress import IPv4Address
from typing import Any, AsyncGenerator, List

import asyncpg
import pytest
import pytest_asyncio
from csv_diff import compare, load_csv # type: ignore
from google.cloud.sql.connector import Connector

import models

Expand All @@ -39,11 +43,6 @@ def db_pass() -> str:
return get_env_var("DB_PASS", "password for the postgres user")


@pytest.fixture(scope="module")
def db_name() -> str:
return get_env_var("DB_NAME", "name of a postgres database")


@pytest.fixture(scope="module")
def db_project() -> str:
return get_env_var("DB_PROJECT", "project id for google cloud")
Expand All @@ -59,15 +58,47 @@ def db_instance() -> str:
return get_env_var("DB_INSTANCE", "instance for cloud sql")


@pytest.fixture(scope="module")
async def create_db(
db_user: str, db_pass: str, db_project: str, db_region: str, db_instance: str
) -> AsyncGenerator[str, None]:
db_name = get_env_var("DB_NAME", "name of a postgres database")
loop = asyncio.get_running_loop()
connector = Connector(loop=loop)
# Database does not exist, create it.
sys_conn: asyncpg.Connection = await connector.connect_async(
f"{db_project}:{db_region}:{db_instance}",
"asyncpg",
user=f"{db_user}",
password=f"{db_pass}",
db="postgres",
)
await sys_conn.execute(f'DROP DATABASE IF EXISTS "{db_name}";')
await sys_conn.execute(f'CREATE DATABASE "{db_name}";')
await sys_conn.close()
conn: asyncpg.Connection = await connector.connect_async(
f"{db_project}:{db_region}:{db_instance}",
"asyncpg",
user=f"{db_user}",
password=f"{db_pass}",
db=f"{db_name}",
)
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
yield db_name
await conn.execute(f'DROP DATABASE IF EXISTS "{db_name}";')
await conn.close()


@pytest_asyncio.fixture(scope="module")
async def ds(
create_db: AsyncGenerator[str, None],
db_user: str,
db_pass: str,
db_name: str,
db_project: str,
db_region: str,
db_instance: str,
) -> AsyncGenerator[datastore.Client, None]:
db_name = await create_db.__anext__()
cfg = cloudsql_postgres.Config(
kind="cloudsql-postgres",
user=db_user,
Expand All @@ -77,13 +108,70 @@ async def ds(
region=db_region,
instance=db_instance,
)
t = create_db
ds = await datastore.create(cfg)

airports_ds_path = "../data/airport_dataset.csv"
amenities_ds_path = "../data/amenity_dataset.csv"
flights_ds_path = "../data/flights_dataset.csv"
airports, amenities, flights = await ds.load_dataset(
airports_ds_path, amenities_ds_path, flights_ds_path
)
await ds.initialize_data(airports, amenities, flights)

if ds is None:
raise TypeError("datastore creation failure")
yield ds
print("after yield")
await ds.close()
print("closed database")


async def test_export_dataset(ds: cloudsql_postgres.Client):
airports, amenities, flights = await ds.export_data()

airports_ds_path = "../data/airport_dataset.csv"
amenities_ds_path = "../data/amenity_dataset.csv"
flights_ds_path = "../data/flights_dataset.csv"

airports_new_path = "../data/airport_dataset.csv.new"
amenities_new_path = "../data/amenity_dataset.csv.new"
flights_new_path = "../data/flights_dataset.csv.new"

await ds.export_dataset(
airports,
amenities,
flights,
airports_new_path,
amenities_new_path,
flights_new_path,
)

diff_airports = compare(
load_csv(open(airports_ds_path), "id"), load_csv(open(airports_new_path), "id")
)
assert diff_airports["added"] == []
assert diff_airports["removed"] == []
assert diff_airports["changed"] == []
assert diff_airports["columns_added"] == []
assert diff_airports["columns_removed"] == []

diff_amenities = compare(
load_csv(open(amenities_ds_path), "id"),
load_csv(open(amenities_new_path), "id"),
)
assert diff_amenities["added"] == []
assert diff_amenities["removed"] == []
assert diff_amenities["changed"] == []
assert diff_amenities["columns_added"] == []
assert diff_amenities["columns_removed"] == []

diff_flights = compare(
load_csv(open(flights_ds_path), "id"), load_csv(open(flights_new_path), "id")
)
assert diff_flights["added"] == []
assert diff_flights["removed"] == []
assert diff_flights["changed"] == []
assert diff_flights["columns_added"] == []
assert diff_flights["columns_removed"] == []


async def test_get_airport_by_id(ds: cloudsql_postgres.Client):
Expand Down
10 changes: 10 additions & 0 deletions retrieval_service/example-config-cloudsql.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
host: 0.0.0.0
datastore:
# Example for Cloud SQL
kind: "cloudsql-postgres"
project: "my-project"
region: "my-region"
instance: "my-instance"
database: "my_database"
user: "my-user"
password: "my-password"

0 comments on commit 3dd3444

Please sign in to comment.