Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a RuleCollection object instead of a "loader" module #1063

Merged
merged 20 commits into from
Apr 5, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions detection_rules/devtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
import click
from elasticsearch import Elasticsearch
from eql import load_dump
from kibana.connector import Kibana

from kibana.connector import Kibana
from . import rule_loader
from .eswrap import CollectEvents, add_range_to_dsl
from .main import root
from .misc import PYTHON_LICENSE, add_client, GithubClient, Manifest, client_error, getdefault
from .packaging import PACKAGE_FILE, Package, manage_versions, RELEASE_DIR
from .rule import TOMLRule, TOMLRuleContents, BaseQueryRuleData
from .rule_loader import get_rule
from .rule import TOMLRule, BaseQueryRuleData
from .rule_loader import production_filter, RuleCollection
from .utils import get_path, dict_hash

RULES_DIR = get_path('rules')
Expand Down Expand Up @@ -68,7 +68,7 @@ def update_lock_versions(rule_ids):
if not click.confirm('Are you sure you want to update hashes without a version bump?'):
return

rules = [r for r in rule_loader.load_rules(verbose=False).values() if r.id in rule_ids]
rules = RuleCollection.default().filter(lambda r: r.id in rule_ids)
changed, new = manage_versions(rules, exclude_version_update=True, add_new=False, save_changes=True)

if not changed:
Expand All @@ -86,10 +86,12 @@ def kibana_diff(rule_id, repo, branch, threads):
"""Diff rules against their version represented in kibana if exists."""
from .misc import get_kibana_rules

rules = RuleCollection.default()

if rule_id:
rules = {r.id: r for r in rule_loader.load_rules(verbose=False).values() if r.id in rule_id}
rules = rules.filter(lambda r: r.id in rule_id)
else:
rules = {r.id: r for r in rule_loader.get_production_rules()}
rules = rules.filter(production_filter)

