Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve drift detection perf, tests, and add some typehints #1186

Merged
merged 3 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions cartography/driftdetect/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Optional


class UpdateConfig:
"""
A common interface for the drift-detection update configuration.
Expand All @@ -18,10 +21,10 @@ class UpdateConfig:

def __init__(
self,
drift_detection_directory,
neo4j_uri,
neo4j_user=None,
neo4j_password=None,
drift_detection_directory: str,
neo4j_uri: str,
neo4j_user: Optional[str] = None,
neo4j_password: Optional[str] = None,
):
self.neo4j_uri = neo4j_uri
self.neo4j_user = neo4j_user
Expand All @@ -46,13 +49,13 @@ class GetDriftConfig:

def __init__(
self,
query_directory,
start_state,
end_state,
query_directory: str,
start_state: str,
end_state: str,
):
self.query_directory = query_directory
self.start_state = start_state
self.end_state = end_state
self.query_directory: str = query_directory
self.start_state: str = start_state
self.end_state: str = end_state


class AddShortcutConfig:
Expand All @@ -72,10 +75,10 @@ class AddShortcutConfig:

def __init__(
self,
query_directory,
shortcut,
filename,
query_directory: str,
shortcut: str,
filename: str,
):
self.query_directory = query_directory
self.shortcut = shortcut
self.filename = filename
self.query_directory: str = query_directory
self.shortcut: str = shortcut
self.filename: str = filename
16 changes: 11 additions & 5 deletions cartography/driftdetect/detect_deviations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import logging
import os
from typing import List
from typing import Union

from marshmallow import ValidationError

from cartography.driftdetect.config import GetDriftConfig
from cartography.driftdetect.model import State
from cartography.driftdetect.reporter import report_drift
from cartography.driftdetect.serializers import ShortcutSchema
from cartography.driftdetect.serializers import StateSchema
Expand All @@ -12,7 +16,7 @@
logger = logging.getLogger(__name__)


def run_drift_detection(config):
def run_drift_detection(config: GetDriftConfig) -> None:
try:
if not valid_directory(config.query_directory):
logger.error("Invalid Drift Detection Directory")
Expand Down Expand Up @@ -59,7 +63,7 @@ def run_drift_detection(config):
logger.exception(msg)


def perform_drift_detection(start_state, end_state):
def perform_drift_detection(start_state: State, end_state: State):
"""
Returns differences (additions and missing results) between two States.

Expand All @@ -81,7 +85,7 @@ def perform_drift_detection(start_state, end_state):
return new_results, missing_results


def compare_states(start_state, end_state):
def compare_states(start_state: State, end_state: State):
"""
Helper function for comparing differences between two States.

Expand All @@ -92,10 +96,12 @@ def compare_states(start_state, end_state):
:return: list of tuples of differences between states in the form (dictionary, State object)
"""
differences = []
# Use set for faster membership check
start_state_results = {tuple(res) for res in start_state.results}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main change of this PR. Spend more memory for faster lookup.

