Skip to content
This repository has been archived by the owner on Jan 23, 2024. It is now read-only.

feat: Support default-rtdb instance. #78

Merged
merged 5 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions src/googleclouddebugger/firebase_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(self):
self._breakpoint_subscription = None

# Events for unit testing.
self.connection_complete = threading.Event()
self.registration_complete = threading.Event()
self.subscription_complete = threading.Event()

Expand Down Expand Up @@ -193,7 +194,10 @@ def SetupAuth(self,
service_account_json_file: JSON file to use for credentials. If not
provided, will default to application default credentials.
database_url: Firebase realtime database URL to be used. If not
provided, will default to https://{project_id}-cdbg.firebaseio.com
provided, connect attempts to the following DBs will be made, in
order:
https://{project_id}-cdbg.firebaseio.com
https://{project_id}-default-rtdb.firebaseio.com
Raises:
NoProjectIdError: If the project id cannot be determined.
"""
Expand All @@ -220,11 +224,7 @@ def SetupAuth(self,
'Please specify the project id using the --project_id flag.')

self._project_id = project_id

if database_url:
self._database_url = database_url
else:
self._database_url = f'https://{self._project_id}-cdbg.firebaseio.com'
self._database_url = database_url

def Start(self):
"""Starts the worker thread."""
Expand Down Expand Up @@ -287,15 +287,11 @@ def _MainThreadProc(self):
which will run in its own thread. That thread will be owned by
self._breakpoint_subscription.
"""
# Note: if self._credentials is None, default app credentials will be used.
try:
firebase_admin.initialize_app(self._credentials,
{'databaseURL': self._database_url})
except ValueError:
native.LogWarning(
f'Failed to initialize firebase: {traceback.format_exc()}')
native.LogError('Failed to start debugger agent. Giving up.')
return
connection_required, delay = True, 0
while connection_required:
time.sleep(delay)
connection_required, delay = self._ConnectToDb()
self.connection_complete.set()

registration_required, delay = True, 0
while registration_required:
Expand Down Expand Up @@ -343,6 +339,40 @@ def _StartMarkActiveTimer(self):
self._MarkActiveTimerFunc)
self._mark_active_timer.start()

def _ConnectToDb(self):
urls = [self._database_url] if self._database_url is not None else \
[f'https://{self._project_id}-cdbg.firebaseio.com',
f'https://{self._project_id}-default-rtdb.firebaseio.com']

for url in urls:
native.LogInfo(f'Attempting to connect to DB with url: {url}')
if self._TryInitializeDbForUrl(url):
native.LogInfo(f'Successfully connected to DB with url: {url}')
self._database_url = url
return (False, 0) # Proceed immediately to registering the debuggee.

return (True, self.register_backoff.Failed())
jasonborg marked this conversation as resolved.
Show resolved Hide resolved

def _TryInitializeDbForUrl(self, database_url):
# Note: if self._credentials is None, default app credentials will be used.
app = None
try:
app = firebase_admin.initialize_app(self._credentials,
jasonborg marked this conversation as resolved.
Show resolved Hide resolved
{'databaseURL': database_url})

if self._CheckSchemaVersionPresence():
return True

except ValueError:
native.LogWarning(
f'Failed to initialize firebase: {traceback.format_exc()}')

# This is the failure path, if we hit here we must cleanup the app handle
if app is not None:
firebase_admin.delete_app(app)

return False