# add versions to the rules
manage_versions(list(rules.values()), verbose=False)
Expand Down Expand Up @@ -388,9 +390,11 @@ def rule_event_search(ctx, rule_file, rule_id, date_range, count, max_results, v
rule: TOMLRule

if rule_id:
rule = get_rule(rule_id, verbose=False)
rule = RuleCollection().load_by_id(rule_id)
if rule is None:
client_error(f"Unable to find rule with id {rule_id}")
elif rule_file:
rule = TOMLRule(path=rule_file, contents=TOMLRuleContents.from_dict(load_dump(rule_file)))
rule = RuleCollection().load_file(rule_file)
else:
client_error('Must specify a rule file or rule ID')

Expand Down Expand Up @@ -431,18 +435,17 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun
"""Survey rule counts."""
from eql.table import Table
from kibana.resources import Signal
from . import rule_loader
from .main import search_rules

survey_results = []
start_time, end_time = date_range

if query:
rule_paths = [r['file'] for r in ctx.invoke(search_rules, query=query, verbose=False)]
rules = rule_loader.load_rules(rule_loader.load_rule_files(paths=rule_paths, verbose=False), verbose=False)
rules = rules.values()
rules = RuleCollection()
paths = [Path(r['file']) for r in ctx.invoke(search_rules, query=query, verbose=False)]
rules.load_files(paths)
else:
rules = rule_loader.load_rules(verbose=False).values()
rules = RuleCollection.default().filter(production_filter)

click.echo(f'Running survey against {len(rules)} rules')
click.echo(f'Saving detailed dump to: {dump_file}')
Expand Down
10 changes: 5 additions & 5 deletions detection_rules/eswrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import json
import os
import time
from contextlib import contextmanager
from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from typing import Union

Expand All @@ -20,10 +20,9 @@
import kql
from .main import root
from .misc import add_params, client_error, elasticsearch_options
from .utils import format_command_options, normalize_timing_and_sort, unix_time_to_formatted, get_path
from .rule import TOMLRule
from .rule_loader import get_rule, rta_mappings

from .rule_loader import rta_mappings, RuleCollection
from .utils import format_command_options, normalize_timing_and_sort, unix_time_to_formatted, get_path

COLLECTION_DIR = get_path('collections')
MATCH_ALL = {'bool': {'filter': [{'match_all': {}}]}}
Expand Down Expand Up @@ -88,7 +87,8 @@ def evaluate_against_rule_and_update_mapping(self, rule_id, rta_name, verbose=Tr
"""Evaluate a rule against collected events and update mapping."""
from .utils import combine_sources, evaluate

rule = get_rule(rule_id, verbose=False)
rule = RuleCollection.load_by_id(rule_id)
assert rule is not None, f"Unable to find rule with ID {rule_id}"
merged_events = combine_sources(*self.events.values())
filtered = evaluate(rule, merged_events)

Expand Down
29 changes: 17 additions & 12 deletions detection_rules/kbwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
# 2.0.

"""Kibana cli commands."""
import uuid
from pathlib import Path

import click

import kql
from kibana import Kibana, Signal, RuleResource

from .main import root
from .misc import add_params, client_error, kibana_options
from .rule_loader import load_rule_files, load_rules
from .rule_loader import RuleCollection
from .schemas import downgrade
from .utils import format_command_options


Expand Down Expand Up @@ -54,25 +58,26 @@ def kibana_group(ctx: click.Context, **kibana_kwargs):
@click.pass_context
def upload_rule(ctx, toml_files, replace_id):
"""Upload a list of rule .toml files to Kibana."""
from .packaging import manage_versions

kibana = ctx.obj['kibana']
file_lookup = load_rule_files(paths=toml_files)
rules = list(load_rules(file_lookup=file_lookup).values())

# assign the versions from etc/versions.lock.json
# rules that have changed in hash get incremented, others stay as-is.
# rules that aren't in the lookup default to version 1
manage_versions(rules, verbose=False)
rules = RuleCollection()
rules.load_files(Path(p) for p in toml_files)

api_payloads = []

for rule in rules:
try:
payload = rule.get_payload(include_version=True, replace_id=replace_id, embed_metadata=True,
target_version=kibana.version)
payload = rule.contents.to_api_format()
payload.setdefault("meta", {}).update(rule.contents.metadata.to_dict())

if replace_id:
payload["rule_id"] = str(uuid.uuid4())

payload = downgrade(payload, target_version=kibana.version)

except ValueError as e:
client_error(f'{e} in version:{kibana.version}, for rule: {rule.name}', e, ctx=ctx)

rule = RuleResource(payload)
api_payloads.append(rule)

Expand Down
119 changes: 45 additions & 74 deletions detection_rules/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,27 @@
# 2.0.

"""CLI commands for detection_rules."""
import dataclasses
import glob
import json
import os
import re
import time
from pathlib import Path
from typing import Dict
from uuid import uuid4

import click
import jsonschema

from . import rule_loader
from .cli_utils import rule_prompt
from .misc import client_error, nested_set, parse_config
from .rule import TOMLRule
from .rule_formatter import toml_write
from .rule_loader import RuleCollection
from .schemas import CurrentSchema, available_versions
from .utils import get_path, clear_caches, load_rule_contents


RULES_DIR = get_path('rules')


Expand Down Expand Up @@ -59,15 +60,14 @@ def create_rule(path, config, required_only, rule_type):
@click.pass_context
def generate_rules_index(ctx: click.Context, query, overwrite, save_files=True):
"""Generate enriched indexes of rules, based on a KQL search, for indexing/importing into elasticsearch/kibana."""
from . import rule_loader
from .packaging import load_current_package_version, Package

if query:
rule_paths = [r['file'] for r in ctx.invoke(search_rules, query=query, verbose=False)]
rules = rule_loader.load_rules(rule_loader.load_rule_files(paths=rule_paths, verbose=False), verbose=False)
rules = rules.values()
rules = RuleCollection()
rules.load_files(Path(p) for p in rule_paths)
else:
rules = rule_loader.load_rules(verbose=False).values()
rules = RuleCollection.default()

rule_count = len(rules)
package = Package(rules, load_current_package_version(), verbose=False)
Expand Down Expand Up @@ -117,9 +117,10 @@ def name_to_filename(name):
def toml_lint(rule_file):
"""Cleanup files with some simple toml formatting."""
if rule_file:
rules = list(rule_loader.load_rules(rule_loader.load_rule_files(paths=[rule_file])).values())
rules = RuleCollection()
rules.load_files(Path(p) for p in rule_file)
else:
rules = list(rule_loader.load_rules().values())
rules = RuleCollection.default()

# removed unneeded defaults
# TODO: we used to remove "unneeded" defaults, but this is a potentially tricky thing.
Expand All @@ -131,7 +132,6 @@ def toml_lint(rule_file):
for rule in rules:
rule.save_toml()

rule_loader.reset()
click.echo('Toml file linting complete')


Expand All @@ -146,8 +146,10 @@ def toml_lint(rule_file):
@click.pass_context
def mass_update(ctx, query, metadata, language, field):
"""Update multiple rules based on eql results."""
rules = RuleCollection().default()
results = ctx.invoke(search_rules, query=query, language=language, verbose=False)
rules = [rule_loader.get_rule(r['rule_id'], verbose=False) for r in results]
matching_ids = set(r["rule_id"] for r in results)
rules = rules.filter(lambda r: r.id in matching_ids)

for rule in rules:
for key, value in field:
Expand All @@ -161,34 +163,17 @@ def mass_update(ctx, query, metadata, language, field):


@root.command('view-rule')
@click.argument('rule-id', required=False)
@click.option('--rule-file', '-f', type=click.Path(dir_okay=False), help='Optionally view a rule from a specified file')
@click.argument('rule-file')
@click.option('--api-format/--rule-format', default=True, help='Print the rule in final api or rule format')
@click.pass_context
def view_rule(ctx, rule_id, rule_file, api_format, verbose=True):
def view_rule(ctx, rule_file, api_format):
"""View an internal rule or specified rule file."""
rule = None
rule = RuleCollection().load_file(rule_file)

if rule_id:
rule = rule_loader.get_rule(rule_id, verbose=False)
elif rule_file:
contents = {k: v for k, v in load_rule_contents(rule_file, single_only=True)[0].items() if v}

try:
rule = TOMLRule(rule_file, contents)
except jsonschema.ValidationError as e:
client_error(f'Rule: {rule_id or os.path.basename(rule_file)} failed validation', e, ctx=ctx)
if api_format:
click.echo(json.dumps(rule.contents.to_api_format(), indent=2, sort_keys=True))
else:
client_error('Unknown rule!')

if not rule:
client_error('Unknown format!')

if verbose:
click.echo(toml_write(rule.rule_format()) if not api_format else
json.dumps(rule.get_payload(), indent=2, sort_keys=True))

return rule
click.echo(toml_write(rule.contents.to_dict()))


@root.command('export-rules')
Expand All @@ -211,39 +196,35 @@ def export_rules(rule_id, rule_file, directory, outfile, replace_id, stack_versi
if not (rule_id or rule_file or directory):
client_error('Required: at least one of --rule-id, --rule-file, or --directory')

rules = RuleCollection()

if rule_id:
all_rules = {r.id: r for r in rule_loader.load_rules(verbose=False).values()}
missing = [rid for rid in rule_id if rid not in all_rules]
rule_id = set(rule_id)
rules = RuleCollection.default().filter(lambda r: r.id in rule_id)
found_ids = {rule.id for rule in rules}
missing = rule_id.difference(found_ids)

if missing:
client_error(f'Unknown rules for rule IDs: {", ".join(missing)}')

rules = [r for r in all_rules.values() if r.id in rule_id]
rule_ids = [r.id for r in rules]
else:
rules = []
rule_ids = []

rule_files = list(rule_file)
for dirpath in directory:
rule_files.extend(list(Path(dirpath).rglob('*.toml')))

file_lookup = rule_loader.load_rule_files(verbose=False, paths=rule_files)
rules_from_files = rule_loader.load_rules(file_lookup=file_lookup).values() if file_lookup else []
if rule_file:
rules.load_files(Path(path) for path in rule_file)

# rule_loader.load_rules handles checks for duplicate rule IDs - this means rules loaded by ID are de-duped and
# rules loaded from files and directories are de-duped from each other, so this check is to ensure that there is
# no overlap between the two sets of rules
duplicates = [r.id for r in rules_from_files if r.id in rule_ids]
if duplicates:
client_error(f'Duplicate rules for rule IDs: {", ".join(duplicates)}')
for directory in directory:
rules.load_directory(Path(directory))

rules.extend(rules_from_files)
assert len(rules) > 0, "No rules found"

if replace_id:
from uuid import uuid4
for rule in rules:
rule.contents['rule_id'] = str(uuid4())
# if we need to replace the id, take each rule object and create a copy
# of it, with only the rule_id field changed
old_rules = rules
rules = RuleCollection()

for rule in old_rules:
new_data = dataclasses.replace(rule.contents.data, rule_id=str(uuid4()))
new_contents = dataclasses.replace(rule.contents, data=new_data)
rules.add_rule(TOMLRule(contents=new_contents))

Path(outfile).parent.mkdir(exist_ok=True)
package = Package(rules, '_', verbose=False)
Expand All @@ -252,29 +233,19 @@ def export_rules(rule_id, rule_file, directory, outfile, replace_id, stack_versi


@root.command('validate-rule')
@click.argument('rule-id', required=False)
@click.option('--rule-name', '-n')
@click.option('--path', '-p', type=click.Path(dir_okay=False))
@click.argument('path')
@click.pass_context
def validate_rule(ctx, rule_id, rule_name, path):
def validate_rule(ctx, path):
"""Check if a rule staged in rules dir validates against a schema."""
try:
rule = rule_loader.get_rule(rule_id, rule_name, path, verbose=False)
if not rule:
client_error('Rule not found!')

rule.validate(as_rule=True)
click.echo('Rule validation successful')
return rule
except jsonschema.ValidationError as e:
client_error(e.args[0], e, ctx=ctx)
rule = RuleCollection().load_file(Path(path))
click.echo('Rule validation successful')
return rule


@root.command('validate-all')
@click.option('--fail/--no-fail', default=True, help='Fail on first failure or process through all printing errors.')
def validate_all(fail):
"""Check if all rules validates against a schema."""
rule_loader.load_rules(verbose=True, error=fail)
RuleCollection.default()
click.echo('Rule validation successful')


Expand All @@ -292,7 +263,7 @@ def search_rules(query, columns, language, count, verbose=True, rules: Dict[str,
from eql.pipes import CountPipe

flattened_rules = []
rules = rules or rule_loader.load_rule_files(verbose=verbose)
rules = rules or {str(rule.path): rule for rule in RuleCollection.default()}

for file_name, rule_doc in rules.items():
flat = {"file": os.path.relpath(file_name)}
Expand Down
Loading