for result in end_state.results:
if result in start_state.results:
if tuple(result) in start_state_results:
continue
drift = []
drift: List[Union[str, List[str]]] = []
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of how the list can have more than one type (and mypy isn't particularly happy about this either) but I don't want to break anything

for field in result:
value = field.split("|")
if len(value) > 1:
Expand Down
26 changes: 21 additions & 5 deletions cartography/driftdetect/get_states.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import logging
import os.path
import time
from typing import Any
from typing import Dict
from typing import List

import neo4j.exceptions
from marshmallow import ValidationError
from neo4j import GraphDatabase

from cartography.client.core.tx import read_list_of_dicts_tx
from cartography.driftdetect.add_shortcut import add_shortcut
from cartography.driftdetect.config import UpdateConfig
from cartography.driftdetect.model import State
from cartography.driftdetect.serializers import ShortcutSchema
from cartography.driftdetect.serializers import StateSchema
from cartography.driftdetect.storage import FileSystem
Expand All @@ -15,7 +21,7 @@
logger = logging.getLogger(__name__)


def run_get_states(config):
def run_get_states(config: UpdateConfig) -> None:
"""
Handles neo4j errors and then updates detectors.

Expand Down Expand Up @@ -90,7 +96,13 @@ def run_get_states(config):
logger.exception(err)


def get_query_state(session, query_directory, state_serializer, storage, filename):
def get_query_state(
session: neo4j.Session,
query_directory: str,
state_serializer: StateSchema,
storage,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not adding a typehint here because it seems like we're doing this ducktyping thing where we expect the storage object to have a write() function and other methods but I don't want to necessarily restrict it to just the existing storage.FileSystem object.

filename: str,
) -> State:
"""
Gets the most recent state of a query.

Expand All @@ -115,7 +127,7 @@ def get_query_state(session, query_directory, state_serializer, storage, filenam
return state


def get_state(session, state):
def get_state(session: neo4j.Session, state: State) -> None:
"""
Connects to a neo4j session, runs the validation query, then saves the results to a state.

Expand All @@ -126,10 +138,14 @@ def get_state(session, state):
:return:
"""

new_results = session.run(state.validation_query)
new_results: List[Dict[str, Any]] = session.read_transaction(
read_list_of_dicts_tx,
state.validation_query,
)
logger.debug(f"Updating results for {state.name}")

state.properties = new_results.keys()
# The keys will be the same across all items in the returned list
state.properties = list(new_results[0].keys())
results = []

for record in new_results:
Expand Down
17 changes: 9 additions & 8 deletions cartography/driftdetect/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import List

logger = logging.getLogger(__name__)

Expand All @@ -19,13 +20,13 @@ class State:

def __init__(
self,
name,
validation_query,
properties,
results,
name: str,
validation_query: str,
properties: List[str],
results: List[List[str]],
):

self.name = name
self.validation_query = validation_query
self.properties = properties
self.results = results
self.name: str = name
self.validation_query: str = validation_query
self.properties: List[str] = properties
self.results: List[List[str]] = results
57 changes: 49 additions & 8 deletions tests/unit/driftdetect/test_detector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import MagicMock

from cartography.client.core.tx import read_list_of_dicts_tx
from cartography.driftdetect.detect_deviations import compare_states
from cartography.driftdetect.get_states import get_state
from cartography.driftdetect.model import State
Expand All @@ -26,13 +27,13 @@ def test_state_no_drift():

mock_result.__getitem__.side_effect = results.__getitem__
mock_result.__iter__.side_effect = results.__iter__
mock_session.run.return_value = mock_result
mock_session.read_transaction.return_value = mock_result
data = FileSystem.load("tests/data/detectors/test_expectations.json")
state_old = StateSchema().load(data)
state_new = State(state_old.name, state_old.validation_query, state_old.properties, [])
get_state(mock_session, state_new)
drifts = compare_states(state_old, state_new)
mock_session.run.assert_called_with(state_new.validation_query)
mock_session.read_transaction.assert_called_with(read_list_of_dicts_tx, state_new.validation_query)
assert not drifts


Expand All @@ -54,20 +55,60 @@ def test_state_picks_up_drift():
{key: "7"},
]

# Arrange
mock_result.__getitem__.side_effect = results.__getitem__
mock_result.__iter__.side_effect = results.__iter__
mock_session.run.return_value = mock_result
mock_session.read_transaction.return_value = mock_result
data = FileSystem.load("tests/data/detectors/test_expectations.json")
state_old = StateSchema().load(data)
state_new = State(state_old.name, state_old.validation_query, state_old.properties, [])
get_state(mock_session, state_new)
state_new.properties = state_old.properties

# Act
drifts = compare_states(state_old, state_new)
mock_session.run.assert_called_with(state_new.validation_query)

# Assert
mock_session.read_transaction.assert_called_with(read_list_of_dicts_tx, state_new.validation_query)
assert drifts
assert ["7"] in drifts


def test_state_order_does_not_matter():
"""
Test that a state that detects drift.
:return:
"""
key = "d.test"
mock_session = MagicMock()
mock_result = MagicMock()
results = [
{key: "1"},
{key: "2"},
{key: "6"}, # This one is out of order
{key: "3"},
{key: "4"},
{key: "5"},
]

# Arrange
mock_result.__getitem__.side_effect = results.__getitem__
mock_result.__iter__.side_effect = results.__iter__
mock_session.read_transaction.return_value = mock_result
data = FileSystem.load("tests/data/detectors/test_expectations.json")
state_old = StateSchema().load(data)
state_new = State(state_old.name, state_old.validation_query, state_old.properties, [])
get_state(mock_session, state_new)
state_new.properties = state_old.properties

# Act
drifts = compare_states(state_old, state_new)

# Assert
mock_session.read_transaction.assert_called_with(read_list_of_dicts_tx, state_new.validation_query)
assert not drifts


def test_state_multiple_expectations():
"""
Test that multiple fields runs properly.
Expand All @@ -89,14 +130,14 @@ def test_state_multiple_expectations():

mock_result.__getitem__.side_effect = results.__getitem__
mock_result.__iter__.side_effect = results.__iter__
mock_session.run.return_value = mock_result
mock_session.read_transaction.return_value = mock_result
data = FileSystem.load("tests/data/detectors/test_multiple_expectations.json")
state_old = StateSchema().load(data)
state_new = State(state_old.name, state_old.validation_query, state_old.properties, [])
get_state(mock_session, state_new)
state_new.properties = state_old.properties
drifts = compare_states(state_old, state_new)
mock_session.run.assert_called_with(state_new.validation_query)
mock_session.read_transaction.assert_called_with(read_list_of_dicts_tx, state_new.validation_query)
assert ["7", "14"] in drifts


Expand All @@ -121,14 +162,14 @@ def test_drift_from_multiple_properties():
]
mock_result.__getitem__.side_effect = results.__getitem__
mock_result.__iter__.side_effect = results.__iter__
mock_session.run.return_value = mock_result
mock_session.read_transaction.return_value = mock_result
data = FileSystem.load("tests/data/detectors/test_multiple_properties.json")
state_old = StateSchema().load(data)
state_new = State(state_old.name, state_old.validation_query, state_old.properties, [])
get_state(mock_session, state_new)
state_new.properties = state_old.properties
drifts = compare_states(state_old, state_new)
mock_session.run.assert_called_with(state_new.validation_query)
mock_session.read_transaction.assert_called_with(read_list_of_dicts_tx, state_new.validation_query)
assert ["7", "14", ["21", "28", "35"]] in drifts
assert ["3", "10", ["17", "24", "31"]] not in drifts

Expand Down