Skip to content

Commit

Permalink
Merge pull request #1582 from blacklanternsecurity/neo4j-update
Browse files Browse the repository at this point in the history
Update and Optimize Neo4j
  • Loading branch information
TheTechromancer authored Jul 30, 2024
2 parents c64ecbc + ef368e2 commit f810f49
Show file tree
Hide file tree
Showing 14 changed files with 202 additions and 98 deletions.
74 changes: 41 additions & 33 deletions bbot/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from bbot.core import CORE
from bbot.errors import BBOTEngineError
from bbot.core.helpers.misc import rand_string
from bbot.core.helpers.async_helpers import get_event_loop
from bbot.core.helpers.misc import rand_string, in_exception_chain


error_sentinel = object()
Expand All @@ -41,6 +41,7 @@ class EngineBase:
ERROR_CLASS = BBOTEngineError

def __init__(self):
self._shutdown_status = False
self.log = logging.getLogger(f"bbot.core.{self.__class__.__name__.lower()}")

def pickle(self, obj):
Expand All @@ -62,7 +63,7 @@ def unpickle(self, binary):

async def _infinite_retry(self, callback, *args, **kwargs):
interval = kwargs.pop("_interval", 10)
while 1:
while not self._shutdown_status:
try:
return await asyncio.wait_for(callback(*args, **kwargs), timeout=interval)
except (TimeoutError, asyncio.TimeoutError):
Expand Down Expand Up @@ -107,7 +108,6 @@ class EngineClient(EngineBase):
SERVER_CLASS = None

def __init__(self, **kwargs):
self._shutdown = False
super().__init__()
self.name = f"EngineClient {self.__class__.__name__}"
self.process = None
Expand Down Expand Up @@ -135,7 +135,7 @@ def check_error(self, message):
async def run_and_return(self, command, *args, **kwargs):
fn_str = f"{command}({args}, {kwargs})"
self.log.debug(f"{self.name}: executing run-and-return {fn_str}")
if self._shutdown and not command == "_shutdown":
if self._shutdown_status and not command == "_shutdown":
self.log.verbose(f"{self.name} has been shut down and is not accepting new tasks")
return
async with self.new_socket() as socket:
Expand Down Expand Up @@ -163,7 +163,7 @@ async def run_and_return(self, command, *args, **kwargs):
async def run_and_yield(self, command, *args, **kwargs):
fn_str = f"{command}({args}, {kwargs})"
self.log.debug(f"{self.name}: executing run-and-yield {fn_str}")
if self._shutdown:
if self._shutdown_status:
self.log.verbose("Engine has been shut down and is not accepting new tasks")
return
message = self.make_message(command, args=args, kwargs=kwargs)
Expand Down Expand Up @@ -213,14 +213,16 @@ async def send_shutdown_message(self):
async with self.new_socket() as socket:
# -99 == special shutdown message
message = pickle.dumps({"c": -99})
await self._infinite_retry(socket.send, message)
while 1:
response = await self._infinite_retry(socket.recv)
response = pickle.loads(response)
if isinstance(response, dict):
response = response.get("m", "")
if response == "SHUTDOWN_OK":
break
with suppress(TimeoutError, asyncio.TimeoutError):
await asyncio.wait_for(socket.send(message), 0.5)
with suppress(TimeoutError, asyncio.TimeoutError):
while 1:
response = await asyncio.wait_for(socket.recv(), 0.5)
response = pickle.loads(response)
if isinstance(response, dict):
response = response.get("m", "")
if response == "SHUTDOWN_OK":
break

def check_stop(self, message):
if isinstance(message, dict) and len(message) == 1 and "_s" in message:
Expand Down Expand Up @@ -280,7 +282,7 @@ def server_process(server_class, socket_path, **kwargs):
else:
asyncio.run(engine_server.worker())
except (asyncio.CancelledError, KeyboardInterrupt, CancelledError):
pass
return
except Exception:
import traceback

Expand All @@ -306,9 +308,9 @@ async def new_socket(self):
socket.close()

async def shutdown(self):
self.log.debug(f"{self.name}: shutting down...")
if not self._shutdown:
self._shutdown = True
if not self._shutdown_status:
self._shutdown_status = True
self.log.verbose(f"{self.name}: shutting down...")
# send shutdown signal
await self.send_shutdown_message()
# then terminate context
Expand Down Expand Up @@ -446,6 +448,7 @@ def check_error(self, message):
return True

async def worker(self):
self.log.debug(f"{self.name}: starting worker")
try:
while 1:
client_id, binary = await self.socket.recv_multipart()
Expand All @@ -462,8 +465,8 @@ async def worker(self):
# -1 == cancel task
if cmd == -1:
self.log.debug(f"{self.name} got cancel signal")
await self.cancel_task(client_id)
await self.send_socket_multipart(client_id, {"m": "CANCEL_OK"})
await self.cancel_task(client_id)
continue

