Skip to content
This repository has been archived by the owner on Jul 18, 2024. It is now read-only.

Commit

Permalink
feat: CLI general location support
Browse files Browse the repository at this point in the history
fixes #141
  • Loading branch information
jasonborg committed May 16, 2023
1 parent 1f759f6 commit 35f9ce4
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 41 deletions.
18 changes: 18 additions & 0 deletions snapshot_dbg_cli/firebase_management_rest_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@
{response}
"""

VALID_LOCATIONS_URL = (
"https://firebase.google.com/docs/projects/locations#rtdb-locations"
)


class FirebaseManagementRestService:
"""Implements a service for making Firebase RTDB management REST requests.
Expand Down Expand Up @@ -282,6 +286,20 @@ def rtdb_instance_create(self, database_id, location):
self._user_output.debug("Got 400:", parsed_error)
return DatabaseCreateResponse(
DatabaseCreateStatus.FAILED_PRECONDITION)

if parsed_error["error"]["status"] == "INVALID_ARGUMENT":
print_http_error(
self._user_output, request, err, error_message=error_message)

self._user_output.error(
"This was attempting to create the database instance. One "
"potential reason for this is if an invalid location was "
"specified, valid locations for RTDBs can be found at "
f"{VALID_LOCATIONS_URL}. To note, valid RTDB locations are a "
"subset of valid Google Cloud regions.")

raise SilentlyExitError from err

except (TypeError, KeyError, ValueError):
pass

Expand Down
15 changes: 14 additions & 1 deletion snapshot_dbg_cli/firebase_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
"""

from enum import Enum
import re

FIREBASE_MANAGMENT_API_SERVICE = 'firebase.googleapis.com'
FIREBASE_RTDB_MANAGMENT_API_SERVICE = 'firebasedatabase.googleapis.com'


class FirebaseProjectStatus(Enum):
ENABLED = 1
NOT_ENABLED = 2
Expand Down Expand Up @@ -123,12 +123,25 @@ def __init__(self, database_instance):
self.database_url = database_instance['databaseUrl']
self.type = database_instance['type']
self.state = database_instance['state']
self.location = self.extract_location(self.name)

if self.location is None:
raise ValueError(
f"Failed to extract location from project name '{self.name}'")

except KeyError as e:
missing_key = e.args[0]
error_message = ('DatabaseInstance is missing expected field '
f"'{missing_key}' instance: {database_instance}")
raise ValueError(error_message) from e

@staticmethod
def extract_location(name):
location_search = re.search('/locations/([^/]+)/', name)
if not location_search or len(location_search.groups()) != 1:
return None

return location_search.groups()[0]

class DatabaseCreateResponse:
"""Represents the response of a database create request.
Expand Down
38 changes: 25 additions & 13 deletions snapshot_dbg_cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,21 @@
state: {db_state}
"""

LOCTION_MISMATCH_ERROR_MSG = (
"ERROR the following database already exists: '{full_database_name}', "
"however its location '{existing_location}' does not match the requested "
"location '{requested_location}'.\n\n"
"The database ID ('{database_id}' in this case) must be unique across "
"locations and projects.\n\n"
"If you meant to finalize the initialization of '{database_id}' in "
"'{requested_location}' (or verify it is already correctly initialized) "
"rerun the init command and specify '--location={existing_location}'.\n\n"
"If you meant to create a new database in '{requested_location}', given "
"'{database_id}' is already in use, you'll need to use a new name by "
"providing the '--database-id' argument to the init command. Note, even "
"if you delete '{database_id}', the name will remain reserved and will not "
"be available to be reused in a new location, a new name must be chosen.")


class InitCommand:
"""This class implements the init command.
Expand Down Expand Up @@ -222,14 +237,6 @@ def register(self, args_subparsers, required_parsers, common_parsers):

# Only some locations are supported, see:
# https://firebase.google.com/docs/projects/locations#rtdb-locations
#
# If unsupported location is used, this error occurs
# "error": {
# "code": 400,
# "message": "Request contains an invalid argument.",
# "status": "INVALID_ARGUMENT"
# }
# For now however we only support 'us-central1'
parser.add_argument(
'-l', '--location', help=LOCATION_HELP, default=DEFAULT_LOCATION)
self.args_parser = parser
Expand All @@ -243,11 +250,6 @@ def cmd(self, args, cli_services):
self.permissions_service = cli_services.permissions_service
self.project_id = cli_services.project_id

if args.location != DEFAULT_LOCATION:
self.user_output.error('ERROR: Currently the only supported location is '
f"'{DEFAULT_LOCATION}'")
raise SilentlyExitError

# If the user does not have the required permissions this will emit an error
# message and exit.
self.permissions_service.check_required_permissions(REQUIRED_PERMISSIONS)
Expand Down Expand Up @@ -347,6 +349,16 @@ def check_and_handle_database_instance(self, args, firebase_project):

if status == DatabaseGetStatus.EXISTS:
database_instance = instance_response.database_instance
if args.location != database_instance.location:
self.user_output.error(
LOCTION_MISMATCH_ERROR_MSG.format(
full_database_name=database_instance.name,
existing_location=database_instance.location,
requested_location=args.location,
database_id=database_id))

raise SilentlyExitError

elif status == DatabaseGetStatus.DOES_NOT_EXIST:
create_response = self.firebase_management_service.rtdb_instance_create(
database_id=database_id, location=args.location)
Expand Down
37 changes: 31 additions & 6 deletions snapshot_dbg_cli_tests/test_firebase_management_rest_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@