def _RegisterDebuggee(self):
"""Single attempt to register the debuggee.

Expand Down Expand Up @@ -371,7 +401,7 @@ def _RegisterDebuggee(self):
else:
debuggee_path = f'cdbg/debuggees/{self._debuggee_id}'
native.LogInfo(
f'registering at {self._database_url}, path: {debuggee_path}')
f'Registering at {self._database_url}, path: {debuggee_path}')
debuggee_data = copy.deepcopy(debuggee)
debuggee_data['registrationTimeUnixMsec'] = {'.sv': 'timestamp'}
debuggee_data['lastUpdateTimeUnixMsec'] = {'.sv': 'timestamp'}
Expand All @@ -388,6 +418,17 @@ def _RegisterDebuggee(self):
native.LogInfo(f'Failed to register debuggee: {repr(e)}')
return (True, self.register_backoff.Failed())

def _CheckSchemaVersionPresence(self):
path = f'cdbg/schema_version'
try:
snapshot = firebase_admin.db.reference(path).get()
# The value doesn't matter; just return true if there's any value.
return snapshot is not None
except BaseException as e:
native.LogInfo(
f'Failed to check schema version presence at {path}: {repr(e)}')
return False

def _CheckDebuggeePresence(self):
path = f'cdbg/debuggees/{self._debuggee_id}/registrationTimeUnixMsec'
try:
Expand Down
126 changes: 108 additions & 18 deletions tests/py/firebase_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import firebase_admin.credentials
from firebase_admin.exceptions import FirebaseError
from firebase_admin.exceptions import NotFoundError

TEST_PROJECT_ID = 'test-project-id'
METADATA_PROJECT_URL = ('http://metadata.google.internal/computeMetadata/'
Expand Down Expand Up @@ -73,19 +74,32 @@ def setUp(self):
self._mock_initialize_app = patcher.start()
self.addCleanup(patcher.stop)

patcher = patch('firebase_admin.delete_app')
self._mock_delete_app = patcher.start()
self.addCleanup(patcher.stop)

patcher = patch('firebase_admin.db.reference')
self._mock_db_ref = patcher.start()
self.addCleanup(patcher.stop)

# Set up the mocks for the database refs.
self._mock_schema_version_ref = MagicMock()
self._mock_schema_version_ref.get.return_value = "2"
self._mock_presence_ref = MagicMock()
self._mock_presence_ref.get.return_value = None
self._mock_active_ref = MagicMock()
self._mock_register_ref = MagicMock()
self._fake_subscribe_ref = FakeReference()


# Setup common happy path reference sequence:
# cdbg/schema_version
# cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec
# cdbg/debuggees/{debuggee_id}
# cdbg/breakpoints/{debuggee_id}/active
self._mock_db_ref.side_effect = [
self._mock_presence_ref, self._mock_register_ref,
self._fake_subscribe_ref
self._mock_schema_version_ref, self._mock_presence_ref,
self._mock_register_ref, self._fake_subscribe_ref
]

def tearDown(self):
Expand All @@ -100,17 +114,13 @@ def testSetupAuthDefault(self):
self._client.SetupAuth()

self.assertEqual(TEST_PROJECT_ID, self._client._project_id)
self.assertEqual(f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com',
self._client._database_url)

def testSetupAuthOverrideProjectIdNumber(self):
# If a project id is provided, we use it.
project_id = 'project2'
self._client.SetupAuth(project_id=project_id)

self.assertEqual(project_id, self._client._project_id)
self.assertEqual(f'https://{project_id}-cdbg.firebaseio.com',
self._client._database_url)

def testSetupAuthServiceAccountJsonAuth(self):
# We'll load credentials from the provided file (mocked for simplicity)
Expand Down Expand Up @@ -144,6 +154,7 @@ def testStart(self):
self._mock_initialize_app.assert_called_with(
None, {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'})
self.assertEqual([
call(f'cdbg/schema_version'),
call(f'cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec'),
call(f'cdbg/debuggees/{debuggee_id}'),
call(f'cdbg/breakpoints/{debuggee_id}/active')
Expand All @@ -155,13 +166,87 @@ def testStart(self):
expected_data['lastUpdateTimeUnixMsec'] = {'.sv': 'timestamp'}
self._mock_register_ref.set.assert_called_once_with(expected_data)

def testStartCustomDbUrlConfigured(self):
self._client.SetupAuth(
project_id=TEST_PROJECT_ID,
database_url='https://custom-db.firebaseio.com')
self._client.Start()
self._client.connection_complete.wait()

debuggee_id = self._client._debuggee_id

self._mock_initialize_app.assert_called_once_with(
None, {'databaseURL': 'https://custom-db.firebaseio.com'})

def testStartConnectFallsBackToDefaultRtdb(self):
# A new schema_version ref will be fetched each time
self._mock_db_ref.side_effect = [
self._mock_schema_version_ref, self._mock_schema_version_ref,
self._mock_presence_ref, self._mock_register_ref,
self._fake_subscribe_ref
]

# Fail on the '-cdbg' instance test, succeed on the '-default-rtdb' one.
self._mock_schema_version_ref.get.side_effect = [
NotFoundError("Not found", http_response=404), '2'
]

self._client.SetupAuth(project_id=TEST_PROJECT_ID)
self._client.Start()
self._client.connection_complete.wait()

self.assertEqual([
call(None,
{'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}),
call(None, {
'databaseURL':
f'https://{TEST_PROJECT_ID}-default-rtdb.firebaseio.com'
})
], self._mock_initialize_app.call_args_list)

self.assertEqual(1, self._mock_delete_app.call_count)

def testStartConnectFailsThenSucceeds(self):
# A new schema_version ref will be fetched each time
self._mock_db_ref.side_effect = [
self._mock_schema_version_ref, self._mock_schema_version_ref,
self._mock_schema_version_ref, self._mock_presence_ref,
self._mock_register_ref, self._fake_subscribe_ref
]

# Completely fail on the initial attempt at reaching a DB, then succeed on
# 2nd attempt. One full attempt will try the '-cdbg' db instance followed by
# the '-default-rtdb' one.
self._mock_schema_version_ref.get.side_effect = [
NotFoundError("Not found", http_response=404),
NotFoundError("Not found", http_response=404), '2'
]

self._client.SetupAuth(project_id=TEST_PROJECT_ID)
self._client.Start()
self._client.connection_complete.wait()

self.assertEqual([
call(None,
{'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}),
call(None, {
'databaseURL':
f'https://{TEST_PROJECT_ID}-default-rtdb.firebaseio.com'
}),
call(None,
{'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'})
], self._mock_initialize_app.call_args_list)

self.assertEqual(2, self._mock_delete_app.call_count)

def testStartAlreadyPresent(self):
# Create a mock for just this test that claims the debuggee is registered.
mock_presence_ref = MagicMock()
mock_presence_ref.get.return_value = 'present!'

self._mock_db_ref.side_effect = [
mock_presence_ref, self._mock_active_ref, self._fake_subscribe_ref
self._mock_schema_version_ref, mock_presence_ref, self._mock_active_ref,
self._fake_subscribe_ref
]

self._client.SetupAuth(project_id=TEST_PROJECT_ID)
Expand All @@ -171,6 +256,7 @@ def testStartAlreadyPresent(self):
debuggee_id = self._client._debuggee_id

self.assertEqual([
call(f'cdbg/schema_version'),
call(f'cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec'),
call(f'cdbg/debuggees/{debuggee_id}/lastUpdateTimeUnixMsec'),
call(f'cdbg/breakpoints/{debuggee_id}/active')
Expand All @@ -182,6 +268,7 @@ def testStartAlreadyPresent(self):
def testStartRegisterRetry(self):
# A new set of db refs are fetched on each retry.
self._mock_db_ref.side_effect = [
self._mock_schema_version_ref,
self._mock_presence_ref, self._mock_register_ref,
self._mock_presence_ref, self._mock_register_ref,
self._fake_subscribe_ref
Expand All @@ -202,6 +289,7 @@ def testStartSubscribeRetry(self):

# A new db ref is fetched on each retry.
self._mock_db_ref.side_effect = [
self._mock_schema_version_ref,
self._mock_presence_ref,
self._mock_register_ref,
mock_subscribe_ref, # Fail the first time
Expand All @@ -212,7 +300,7 @@ def testStartSubscribeRetry(self):
self._client.Start()
self._client.subscription_complete.wait()

self.assertEqual(4, self._mock_db_ref.call_count)
self.assertEqual(5, self._mock_db_ref.call_count)

def testMarkActiveTimer(self):
# Make sure that there are enough refs queued up.
Expand Down Expand Up @@ -310,9 +398,9 @@ def testEnqueueBreakpointUpdate(self):
final_ref_mock = MagicMock()

self._mock_db_ref.side_effect = [
self._mock_presence_ref, self._mock_register_ref,
self._fake_subscribe_ref, active_ref_mock, snapshot_ref_mock,
final_ref_mock
self._mock_schema_version_ref, self._mock_presence_ref,
self._mock_register_ref, self._fake_subscribe_ref,
active_ref_mock, snapshot_ref_mock, final_ref_mock
]

self._client.SetupAuth(project_id=TEST_PROJECT_ID)
Expand Down Expand Up @@ -370,13 +458,13 @@ def testEnqueueBreakpointUpdate(self):
db_ref_calls = self._mock_db_ref.call_args_list
self.assertEqual(
call(f'cdbg/breakpoints/{debuggee_id}/active/{breakpoint_id}'),
db_ref_calls[3])
db_ref_calls[4])
self.assertEqual(
call(f'cdbg/breakpoints/{debuggee_id}/snapshot/{breakpoint_id}'),
db_ref_calls[4])
db_ref_calls[5])
self.assertEqual(
call(f'cdbg/breakpoints/{debuggee_id}/final/{breakpoint_id}'),
db_ref_calls[5])
db_ref_calls[6])

active_ref_mock.delete.assert_called_once()
snapshot_ref_mock.set.assert_called_once_with(full_breakpoint)
Expand All @@ -387,8 +475,9 @@ def testEnqueueBreakpointUpdateWithLogpoint(self):
final_ref_mock = MagicMock()

self._mock_db_ref.side_effect = [
self._mock_presence_ref, self._mock_register_ref,
self._fake_subscribe_ref, active_ref_mock, final_ref_mock
self._mock_schema_version_ref, self._mock_presence_ref,
self._mock_register_ref, self._fake_subscribe_ref,
active_ref_mock, final_ref_mock
]

self._client.SetupAuth(project_id=TEST_PROJECT_ID)
Expand Down Expand Up @@ -437,10 +526,10 @@ def testEnqueueBreakpointUpdateWithLogpoint(self):
db_ref_calls = self._mock_db_ref.call_args_list
self.assertEqual(
call(f'cdbg/breakpoints/{debuggee_id}/active/{breakpoint_id}'),
db_ref_calls[3])
db_ref_calls[4])
self.assertEqual(
call(f'cdbg/breakpoints/{debuggee_id}/final/{breakpoint_id}'),
db_ref_calls[4])
db_ref_calls[5])

active_ref_mock.delete.assert_called_once()
final_ref_mock.set.assert_called_once_with(output_breakpoint)
Expand Down Expand Up @@ -468,6 +557,7 @@ def testEnqueueBreakpointUpdateRetry(self):
]

self._mock_db_ref.side_effect = [
self._mock_schema_version_ref,
self._mock_presence_ref,
self._mock_register_ref,
self._fake_subscribe_ref, # setup
Expand Down