# -99 == shutdown task
Expand Down Expand Up @@ -500,24 +503,28 @@ async def worker(self):
task = asyncio.create_task(coroutine)
self.tasks[client_id] = task, command_fn, args, kwargs
# self.log.debug(f"{self.name}: finished creating task for {command_name}() coroutine")
except Exception as e:
self.log.error(f"{self.name}: error in EngineServer worker: {e}")
self.log.trace(traceback.format_exc())
except BaseException as e:
await self._shutdown()
if not in_exception_chain(e, (KeyboardInterrupt, asyncio.CancelledError)):
self.log.error(f"{self.name}: error in EngineServer worker: {e}")
self.log.trace(traceback.format_exc())
finally:
self.log.debug(f"{self.name}: finished worker()")

async def _shutdown(self):
self.log.debug(f"{self.name}: shutting down...")
await self.cancel_all_tasks()
try:
self.context.destroy(linger=0)
except Exception:
self.log.trace(traceback.format_exc())
try:
self.context.term()
except Exception:
self.log.trace(traceback.format_exc())
self.log.debug(f"{self.name}: finished shutting down")
if not self._shutdown_status:
self.log.verbose(f"{self.name}: shutting down...")
self._shutdown_status = True
await self.cancel_all_tasks()
try:
self.context.destroy(linger=0)
except Exception:
self.log.trace(traceback.format_exc())
try:
self.context.term()
except Exception:
self.log.trace(traceback.format_exc())
self.log.debug(f"{self.name}: finished shutting down")

def new_child_task(self, client_id, coro):
task = asyncio.create_task(coro)
Expand Down Expand Up @@ -554,8 +561,9 @@ async def _cancel_task(self, task):
await asyncio.wait_for(task, timeout=10)
except (TimeoutError, asyncio.TimeoutError):
self.log.debug(f"{self.name}: Timeout cancelling task")
return
except (KeyboardInterrupt, asyncio.CancelledError):
pass
return
except BaseException as e:
self.log.error(f"Unhandled error in {task.get_coro().__name__}(): {e}")
self.log.trace(traceback.format_exc())
Expand Down
2 changes: 1 addition & 1 deletion bbot/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,7 @@ async def _worker(self):
self.scan.stats.event_consumed(event, self)
self.debug(f"Intercepting {event}")
async with self.scan._acatch(context), self._task_counter.count(context):
forward_event = await self.handle_event(event, kwargs)
forward_event = await self.handle_event(event, **kwargs)
with suppress(ValueError, TypeError):
forward_event, forward_event_reason = forward_event

Expand Down
2 changes: 1 addition & 1 deletion bbot/modules/internal/cloudcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def filter_event(self, event):
return False, "event does not have host attribute"
return True

async def handle_event(self, event, kwargs):
async def handle_event(self, event, **kwargs):
# don't hold up the event loop loading cloud IPs etc.
if self.dummy_modules is None:
self.make_dummy_modules()
Expand Down
2 changes: 1 addition & 1 deletion bbot/modules/internal/dnsresolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def filter_event(self, event):
return False, "event does not have host attribute"
return True

async def handle_event(self, event, kwargs):
async def handle_event(self, event, **kwargs):
dns_tags = set()
dns_children = dict()
event_whitelisted = False
Expand Down
16 changes: 8 additions & 8 deletions bbot/modules/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def human_event_str(self, event):
return event_str

def _event_precheck(self, event):
reason = "precheck succeeded"
# special signal event types
if event.type in ("FINISHED",):
return True, "its type is FINISHED"
Expand All @@ -42,24 +43,23 @@ def _event_precheck(self, event):
if event.type.startswith("URL") and self.name != "httpx" and "httpx-only" in event.tags:
return False, (f"Omitting {event} from output because it's marked as httpx-only")

if event._omit:
return False, "_omit is True"

# omit certain event types
if event.type in self.scan.omitted_event_types:
if event._omit:
if "target" in event.tags:
self.debug(f"Allowing omitted event: {event} because it's a target")
reason = "it's a target"
self.debug(f"Allowing omitted event: {event} because {reason}")
elif event.type in self.get_watched_events():
self.debug(f"Allowing omitted event: {event} because its type is explicitly in watched_events")
reason = "its type is explicitly in watched_events"
self.debug(f"Allowing omitted event: {event} because {reason}")
else:
return False, "its type is omitted in the config"
return False, "_omit is True"

# internal events like those from speculate, ipneighbor
# or events that are over our report distance
if event._internal:
return False, "_internal is True"