VALID_PROJECT_RESPONSE = {'state': 'ACTIVE'}

DB_NAME = "projects/1111111111/locations/us-central1/instances/db-name"

VALID_DB_RESPONSE = {
'name': 'db-name',
'name': DB_NAME,
'project': 'project-name',
'databaseUrl': 'project-default-rtdb.firebaseio.com',
'type': 'DEFAULT_DATABASE',
Expand Down Expand Up @@ -252,7 +254,7 @@ def test_rtdb_instance_get_raises_when_state_is_not_active_in_response(self):

def test_rtdb_instance_get_returns_expected_value_on_success(self):
self.http_service_mock.send.return_value = {
'name': 'db-name',
'name': DB_NAME,
'project': 'project-name',
'databaseUrl': 'project-default-rtdb.firebaseio.com',
'type': 'DEFAULT_DATABASE',
Expand All @@ -263,7 +265,7 @@ def test_rtdb_instance_get_returns_expected_value_on_success(self):
'db-name')

self.assertEqual(DatabaseGetStatus.EXISTS, obtained_response.status)
self.assertEqual('db-name', obtained_response.database_instance.name)
self.assertEqual(DB_NAME, obtained_response.database_instance.name)
self.assertEqual('project-name',
obtained_response.database_instance.project)
self.assertEqual('project-default-rtdb.firebaseio.com',
Expand Down Expand Up @@ -327,7 +329,30 @@ def test_rtdb_instance_create_returns_status_failed_precondition(self):
self.assertEqual(DatabaseCreateStatus.FAILED_PRECONDITION,
obtained_response.status)

def test_rtdb_instance_create_raises_on_400_non_failed_precondition(self):
def test_rtdb_instance_create_returns_status_invalid_argument(self):
error_message = json.dumps({'error': {'status': 'INVALID_ARGUMENT', 'message': 'Invalid location "bad-location"'}})

http_error = HTTPError('https://foo.com', 400, 'Invalid Argument', {},
BytesIO(bytes(f'{error_message}', 'utf-8')))

self.http_service_mock.send.side_effect = http_error

with self.assertRaises(SilentlyExitError), \
patch('sys.stdout', new_callable=StringIO) as out, \
patch('sys.stderr', new_callable=StringIO) as err:
self.firebase_management_rest_service.rtdb_instance_create(
'db-name', 'bad-location')

self.assertIn('Invalid location', err.getvalue())
self.assertIn((
'One potential reason for this is if an invalid location was specified, '
'valid locations for RTDBs can be found at '
'https://firebase.google.com/docs/projects/locations#rtdb-locations.'),
err.getvalue())

self.assertEqual('', out.getvalue())

def test_rtdb_instance_create_raises_on_400_non_special_type(self):
http_error = HTTPError('https://foo.com', 400, 'Internal Server Error', {},
BytesIO(b'Fake Error Message'))
self.http_service_mock.send.side_effect = http_error
Expand Down Expand Up @@ -357,7 +382,7 @@ def test_rtdb_instance_create_raises_on_non_400(self):

def test_rtdb_instance_create_returns_expected_value_on_success(self):
self.http_service_mock.send.return_value = {
'name': 'db-name',
'name': DB_NAME,
'project': 'project-name',
'databaseUrl': 'project-default-rtdb.firebaseio.com',
'type': 'DEFAULT_DATABASE',
Expand All @@ -369,7 +394,7 @@ def test_rtdb_instance_create_returns_expected_value_on_success(self):
'db-name', 'us-central1')

self.assertEqual(DatabaseCreateStatus.SUCCESS, obtained_response.status)
self.assertEqual('db-name', obtained_response.database_instance.name)
self.assertEqual(DB_NAME, obtained_response.database_instance.name)
self.assertEqual('project-name',
obtained_response.database_instance.project)
self.assertEqual('project-default-rtdb.firebaseio.com',
Expand Down
58 changes: 58 additions & 0 deletions snapshot_dbg_cli_tests/test_firebase_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Unit test file for the firebase_types module.
"""

import unittest
from snapshot_dbg_cli.firebase_types import DatabaseInstance


class DatabaseInstanceTests(unittest.TestCase):
""" Contains the unit tests for the DatabaseInstance class.
"""

def test_location_success(self):
locations = ['us-central1', 'europe-west1']
for location in locations:
with self.subTest(location):
db_instance = DatabaseInstance({
'name': f'projects/1111111111/locations/{location}/instances/foo-cdbg',
'project': 'projects/1111111111',
'databaseUrl': 'https://foo-cdbg.firebaseio.com',
'type': 'USER_DATABASE',
'state': 'ACTIVE'
})
self.assertEqual(location, db_instance.location)

def test_location_could_not_be_found(self):
invalid_names = [
'',
'foo',
# missing the /locations/
'projects/1111111111/us-central1/instances/foo-cdbg',
'projects/1111111111/locations',
'projects/1111111111/locations/'
]

for full_db_name in invalid_names:
with self.subTest(full_db_name):
with self.assertRaises(ValueError) as ctxt:
DatabaseInstance({
'name': full_db_name,
'project': 'projects/1111111111',
'databaseUrl': 'https://foo-cdbg.firebaseio.com',
'type': 'USER_DATABASE',
'state': 'ACTIVE'
})
self.assertEqual(f"Failed to extract location from project name '{full_db_name}'", str(ctxt.exception))
Loading

0 comments on commit 35f9ce4

Please sign in to comment.