-
Notifications
You must be signed in to change notification settings - Fork 340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve drift detection perf, tests, and add some typehints #1186
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,12 @@ | ||
import logging | ||
import os | ||
from typing import List | ||
from typing import Union | ||
|
||
from marshmallow import ValidationError | ||
|
||
from cartography.driftdetect.config import GetDriftConfig | ||
from cartography.driftdetect.model import State | ||
from cartography.driftdetect.reporter import report_drift | ||
from cartography.driftdetect.serializers import ShortcutSchema | ||
from cartography.driftdetect.serializers import StateSchema | ||
|
@@ -12,7 +16,7 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
def run_drift_detection(config): | ||
def run_drift_detection(config: GetDriftConfig) -> None: | ||
try: | ||
if not valid_directory(config.query_directory): | ||
logger.error("Invalid Drift Detection Directory") | ||
|
@@ -59,7 +63,7 @@ def run_drift_detection(config): | |
logger.exception(msg) | ||
|
||
|
||
def perform_drift_detection(start_state, end_state): | ||
def perform_drift_detection(start_state: State, end_state: State): | ||
""" | ||
Returns differences (additions and missing results) between two States. | ||
|
||
|
@@ -81,7 +85,7 @@ def perform_drift_detection(start_state, end_state): | |
return new_results, missing_results | ||
|
||
|
||
def compare_states(start_state, end_state): | ||
def compare_states(start_state: State, end_state: State): | ||
""" | ||
Helper function for comparing differences between two States. | ||
|
||
|
@@ -92,10 +96,12 @@ def compare_states(start_state, end_state): | |
:return: list of tuples of differences between states in the form (dictionary, State object) | ||
""" | ||
differences = [] | ||
# Use set for faster membership check | ||
start_state_results = {tuple(res) for res in start_state.results} | ||
for result in end_state.results: | ||
if result in start_state.results: | ||
if tuple(result) in start_state_results: | ||
continue | ||
drift = [] | ||
drift: List[Union[str, List[str]]] = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a fan of how the list can have more than one type (and mypy isn't particularly happy about this either) but I don't want to break anything |
||
for field in result: | ||
value = field.split("|") | ||
if len(value) > 1: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,18 @@ | ||
import logging | ||
import os.path | ||
import time | ||
from typing import Any | ||
from typing import Dict | ||
from typing import List | ||
|
||
import neo4j.exceptions | ||
from marshmallow import ValidationError | ||
from neo4j import GraphDatabase | ||
|
||
from cartography.client.core.tx import read_list_of_dicts_tx | ||
from cartography.driftdetect.add_shortcut import add_shortcut | ||
from cartography.driftdetect.config import UpdateConfig | ||
from cartography.driftdetect.model import State | ||
from cartography.driftdetect.serializers import ShortcutSchema | ||
from cartography.driftdetect.serializers import StateSchema | ||
from cartography.driftdetect.storage import FileSystem | ||
|
@@ -15,7 +21,7 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
def run_get_states(config): | ||
def run_get_states(config: UpdateConfig) -> None: | ||
""" | ||
Handles neo4j errors and then updates detectors. | ||
|
||
|
@@ -90,7 +96,13 @@ def run_get_states(config): | |
logger.exception(err) | ||
|
||
|
||
def get_query_state(session, query_directory, state_serializer, storage, filename): | ||
def get_query_state( | ||
session: neo4j.Session, | ||
query_directory: str, | ||
state_serializer: StateSchema, | ||
storage, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not adding a typehint here because it seems like we're doing this ducktyping thing where we expect the |
||
filename: str, | ||
) -> State: | ||
""" | ||
Gets the most recent state of a query. | ||
|
||
|
@@ -115,7 +127,7 @@ def get_query_state(session, query_directory, state_serializer, storage, filenam | |
return state | ||
|
||
|
||
def get_state(session, state): | ||
def get_state(session: neo4j.Session, state: State) -> None: | ||
""" | ||
Connects to a neo4j session, runs the validation query, then saves the results to a state. | ||
|
||
|
@@ -126,10 +138,14 @@ def get_state(session, state): | |
:return: | ||
""" | ||
|
||
new_results = session.run(state.validation_query) | ||
new_results: List[Dict[str, Any]] = session.read_transaction( | ||
read_list_of_dicts_tx, | ||
state.validation_query, | ||
) | ||
logger.debug(f"Updating results for {state.name}") | ||
|
||
state.properties = new_results.keys() | ||
# The keys will be the same across all items in the returned list | ||
state.properties = list(new_results[0].keys()) | ||
results = [] | ||
|
||
for record in new_results: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main change of this PR. Spend more memory for faster lookup.