return True, "precheck succeeded"
return True, reason

async def _event_postcheck(self, event):
acceptable, reason = await super()._event_postcheck(event)
Expand Down
110 changes: 86 additions & 24 deletions bbot/modules/output/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class neo4j(BaseOutputModule):
"password": "Neo4j password",
}
deps_pip = ["neo4j"]
_batch_size = 500
_preserve_graph = True

async def setup(self):
Expand All @@ -51,32 +52,93 @@ async def setup(self):
return False, f"Error setting up Neo4j: {e}"
return True

async def handle_event(self, event):
# create events
src_id = await self.merge_event(event.get_parent(), id_only=True)
dst_id = await self.merge_event(event)
# create relationship
cypher = f"""
MATCH (a) WHERE id(a) = $src_id
MATCH (b) WHERE id(b) = $dst_id
MERGE (a)-[_:{event.module}]->(b)
SET _.timestamp = $timestamp"""
await self.session.run(cypher, src_id=src_id, dst_id=dst_id, timestamp=event.timestamp)

async def merge_event(self, event, id_only=False):
async def handle_batch(self, *all_events):
await self.helpers.sleep(5)
# group events by type, since cypher doesn't allow dynamic labels
events_by_type = {}
parents_by_type = {}
relationships = []
for event in all_events:
parent = event.get_parent()
try:
events_by_type[event.type].append(event)
except KeyError:
events_by_type[event.type] = [event]
try:
parents_by_type[parent.type].append(parent)
except KeyError:
parents_by_type[parent.type] = [parent]

module = str(event.module)
timestamp = event.timestamp
relationships.append((parent, module, timestamp, event))

all_ids = {}
for event_type, events in events_by_type.items():
self.debug(f"{len(events):,} events of type {event_type}")
all_ids.update(await self.merge_events(events, event_type))
for event_type, parents in parents_by_type.items():
self.debug(f"{len(parents):,} parents of type {event_type}")
all_ids.update(await self.merge_events(parents, event_type, id_only=True))

rel_ids = []
for parent, module, timestamp, event in relationships:
try:
src_id = all_ids[parent.id]
dst_id = all_ids[event.id]
except KeyError as e:
self.critical(f'Error "{e}" correlating {parent.id}:{parent.data} --> {event.id}:{event.data}')
continue
rel_ids.append((src_id, module, timestamp, dst_id))

await self.merge_relationships(rel_ids)

async def merge_events(self, events, event_type, id_only=False):
if id_only:
eventdata = {"type": event.type, "id": event.id}
insert_data = [{"data": str(e.data), "type": e.type, "id": e.id} for e in events]
else:
eventdata = event.json(mode="graph")
# we pop the timestamp because it belongs on the relationship
eventdata.pop("timestamp")
cypher = f"""MERGE (_:{event.type} {{ id: $eventdata['id'] }})
SET _ += $eventdata
RETURN id(_)"""
# insert event
result = await self.session.run(cypher, eventdata=eventdata)
# get Neo4j id
return (await result.single()).get("id(_)")
insert_data = []
for e in events:
event_json = e.json(mode="graph")
# we pop the timestamp because it belongs on the relationship
event_json.pop("timestamp")
# nested data types aren't supported in neo4j
event_json.pop("dns_children", None)
insert_data.append(event_json)

cypher = f"""UNWIND $events AS event
MERGE (_:{event_type} {{ id: event.id }})
SET _ += event
RETURN event.data as event_data, event.id as event_id, elementId(_) as neo4j_id"""
# insert events
results = await self.session.run(cypher, events=insert_data)
# get Neo4j ids
neo4j_ids = {}
for result in await results.data():
event_id = result["event_id"]
neo4j_id = result["neo4j_id"]
neo4j_ids[event_id] = neo4j_id
return neo4j_ids

async def merge_relationships(self, relationships):
rels_by_module = {}
# group by module
for src_id, module, timestamp, dst_id in relationships:
data = {"src_id": src_id, "timestamp": timestamp, "dst_id": dst_id}
try:
rels_by_module[module].append(data)
except KeyError:
rels_by_module[module] = [data]

for module, rels in rels_by_module.items():
self.debug(f"{len(rels):,} relationships of type {module}")
cypher = f"""
UNWIND $rels AS rel
MATCH (a) WHERE elementId(a) = rel.src_id
MATCH (b) WHERE elementId(b) = rel.dst_id
MERGE (a)-[_:{module}]->(b)
SET _.timestamp = rel.timestamp"""
await self.session.run(cypher, rels=rels)

async def cleanup(self):
with suppress(Exception):
Expand Down
Loading

0 comments on commit f810f49

Please sign in to comment.