diff --git a/CLI.md b/CLI.md index 30910e57b68..baea6c8dab8 100644 --- a/CLI.md +++ b/CLI.md @@ -71,7 +71,7 @@ and will accept any valid rule in the following formats: #### `import-rules` ```console -Usage: detection_rules import-rules [OPTIONS] [INFILE]... +Usage: detection_rules import-rules [OPTIONS] [INPUT_FILE]... Import rules from json, toml, or Kibana exported rule file(s). @@ -159,34 +159,39 @@ Options: --cloud-id TEXT -k, --kibana-url TEXT -Usage: detection_rules kibana upload-rule [OPTIONS] TOML_FILES... +Usage: detection_rules kibana upload-rule [OPTIONS] Upload a list of rule .toml files to Kibana. Options: - -r, --replace-id Replace rule IDs with new IDs before export - -h, --help Show this message and exit. + -f, --rule-file FILE + -d, --directory DIRECTORY Recursively export rules from a directory + -id, --rule-id TEXT + -r, --replace-id Replace rule IDs with new IDs before export + -h, --help Show this message and exit. +(detection-rules-build) (base) ➜ detection-rules git:(rule-loader) ✗ ``` Alternatively, rules can be exported into a consolidated ndjson file which can be imported in the Kibana security app directly. ```console -Usage: detection_rules export-rules [OPTIONS] [RULE_ID]... +Usage: detection_rules export-rules [OPTIONS] Export rule(s) into an importable ndjson file. Options: - -f, --rule-file FILE Export specified rule files + -f, --rule-file FILE -d, --directory DIRECTORY Recursively export rules from a directory + -id, --rule-id TEXT -o, --outfile FILE Name of file for exported rules -r, --replace-id Replace rule IDs with new IDs before export - --stack-version [7.8|7.9|7.10|7.11] + --stack-version [7.8|7.9|7.10|7.11|7.12] Downgrade a rule version to be compatible with older instances of Kibana - -s, --skip-unsupported If `--stack-version` is passed, skip - rule types which are unsupported (an error - will be raised otherwise) + -s, --skip-unsupported If `--stack-version` is passed, skip rule + types which are unsupported (an error will + be raised otherwise) -h, --help Show this message and exit. ``` diff --git a/detection_rules/cli_utils.py b/detection_rules/cli_utils.py index 1004d4b284d..e59b626cc88 100644 --- a/detection_rules/cli_utils.py +++ b/detection_rules/cli_utils.py @@ -7,19 +7,95 @@ import datetime import os from pathlib import Path +from typing import List import click import kql +import functools from . import ecs from .attack import matrix, tactics, build_threat_map_entry from .rule import TOMLRule, TOMLRuleContents +from .rule_loader import RuleCollection, DEFAULT_RULES_DIR, dict_filter from .schemas import CurrentSchema from .utils import clear_caches, get_path RULES_DIR = get_path("rules") +def single_collection(f): + """Add arguments to get a RuleCollection by file, directory or a list of IDs""" + from .misc import client_error + + @click.option('--rule-file', '-f', multiple=False, required=False, type=click.Path(dir_okay=False)) + @click.option('--rule-id', '-id', multiple=False, required=False) + @functools.wraps(f) + def get_collection(*args, **kwargs): + rule_name: List[str] = kwargs.pop("rule_name", []) + rule_id: List[str] = kwargs.pop("rule_id", []) + rule_files: List[str] = kwargs.pop("rule_file") + directories: List[str] = kwargs.pop("directory") + + rules = RuleCollection() + + if bool(rule_name) + bool(rule_id) + bool(rule_files) != 1: + client_error('Required: exactly one of --rule-id, --rule-file, or --directory') + + rules.load_files(Path(p) for p in rule_files) + rules.load_directories(Path(d) for d in directories) + + if rule_id: + rules.load_directory(DEFAULT_RULES_DIR, toml_filter=dict_filter(rule__rule_id=rule_id)) + + if len(rules) != 1: + client_error(f"Could not find rule with ID {rule_id}") + + kwargs["rules"] = rules + return f(*args, **kwargs) + + return get_collection + + +def multi_collection(f): + """Add arguments to get a RuleCollection by file, directory or a list of IDs""" + from .misc import client_error + + @click.option('--rule-file', '-f', multiple=True, type=click.Path(dir_okay=False), required=False) + @click.option('--directory', '-d', multiple=True, type=click.Path(file_okay=False), required=False, + help='Recursively export rules from a directory') + @click.option('--rule-id', '-id', multiple=True, required=False) + @functools.wraps(f) + def get_collection(*args, **kwargs): + rule_name: List[str] = kwargs.pop("rule_name", []) + rule_id: List[str] = kwargs.pop("rule_id", []) + rule_files: List[str] = kwargs.pop("rule_file") + directories: List[str] = kwargs.pop("directory") + + rules = RuleCollection() + + if not rule_name or rule_id or rule_files: + client_error('Required: at least one of --rule-id, --rule-file, or --directory') + + rules.load_files(Path(p) for p in rule_files) + rules.load_directories(Path(d) for d in directories) + + if rule_id: + rules.load_directory(DEFAULT_RULES_DIR, toml_filter=dict_filter(rule__rule_id=rule_id)) + found_ids = {rule.id for rule in rules} + missing = set(rule_id).difference(found_ids) + + if missing: + client_error(f'Could not find rules with IDs: {", ".join(missing)}') + + if len(rules) == 0: + client_error("No rules found") + + kwargs["rules"] = rules + return f(*args, **kwargs) + + return get_collection + + def rule_prompt(path=None, rule_type=None, required_only=True, save=True, verbose=False, **kwargs) -> TOMLRule: """Prompt loop to build a rule.""" from .misc import schema_prompt diff --git a/detection_rules/devtools.py b/detection_rules/devtools.py index 87929a55e70..9caa81bd11e 100644 --- a/detection_rules/devtools.py +++ b/detection_rules/devtools.py @@ -17,15 +17,16 @@ import click from elasticsearch import Elasticsearch from eql import load_dump -from kibana.connector import Kibana +from kibana.connector import Kibana from . import rule_loader +from .cli_utils import single_collection from .eswrap import CollectEvents, add_range_to_dsl from .main import root from .misc import PYTHON_LICENSE, add_client, GithubClient, Manifest, client_error, getdefault from .packaging import PACKAGE_FILE, Package, manage_versions, RELEASE_DIR -from .rule import TOMLRule, TOMLRuleContents, BaseQueryRuleData -from .rule_loader import get_rule +from .rule import TOMLRule, BaseQueryRuleData +from .rule_loader import production_filter, RuleCollection from .utils import get_path, dict_hash RULES_DIR = get_path('rules') @@ -68,7 +69,7 @@ def update_lock_versions(rule_ids): if not click.confirm('Are you sure you want to update hashes without a version bump?'): return - rules = [r for r in rule_loader.load_rules(verbose=False).values() if r.id in rule_ids] + rules = RuleCollection.default().filter(lambda r: r.id in rule_ids) changed, new = manage_versions(rules, exclude_version_update=True, add_new=False, save_changes=True) if not changed: @@ -86,10 +87,12 @@ def kibana_diff(rule_id, repo, branch, threads): """Diff rules against their version represented in kibana if exists.""" from .misc import get_kibana_rules + rules = RuleCollection.default() + if rule_id: - rules = {r.id: r for r in rule_loader.load_rules(verbose=False).values() if r.id in rule_id} + rules = rules.filter(lambda r: r.id in rule_id) else: - rules = {r.id: r for r in rule_loader.get_production_rules()} + rules = rules.filter(production_filter) # add versions to the rules manage_versions(list(rules.values()), verbose=False) @@ -102,13 +105,13 @@ def kibana_diff(rule_id, repo, branch, threads): missing_from_kibana = list(set(repo_hashes).difference(set(kibana_hashes))) rule_diff = [] - for rid, rhash in repo_hashes.items(): - if rid in missing_from_kibana: + for rule_id, rule_hash in repo_hashes.items(): + if rule_id in missing_from_kibana: continue - if rhash != kibana_hashes[rid]: + if rule_hash != kibana_hashes[rule_id]: rule_diff.append( - f'versions - repo: {rules[rid].contents["version"]}, kibana: {kibana_rules[rid]["version"]} -> ' - f'{rid} - {rules[rid].name}' + f'versions - repo: {rules[rule_id].contents["version"]}, kibana: {kibana_rules[rule_id]["version"]} -> ' + f'{rule_id} - {rules[rule_id].name}' ) diff = { @@ -373,8 +376,7 @@ def event_search(query, index, language, date_range, count, max_results, verbose @test_group.command('rule-event-search') -@click.argument('rule-file', type=click.Path(dir_okay=False), required=False) -@click.option('--rule-id', '-id') +@single_collection @click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search') @click.option('--count', '-c', is_flag=True, help='Return count of results only') @click.option('--max-results', '-m', type=click.IntRange(1, 1000), default=100, @@ -382,17 +384,9 @@ def event_search(query, index, language, date_range, count, max_results, verbose @click.option('--verbose', '-v', is_flag=True) @click.pass_context @add_client('elasticsearch') -def rule_event_search(ctx, rule_file, rule_id, date_range, count, max_results, verbose, +def rule_event_search(ctx, rule, date_range, count, max_results, verbose, elasticsearch_client: Elasticsearch = None): """Search using a rule file against an Elasticsearch instance.""" - rule: TOMLRule - - if rule_id: - rule = get_rule(rule_id, verbose=False) - elif rule_file: - rule = TOMLRule(path=rule_file, contents=TOMLRuleContents.from_dict(load_dump(rule_file))) - else: - client_error('Must specify a rule file or rule ID') if isinstance(rule.contents.data, BaseQueryRuleData): if verbose: @@ -431,18 +425,17 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun """Survey rule counts.""" from eql.table import Table from kibana.resources import Signal - from . import rule_loader from .main import search_rules survey_results = [] start_time, end_time = date_range if query: - rule_paths = [r['file'] for r in ctx.invoke(search_rules, query=query, verbose=False)] - rules = rule_loader.load_rules(rule_loader.load_rule_files(paths=rule_paths, verbose=False), verbose=False) - rules = rules.values() + rules = RuleCollection() + paths = [Path(r['file']) for r in ctx.invoke(search_rules, query=query, verbose=False)] + rules.load_files(paths) else: - rules = rule_loader.load_rules(verbose=False).values() + rules = RuleCollection.default().filter(production_filter) click.echo(f'Running survey against {len(rules)} rules') click.echo(f'Saving detailed dump to: {dump_file}') diff --git a/detection_rules/eswrap.py b/detection_rules/eswrap.py index 3882e3786c4..c9a019d1d7e 100644 --- a/detection_rules/eswrap.py +++ b/detection_rules/eswrap.py @@ -7,8 +7,8 @@ import json import os import time -from contextlib import contextmanager from collections import defaultdict +from contextlib import contextmanager from pathlib import Path from typing import Union @@ -20,10 +20,9 @@ import kql from .main import root from .misc import add_params, client_error, elasticsearch_options -from .utils import format_command_options, normalize_timing_and_sort, unix_time_to_formatted, get_path from .rule import TOMLRule -from .rule_loader import get_rule, rta_mappings - +from .rule_loader import rta_mappings, RuleCollection +from .utils import format_command_options, normalize_timing_and_sort, unix_time_to_formatted, get_path COLLECTION_DIR = get_path('collections') MATCH_ALL = {'bool': {'filter': [{'match_all': {}}]}} @@ -88,7 +87,8 @@ def evaluate_against_rule_and_update_mapping(self, rule_id, rta_name, verbose=Tr """Evaluate a rule against collected events and update mapping.""" from .utils import combine_sources, evaluate - rule = get_rule(rule_id, verbose=False) + rule = next((rule for rule in RuleCollection.default() if rule.id == rule_id), None) + assert rule is not None, f"Unable to find rule with ID {rule_id}" merged_events = combine_sources(*self.events.values()) filtered = evaluate(rule, merged_events) diff --git a/detection_rules/kbwrap.py b/detection_rules/kbwrap.py index 7501dce552f..13895b4b66d 100644 --- a/detection_rules/kbwrap.py +++ b/detection_rules/kbwrap.py @@ -4,13 +4,16 @@ # 2.0. """Kibana cli commands.""" +import uuid + import click + import kql from kibana import Kibana, Signal, RuleResource - +from .cli_utils import multi_collection from .main import root from .misc import add_params, client_error, kibana_options -from .rule_loader import load_rule_files, load_rules +from .schemas import downgrade from .utils import format_command_options @@ -49,30 +52,28 @@ def kibana_group(ctx: click.Context, **kibana_kwargs): @kibana_group.command("upload-rule") -@click.argument("toml-files", nargs=-1, required=True) +@multi_collection @click.option('--replace-id', '-r', is_flag=True, help='Replace rule IDs with new IDs before export') @click.pass_context -def upload_rule(ctx, toml_files, replace_id): +def upload_rule(ctx, rules, replace_id): """Upload a list of rule .toml files to Kibana.""" - from .packaging import manage_versions kibana = ctx.obj['kibana'] - file_lookup = load_rule_files(paths=toml_files) - rules = list(load_rules(file_lookup=file_lookup).values()) - - # assign the versions from etc/versions.lock.json - # rules that have changed in hash get incremented, others stay as-is. - # rules that aren't in the lookup default to version 1 - manage_versions(rules, verbose=False) - api_payloads = [] for rule in rules: try: - payload = rule.get_payload(include_version=True, replace_id=replace_id, embed_metadata=True, - target_version=kibana.version) + payload = rule.contents.to_api_format() + payload.setdefault("meta", {}).update(rule.contents.metadata.to_dict()) + + if replace_id: + payload["rule_id"] = str(uuid.uuid4()) + + payload = downgrade(payload, target_version=kibana.version) + except ValueError as e: client_error(f'{e} in version:{kibana.version}, for rule: {rule.name}', e, ctx=ctx) + rule = RuleResource(payload) api_payloads.append(rule) diff --git a/detection_rules/main.py b/detection_rules/main.py index ddf75cf0058..3607ab91fbc 100644 --- a/detection_rules/main.py +++ b/detection_rules/main.py @@ -4,6 +4,7 @@ # 2.0. """CLI commands for detection_rules.""" +import dataclasses import glob import json import os @@ -11,19 +12,19 @@ import time from pathlib import Path from typing import Dict +from uuid import uuid4 import click -import jsonschema from . import rule_loader -from .cli_utils import rule_prompt -from .misc import client_error, nested_set, parse_config +from .cli_utils import rule_prompt, multi_collection +from .misc import nested_set, parse_config from .rule import TOMLRule from .rule_formatter import toml_write +from .rule_loader import RuleCollection from .schemas import CurrentSchema, available_versions from .utils import get_path, clear_caches, load_rule_contents - RULES_DIR = get_path('rules') @@ -59,15 +60,14 @@ def create_rule(path, config, required_only, rule_type): @click.pass_context def generate_rules_index(ctx: click.Context, query, overwrite, save_files=True): """Generate enriched indexes of rules, based on a KQL search, for indexing/importing into elasticsearch/kibana.""" - from . import rule_loader from .packaging import load_current_package_version, Package if query: rule_paths = [r['file'] for r in ctx.invoke(search_rules, query=query, verbose=False)] - rules = rule_loader.load_rules(rule_loader.load_rule_files(paths=rule_paths, verbose=False), verbose=False) - rules = rules.values() + rules = RuleCollection() + rules.load_files(Path(p) for p in rule_paths) else: - rules = rule_loader.load_rules(verbose=False).values() + rules = RuleCollection.default() rule_count = len(rules) package = Package(rules, load_current_package_version(), verbose=False) @@ -88,12 +88,12 @@ def generate_rules_index(ctx: click.Context, query, overwrite, save_files=True): @root.command('import-rules') -@click.argument('infile', type=click.Path(dir_okay=False, exists=True), nargs=-1, required=False) +@click.argument('input-file', type=click.Path(dir_okay=False, exists=True), nargs=-1, required=False) @click.option('--directory', '-d', type=click.Path(file_okay=False, exists=True), help='Load files from a directory') -def import_rules(infile, directory): +def import_rules(input_file, directory): """Import rules from json, toml, or Kibana exported rule file(s).""" rule_files = glob.glob(os.path.join(directory, '**', '*.*'), recursive=True) if directory else [] - rule_files = sorted(set(rule_files + list(infile))) + rule_files = sorted(set(rule_files + list(input_file))) rule_contents = [] for rule_file in rule_files: @@ -117,22 +117,16 @@ def name_to_filename(name): def toml_lint(rule_file): """Cleanup files with some simple toml formatting.""" if rule_file: - rules = list(rule_loader.load_rules(rule_loader.load_rule_files(paths=[rule_file])).values()) + rules = RuleCollection() + rules.load_files(Path(p) for p in rule_file) else: - rules = list(rule_loader.load_rules().values()) - - # removed unneeded defaults - # TODO: we used to remove "unneeded" defaults, but this is a potentially tricky thing. - # we need to figure out if a default is Kibana-imposed or detection-rules imposed. - # ideally, we can explicitly mention default in TOML if desired and have a concept - # of build-time defaults, so that defaults are filled in as late as possible + rules = RuleCollection.default() # re-save the rules to force TOML reformatting for rule in rules: rule.save_toml() - rule_loader.reset() - click.echo('Toml file linting complete') + click.echo('TOML file linting complete') @root.command('mass-update') @@ -146,8 +140,10 @@ def toml_lint(rule_file): @click.pass_context def mass_update(ctx, query, metadata, language, field): """Update multiple rules based on eql results.""" + rules = RuleCollection().default() results = ctx.invoke(search_rules, query=query, language=language, verbose=False) - rules = [rule_loader.get_rule(r['rule_id'], verbose=False) for r in results] + matching_ids = set(r["rule_id"] for r in results) + rules = rules.filter(lambda r: r.id in matching_ids) for rule in rules: for key, value in field: @@ -161,41 +157,21 @@ def mass_update(ctx, query, metadata, language, field): @root.command('view-rule') -@click.argument('rule-id', required=False) -@click.option('--rule-file', '-f', type=click.Path(dir_okay=False), help='Optionally view a rule from a specified file') +@click.argument('rule-file') @click.option('--api-format/--rule-format', default=True, help='Print the rule in final api or rule format') @click.pass_context -def view_rule(ctx, rule_id, rule_file, api_format, verbose=True): +def view_rule(ctx, rule_file, api_format): """View an internal rule or specified rule file.""" - rule = None - - if rule_id: - rule = rule_loader.get_rule(rule_id, verbose=False) - elif rule_file: - contents = {k: v for k, v in load_rule_contents(rule_file, single_only=True)[0].items() if v} + rule = RuleCollection().load_file(rule_file) - try: - rule = TOMLRule(rule_file, contents) - except jsonschema.ValidationError as e: - client_error(f'Rule: {rule_id or os.path.basename(rule_file)} failed validation', e, ctx=ctx) + if api_format: + click.echo(json.dumps(rule.contents.to_api_format(), indent=2, sort_keys=True)) else: - client_error('Unknown rule!') - - if not rule: - client_error('Unknown format!') - - if verbose: - click.echo(toml_write(rule.rule_format()) if not api_format else - json.dumps(rule.get_payload(), indent=2, sort_keys=True)) - - return rule + click.echo(toml_write(rule.contents.to_dict())) @root.command('export-rules') -@click.argument('rule-id', nargs=-1, required=False) -@click.option('--rule-file', '-f', multiple=True, type=click.Path(dir_okay=False), help='Export specified rule files') -@click.option('--directory', '-d', multiple=True, type=click.Path(file_okay=False), - help='Recursively export rules from a directory') +@multi_collection @click.option('--outfile', '-o', default=get_path('exports', f'{time.strftime("%Y%m%dT%H%M%SL")}.ndjson'), type=click.Path(dir_okay=False), help='Name of file for exported rules') @click.option('--replace-id', '-r', is_flag=True, help='Replace rule IDs with new IDs before export') @@ -204,46 +180,22 @@ def view_rule(ctx, rule_id, rule_file, api_format, verbose=True): @click.option('--skip-unsupported', '-s', is_flag=True, help='If `--stack-version` is passed, skip rule types which are unsupported ' '(an error will be raised otherwise)') -def export_rules(rule_id, rule_file, directory, outfile, replace_id, stack_version, skip_unsupported): +def export_rules(rules, outfile, replace_id, stack_version, skip_unsupported): """Export rule(s) into an importable ndjson file.""" from .packaging import Package - if not (rule_id or rule_file or directory): - client_error('Required: at least one of --rule-id, --rule-file, or --directory') - - if rule_id: - all_rules = {r.id: r for r in rule_loader.load_rules(verbose=False).values()} - missing = [rid for rid in rule_id if rid not in all_rules] - - if missing: - client_error(f'Unknown rules for rule IDs: {", ".join(missing)}') - - rules = [r for r in all_rules.values() if r.id in rule_id] - rule_ids = [r.id for r in rules] - else: - rules = [] - rule_ids = [] - - rule_files = list(rule_file) - for dirpath in directory: - rule_files.extend(list(Path(dirpath).rglob('*.toml'))) - - file_lookup = rule_loader.load_rule_files(verbose=False, paths=rule_files) - rules_from_files = rule_loader.load_rules(file_lookup=file_lookup).values() if file_lookup else [] - - # rule_loader.load_rules handles checks for duplicate rule IDs - this means rules loaded by ID are de-duped and - # rules loaded from files and directories are de-duped from each other, so this check is to ensure that there is - # no overlap between the two sets of rules - duplicates = [r.id for r in rules_from_files if r.id in rule_ids] - if duplicates: - client_error(f'Duplicate rules for rule IDs: {", ".join(duplicates)}') - - rules.extend(rules_from_files) + assert len(rules) > 0, "No rules found" if replace_id: - from uuid import uuid4 - for rule in rules: - rule.contents['rule_id'] = str(uuid4()) + # if we need to replace the id, take each rule object and create a copy + # of it, with only the rule_id field changed + old_rules = rules + rules = RuleCollection() + + for rule in old_rules: + new_data = dataclasses.replace(rule.contents.data, rule_id=str(uuid4())) + new_contents = dataclasses.replace(rule.contents, data=new_data) + rules.add_rule(TOMLRule(contents=new_contents)) Path(outfile).parent.mkdir(exist_ok=True) package = Package(rules, '_', verbose=False) @@ -252,29 +204,19 @@ def export_rules(rule_id, rule_file, directory, outfile, replace_id, stack_versi @root.command('validate-rule') -@click.argument('rule-id', required=False) -@click.option('--rule-name', '-n') -@click.option('--path', '-p', type=click.Path(dir_okay=False)) +@click.argument('path') @click.pass_context -def validate_rule(ctx, rule_id, rule_name, path): +def validate_rule(ctx, path): """Check if a rule staged in rules dir validates against a schema.""" - try: - rule = rule_loader.get_rule(rule_id, rule_name, path, verbose=False) - if not rule: - client_error('Rule not found!') - - rule.validate(as_rule=True) - click.echo('Rule validation successful') - return rule - except jsonschema.ValidationError as e: - client_error(e.args[0], e, ctx=ctx) + rule = RuleCollection().load_file(Path(path)) + click.echo('Rule validation successful') + return rule @root.command('validate-all') -@click.option('--fail/--no-fail', default=True, help='Fail on first failure or process through all printing errors.') def validate_all(fail): """Check if all rules validates against a schema.""" - rule_loader.load_rules(verbose=True, error=fail) + RuleCollection.default() click.echo('Rule validation successful') @@ -292,7 +234,7 @@ def search_rules(query, columns, language, count, verbose=True, rules: Dict[str, from eql.pipes import CountPipe flattened_rules = [] - rules = rules or rule_loader.load_rule_files(verbose=verbose) + rules = rules or {str(rule.path): rule for rule in RuleCollection.default()} for file_name, rule_doc in rules.items(): flat = {"file": os.path.relpath(file_name)} diff --git a/detection_rules/mappings.py b/detection_rules/mappings.py index a13138fce60..fc64495c528 100644 --- a/detection_rules/mappings.py +++ b/detection_rules/mappings.py @@ -10,7 +10,6 @@ from .schemas import validate_rta_mapping from .utils import load_etc_dump, save_etc_dump, get_path - RTA_DIR = get_path("rta") diff --git a/detection_rules/mixins.py b/detection_rules/mixins.py index ad4ed9b45b1..2ef9d2ec11d 100644 --- a/detection_rules/mixins.py +++ b/detection_rules/mixins.py @@ -35,6 +35,10 @@ def __schema(cls: ClassT) -> Schema: """Get the marshmallow schema for the data class""" return marshmallow_dataclass.class_schema(cls)() + def get(self, key: str): + """Get a key from the query data without raising attribute errors.""" + return getattr(self, key, None) + @classmethod def from_dict(cls: Type[ClassT], obj: dict) -> ClassT: """Deserialize and validate a dataclass from a dict using marshmallow.""" diff --git a/detection_rules/packaging.py b/detection_rules/packaging.py index 13017aaf2f2..6b336dce6c1 100644 --- a/detection_rules/packaging.py +++ b/detection_rules/packaging.py @@ -17,10 +17,10 @@ import click import yaml -from . import rule_loader from .misc import JS_LICENSE, cached -from .rule import TOMLRule, BaseQueryRuleData, RULES_DIR, ThreatMapping +from .rule import TOMLRule, BaseQueryRuleData, ThreatMapping from .rule import downgrade_contents_from_rule +from .rule_loader import RuleCollection, DEFAULT_RULES_DIR from .schemas import CurrentSchema, definitions from .utils import Ndjson, get_path, get_etc_path, load_etc_dump, save_etc_dump @@ -325,7 +325,7 @@ def get_package_hash(self, as_api=True, verbose=True): @classmethod def from_config(cls, config: dict = None, update_version_lock: bool = False, verbose: bool = False) -> 'Package': """Load a rules package given a config.""" - all_rules = rule_loader.load_rules(verbose=False).values() + all_rules = RuleCollection.default() config = config or {} exclude_fields = config.pop('exclude_fields', {}) log_deprecated = config.pop('log_deprecated', False) @@ -335,21 +335,14 @@ def from_config(cls, config: dict = None, update_version_lock: bool = False, ver if log_deprecated: deprecated_rules = [r for r in all_rules if r.contents.metadata.maturity == 'deprecated'] - rules = list(filter(lambda rule: filter_rule(rule, rule_filter, exclude_fields), all_rules)) + rules = all_rules.filter(lambda r: filter_rule(r, rule_filter, exclude_fields)) if verbose: click.echo(f' - {len(all_rules) - len(rules)} rules excluded from package') - update = config.pop('update', {}) package = cls(rules, deprecated_rules=deprecated_rules, update_version_lock=update_version_lock, verbose=verbose, **config) - # Allow for some fields to be overwritten - if update.get('data', {}): - for rule in package.rules: - for sub_dict, values in update.items(): - rule.contents[sub_dict].update(values) - return package def generate_summary_and_changelog(self, changed_rule_ids, new_rule_ids, removed_rules): @@ -551,7 +544,7 @@ def create_bulk_index_body(self) -> Tuple[Ndjson, Ndjson]: status=status, package_version=self.name, flat_mitre=ThreatMapping.flatten(rule.contents.data.threat).to_dict(), - relative_path=str(rule.path.resolve().relative_to(RULES_DIR))) + relative_path=str(rule.path.resolve().relative_to(DEFAULT_RULES_DIR))) bulk_upload_docs.append(rule_doc) importable_rules_docs.append(rule_doc) diff --git a/detection_rules/rule.py b/detection_rules/rule.py index 24c3aa0b9e8..cd52ce81184 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -16,11 +16,10 @@ from . import ecs, beats, utils from .mixins import MarshmallowDataclassMixin from .rule_formatter import toml_write, nested_normalize -from .schemas import downgrade from .schemas import definitions -from .utils import get_path, cached +from .schemas import downgrade +from .utils import cached -RULES_DIR = get_path("rules") _META_SCHEMA_REQ_DEFAULTS = {} @@ -421,7 +420,7 @@ def sha256(self) -> str: @dataclass class TOMLRule: contents: TOMLRuleContents = field(hash=True) - path: Path + path: Optional[Path] = None gh_pr: Any = field(hash=False, compare=False, default=None, repr=None) @property @@ -437,6 +436,7 @@ def get_asset(self) -> dict: return {"id": self.id, "attributes": self.contents.to_api_format(), "type": definitions.SAVED_OBJECT_TYPE} def save_toml(self): + assert self.path is not None, f"Can't save rule {self.name} (self.id) without a path" converted = self.contents.to_dict() toml_write(converted, str(self.path.absolute())) diff --git a/detection_rules/rule_loader.py b/detection_rules/rule_loader.py index f8efe6345c5..381a44f996c 100644 --- a/detection_rules/rule_loader.py +++ b/detection_rules/rule_loader.py @@ -4,132 +4,196 @@ # 2.0. """Load rule metadata transform between rule and api formats.""" -import functools -import glob import io -import os -import re from collections import OrderedDict from pathlib import Path -from typing import Dict, List, Iterable +from typing import Dict, List, Iterable, Callable, Optional import click import pytoml from .mappings import RtaMappings -from .rule import RULES_DIR, TOMLRule, TOMLRuleContents, EQLRuleData, KQLRuleData -from .schemas import CurrentSchema +from .rule import TOMLRule, TOMLRuleContents +from .schemas import CurrentSchema, definitions from .utils import get_path, cached +DEFAULT_RULES_DIR = Path(get_path("rules")) RTA_DIR = get_path("rta") FILE_PATTERN = r'^([a-z0-9_])+\.(json|toml)$' -def mock_loader(f): - """Mock rule loader.""" - @functools.wraps(f) - def wrapped(*args, **kwargs): - try: - return f(*args, **kwargs) - finally: - load_rules.clear() +def path_getter(value: str) -> Callable[[dict], bool]: + """Get the path from a Python object.""" + path = value.replace("__", ".").split(".") - return wrapped + def callback(obj: dict): + for p in path: + if isinstance(obj, dict) and p in path: + obj = obj[p] + else: + return None + return obj -def reset(): - """Clear all rule caches.""" - load_rule_files.clear() - load_rules.clear() - get_rule.clear() - filter_rules.clear() + return callback -@cached -def load_rule_files(verbose=True, paths=None): - """Load the rule YAML files, but without parsing the EQL query portion.""" - file_lookup = {} # type: dict[str, dict] +def dict_filter(_obj: Optional[dict] = None, **critieria) -> Callable[[dict], bool]: + """Get a callable that will return true if a dictionary matches a set of criteria. - if verbose: - print("Loading rules from {}".format(RULES_DIR)) + * each key is a dotted (or __ delimited) path into a dictionary to check + * each value is a value or list of values to match + """ + critieria.update(_obj or {}) + checkers = [(path_getter(k), set(v) if isinstance(v, (list, set, tuple)) else {v}) for k, v in critieria.items()] - if paths is None: - paths = sorted(glob.glob(os.path.join(RULES_DIR, '**', '*.toml'), recursive=True)) + def callback(obj: dict) -> bool: + for getter, expected in checkers: + target_values = getter(obj) + target_values = set(target_values) if isinstance(target_values, (list, set, tuple)) else {target_values} - for rule_file in paths: - try: - # use pytoml instead of toml because of annoying bugs - # https://github.com/uiri/toml/issues/152 - # might also be worth looking at https://github.com/sdispater/tomlkit - with io.open(rule_file, "r", encoding="utf-8") as f: - file_lookup[rule_file] = pytoml.load(f) - except Exception: - print(u"Error loading {}".format(rule_file)) - raise + return bool(expected.intersection(target_values)) - if verbose: - print("Loaded {} rules".format(len(file_lookup))) - return file_lookup + return False + return callback -@cached -def load_rules(file_lookup=None, verbose=True, error=True): - """Load all the rules from toml files.""" - file_lookup = file_lookup or load_rule_files(verbose=verbose) - - failed = False - rules: List[TOMLRule] = [] - errors = [] - queries = [] - query_check_index = [] - rule_ids = set() - rule_names = set() - - for rule_file, rule_contents in file_lookup.items(): - try: - contents = TOMLRuleContents.from_dict(rule_contents) - rule = TOMLRule(path=Path(rule_file), contents=contents) - if rule.id in rule_ids: - existing = next(r for r in rules if r.id == rule.id) - raise KeyError(f'{rule.path} has duplicate ID with \n{existing.path}') +def metadata_filter(**metadata) -> Callable[[TOMLRule], bool]: + """Get a filter callback based off rule metadata""" + flt = dict_filter(metadata) - if rule.name in rule_names: - existing = next(r for r in rules if r.name == rule.name) - raise KeyError(f'{rule.path} has duplicate name with \n{existing.path}') + def callback(rule: TOMLRule) -> bool: + target_dict = rule.contents.metadata.to_dict() + return flt(target_dict) - if isinstance(contents.data, (KQLRuleData, EQLRuleData)): - duplicate_key = (contents.data.parsed_query, contents.data.type) - query_check_index.append(rule) + return callback - if duplicate_key in queries: - existing = query_check_index[queries.index(duplicate_key)] - raise KeyError(f'{rule.path} has duplicate query with \n{existing.path}') - queries.append(duplicate_key) +production_filter = metadata_filter(maturity="production") +deprecate_filter = metadata_filter(maturity="deprecated") - if not re.match(FILE_PATTERN, os.path.basename(rule.path)): - raise ValueError(f'{rule.path} does not meet rule name standard of {FILE_PATTERN}') - rules.append(rule) - rule_ids.add(rule.id) - rule_names.add(rule.name) +class RuleCollection: + """Collection of rule objects.""" - except Exception as e: - failed = True - err_msg = "Invalid rule file in {}\n{}".format(rule_file, click.style(str(e), fg='red')) - errors.append(err_msg) - if error: - if verbose: - print(err_msg) - raise e + __default = None + + def __init__(self, rules: Optional[List[TOMLRule]] = None): + self.id_map: Dict[definitions.UUIDString, TOMLRule] = {} + self.file_map: Dict[Path, TOMLRule] = {} + self.rules: List[TOMLRule] = [] + self.frozen = False + + self._toml_load_cache: Dict[Path, dict] = {} + + for rule in (rules or []): + self.add_rule(rule) - if failed: - if verbose: - for e in errors: - print(e) + def __len__(self): + """Get the total amount of rules in the collection.""" + return len(self.rules) - return OrderedDict([(rule.id, rule) for rule in sorted(rules, key=lambda r: r.name)]) + def __iter__(self): + """Iterate over all rules in the collection.""" + return iter(self.rules) + + def __contains__(self, rule: TOMLRule): + """Check if a rule is in the map by comparing IDs.""" + return rule.id in self.id_map + + def filter(self, cb: Callable[[TOMLRule], bool]) -> 'RuleCollection': + """Retrieve a filtered collection of rules.""" + filtered_collection = RuleCollection() + + for rule in filter(cb, self.rules): + filtered_collection.add_rule(rule) + + return filtered_collection + + def _deserialize_toml(self, path: Path) -> dict: + if path in self._toml_load_cache: + return self._toml_load_cache[path] + + # use pytoml instead of toml because of annoying bugs + # https://github.com/uiri/toml/issues/152 + # might also be worth looking at https://github.com/sdispater/tomlkit + with io.open(str(path.resolve()), "r", encoding="utf-8") as f: + toml_dict = pytoml.load(f) + self._toml_load_cache[path] = toml_dict + return toml_dict + + def _get_paths(self, directory: Path, recursive=True) -> List[Path]: + return sorted(directory.rglob('*.toml') if recursive else directory.glob('*.toml')) + + def add_rule(self, rule: TOMLRule): + assert not self.frozen, f"Unable to add rule {rule.name} {rule.id} to a frozen collection" + assert rule.id not in self.id_map, \ + f"Rule ID {rule.id} for {rule.name} collides with rule {self.id_map.get(rule.id).name}" + + if rule.path is not None: + rule.path = rule.path.resolve() + assert rule.path not in self.file_map, f"Rule file {rule.path} already loaded" + self.file_map[rule.path] = rule + + self.id_map[rule.id] = rule + self.rules.append(rule) + + def load_dict(self, obj: dict, path: Optional[Path] = None): + contents = TOMLRuleContents.from_dict(obj) + rule = TOMLRule(path=path, contents=contents) + self.add_rule(rule) + + return rule + + def load_file(self, path: Path) -> TOMLRule: + try: + path = path.resolve() + + # use the default rule loader as a cache. + # if it already loaded the rule, then we can just use it from that + if self.__default is not None and self is not self.__default and path in self.__default.file_map: + rule = self.__default.file_map[path] + self.add_rule(rule) + return rule + + obj = self._deserialize_toml(path) + return self.load_dict(obj, path=path) + except Exception: + print(f"Error loading rule in {path}") + raise + + def load_files(self, paths: Iterable[Path]): + """Load multiple files into the collection.""" + for path in paths: + self.load_file(path) + + def load_directory(self, directory: Path, recursive=True, toml_filter: Optional[Callable[[dict], bool]] = None): + paths = self._get_paths(directory, recursive=recursive) + if toml_filter is not None: + paths = [path for path in paths if toml_filter(self._deserialize_toml(path))] + + self.load_files(paths) + + def load_directories(self, directories: Iterable[Path], recursive=True, + toml_filter: Optional[Callable[[dict], bool]] = None): + for path in directories: + self.load_directory(path, recursive=recursive, toml_filter=toml_filter) + + def freeze(self): + """Freeze the rule collection and make it immutable going forward.""" + self.frozen = True + + @classmethod + def default(cls): + """Return the default rule collection, which retrieves from rules/.""" + if cls.__default is None: + collection = RuleCollection() + collection.load_directory(DEFAULT_RULES_DIR) + collection.freeze() + cls.__default = collection + + return cls.__default @cached @@ -151,7 +215,7 @@ def load_github_pr_rules(labels: list = None, repo: str = 'elastic/detection-rul modified_rules: List[TOMLRule] = [] errors: Dict[str, list] = {} - existing_rules = load_rules(verbose=False) + existing_rules = RuleCollection.default() pr_rules = [] if verbose: @@ -165,7 +229,7 @@ def download_worker(pr_info): rule = TOMLRule(rule_file.filename, raw_rule) rule.gh_pr = pull - if rule.id in existing_rules: + if rule in existing_rules: modified_rules.append(rule) else: new_rules.append(rule) @@ -191,57 +255,6 @@ def download_worker(pr_info): return new, modified, errors -@cached -def get_rule(rule_id=None, rule_name=None, file_name=None, verbose=True): - """Get a rule based on its id.""" - rules_lookup = load_rules(verbose=verbose) - if rule_id is not None: - return rules_lookup.get(rule_id) - - for rule in rules_lookup.values(): # type: TOMLRule - if rule.name == rule_name: - return rule - elif rule.path == file_name: - return rule - - -def get_rule_name(rule_id, verbose=True): - """Get the name of a rule given the rule id.""" - rule = get_rule(rule_id, verbose=verbose) - if rule: - return rule.name - - -def get_file_name(rule_id, verbose=True): - """Get the file path that corresponds to a rule.""" - rule = get_rule(rule_id, verbose=verbose) - if rule: - return rule.path - - -def get_rule_contents(rule_id, verbose=True): - """Get the full contents for a rule_id.""" - rule = get_rule(rule_id, verbose=verbose) - if rule: - return rule.contents - - -@cached -def filter_rules(rules: Iterable[TOMLRule], metadata_field: str, value) -> List[TOMLRule]: - """Filter rules based on the metadata.""" - return [rule for rule in rules if rule.contents.metadata.to_dict().get(metadata_field) == value] - - -def get_production_rules(verbose=False, include_deprecated=False) -> List[TOMLRule]: - """Get rules with a maturity of production.""" - from .packaging import filter_rule - - maturity = ['production'] - if include_deprecated: - maturity.append('deprecated') - return [rule for rule in load_rules(verbose=verbose).values() if filter_rule(rule, {'maturity': maturity})] - - @cached def get_non_required_defaults_by_type(rule_type: str) -> dict: """Get list of fields which are not required for a specified rule type.""" @@ -265,18 +278,13 @@ def find_unneeded_defaults_from_rule(toml_contents: dict) -> dict: __all__ = ( "FILE_PATTERN", - "load_rule_files", - "load_rules", - "load_rule_files", + "DEFAULT_RULES_DIR", "load_github_pr_rules", - "get_file_name", "get_non_required_defaults_by_type", - "get_production_rules", - "get_rule", - "filter_rules", + "RuleCollection", + "metadata_filter", + "production_filter", + "dict_filter", "find_unneeded_defaults_from_rule", - "get_rule_name", - "get_rule_contents", - "reset", "rta_mappings" ) diff --git a/tests/base.py b/tests/base.py index da6b0262930..2bcdce06fb8 100644 --- a/tests/base.py +++ b/tests/base.py @@ -7,8 +7,8 @@ import unittest -from detection_rules import rule_loader from detection_rules.rule import TOMLRule +from detection_rules.rule_loader import RuleCollection, production_filter class BaseRuleTest(unittest.TestCase): @@ -16,10 +16,9 @@ class BaseRuleTest(unittest.TestCase): @classmethod def setUpClass(cls): - cls.rule_files = rule_loader.load_rule_files(verbose=False) - cls.rule_lookup = rule_loader.load_rules(verbose=False) - cls.rules = cls.rule_lookup.values() - cls.production_rules = rule_loader.get_production_rules() + cls.all_rules = RuleCollection.default() + cls.rule_lookup = {rule.id: rule for rule in cls.all_rules} + cls.production_rules = cls.all_rules.filter(production_filter) @staticmethod def rule_str(rule: TOMLRule, trailer=' ->'): diff --git a/tests/test_all_rules.py b/tests/test_all_rules.py index eb4c8567dad..0ae9eb8cdb5 100644 --- a/tests/test_all_rules.py +++ b/tests/test_all_rules.py @@ -4,23 +4,18 @@ # 2.0. """Test that all rules have valid metadata and syntax.""" -import json import os import re -import sys from collections import defaultdict from pathlib import Path import eql -import jsonschema -import pytoml -import toml import kql from detection_rules import attack, beats, ecs from detection_rules.packaging import load_versions -from detection_rules.rule import TOMLRule, BaseQueryRuleData -from detection_rules.rule_loader import FILE_PATTERN, find_unneeded_defaults_from_rule +from detection_rules.rule import BaseQueryRuleData +from detection_rules.rule_loader import FILE_PATTERN from detection_rules.utils import get_path, load_etc_dump from rta import get_ttp_names from .base import BaseRuleTest @@ -31,19 +26,7 @@ class TestValidRules(BaseRuleTest): def test_schema_and_dupes(self): """Ensure that every rule matches the schema and there are no duplicates.""" - self.assertGreaterEqual(len(self.rule_files), 1, 'No rules were loaded from rules directory!') - - def test_all_rule_files(self): - """Ensure that every rule file can be loaded and validate against schema.""" - for file_name, contents in self.rule_files.items(): - try: - TOMLRule(file_name, contents) - except (pytoml.TomlError, toml.TomlDecodeError) as e: - print("TOML error when parsing rule file \"{}\"".format(os.path.basename(file_name)), file=sys.stderr) - raise e - except jsonschema.ValidationError as e: - print("Schema error when parsing rule file \"{}\"".format(os.path.basename(file_name)), file=sys.stderr) - raise e + self.assertGreaterEqual(len(self.all_rules), 1, 'No rules were loaded from rules directory!') def test_file_names(self): """Test that the file names meet the requirement.""" @@ -54,37 +37,21 @@ def test_file_names(self): self.assertIsNone(re.match(file_pattern, 'still_not_a_valid_file_name.not_json'), f'Incorrect pattern for verifying rule names: {file_pattern}') - for rule_file in self.rule_files.keys(): - self.assertIsNotNone(re.match(file_pattern, os.path.basename(rule_file)), - f'Invalid file name for {rule_file}') + for rule in self.all_rules: + file_name = str(rule.path.name) + self.assertIsNotNone(re.match(file_pattern, file_name), f'Invalid file name for {rule.path}') def test_all_rule_queries_optimized(self): """Ensure that every rule query is in optimized form.""" - for file_name, contents in self.rule_files.items(): - rule = TOMLRule(file_name, contents) - - if contents["rule"].get("langauge") == "kql": - source = contents["rule"]["query"] + for rule in self.production_rules: + if rule.contents.data.get("language") == "kql": + source = rule.contents.data.query tree = kql.parse(source, optimize=False) optimized = tree.optimize(recursive=True) err_message = f'\n{self.rule_str(rule)} Query not optimized for rule\n' \ f'Expected: {optimized}\nActual: {source}' self.assertEqual(tree, optimized, err_message) - def test_no_unrequired_defaults(self): - """Test that values that are not required in the schema are not set with default values.""" - rules_with_hits = {} - - for file_name, contents in self.rule_files.items(): - default_matches = find_unneeded_defaults_from_rule(contents) - - if default_matches: - rules_with_hits[f'{contents["rule"]["rule_id"]} - {contents["rule"]["name"]}'] = default_matches - - error_msg = f'The following rules have unnecessary default values set: ' \ - f'\n{json.dumps(rules_with_hits, indent=2)}' - self.assertDictEqual(rules_with_hits, {}, error_msg) - def test_production_rules_have_rta(self): """Ensure that all production rules have RTAs.""" mappings = load_etc_dump('rule-mapping.yml') @@ -103,9 +70,9 @@ def test_production_rules_have_rta(self): def test_duplicate_file_names(self): """Test that no file names are duplicated.""" name_map = defaultdict(list) - for file_path in self.rule_files: - base_name = os.path.basename(file_path) - name_map[base_name].append(file_path) + + for rule in self.all_rules: + name_map[rule.path.name].append(rule.path.name) duplicates = {name: paths for name, paths in name_map.items() if len(paths) > 1} if duplicates: @@ -121,7 +88,7 @@ def test_technique_deprecations(self): revoked = list(attack.revoked) deprecated = list(attack.deprecated) - for rule in self.rules: + for rule in self.all_rules: revoked_techniques = {} threat_mapping = rule.contents.data.threat @@ -138,7 +105,7 @@ def test_technique_deprecations(self): def test_tactic_to_technique_correlations(self): """Ensure rule threat info is properly related to a single tactic and technique.""" - for rule in self.rules: + for rule in self.all_rules: threat_mapping = rule.contents.data.threat or [] if threat_mapping: for entry in threat_mapping: @@ -196,7 +163,7 @@ def test_tactic_to_technique_correlations(self): def test_duplicated_tactics(self): """Check that a tactic is only defined once.""" - for rule in self.rules: + for rule in self.all_rules: threat_mapping = rule.contents.data.threat tactics = [t.tactic.name for t in threat_mapping or []] duplicates = sorted(set(t for t in tactics if tactics.count(t) > 1)) @@ -222,7 +189,7 @@ def normalize(s): ] expected_case = {normalize(t): t for t in expected_tags} - for rule in self.rules: + for rule in self.all_rules: rule_tags = rule.contents.data.tags if rule_tags: @@ -254,7 +221,7 @@ def test_required_tags(self): 'winlogbeat-*': {'all': ['Windows']} } - for rule in self.rules: + for rule in self.all_rules: rule_tags = rule.contents.data.tags error_msg = f'{self.rule_str(rule)} Missing tags:\nActual tags: {", ".join(rule_tags)}' @@ -292,7 +259,7 @@ def test_primary_tactic_as_tag(self): invalid = [] tactics = set(tactics) - for rule in self.rules: + for rule in self.all_rules: rule_tags = rule.contents.data.tags if 'Continuous Monitoring' in rule_tags or rule.contents.data.type == 'machine_learning': @@ -340,7 +307,7 @@ class TestRuleTimelines(BaseRuleTest): def test_timeline_has_title(self): """Ensure rules with timelines have a corresponding title.""" - for rule in self.rules: + for rule in self.all_rules: timeline_id = rule.contents.data.timeline_id timeline_title = rule.contents.data.timeline_title @@ -366,8 +333,8 @@ def test_rule_file_names_by_tactic(self): """Test to ensure rule files have the primary tactic prepended to the filename.""" bad_name_rules = [] - for rule in self.rules: - rule_path = Path(rule.path).resolve() + for rule in self.all_rules: + rule_path = rule.path.resolve() filename = rule_path.name if rule_path.parent.name == 'ml': @@ -394,7 +361,7 @@ class TestRuleMetadata(BaseRuleTest): def test_ecs_and_beats_opt_in_not_latest_only(self): """Test that explicitly defined opt-in validation is not only the latest versions to avoid stale tests.""" - for rule in self.rules: + for rule in self.all_rules: beats_version = rule.contents.metadata.beats_version ecs_versions = rule.contents.metadata.ecs_versions or [] latest_beats = str(beats.get_max_version()) @@ -413,7 +380,7 @@ def test_updated_date_newer_than_creation(self): """Test that the updated_date is newer than the creation date.""" invalid = [] - for rule in self.rules: + for rule in self.all_rules: created = rule.contents.metadata.creation_date.split('/') updated = rule.contents.metadata.updated_date.split('/') if updated < created: @@ -430,7 +397,7 @@ def test_deprecated_rules(self): deprecations = load_etc_dump('deprecated_rules.json') deprecated_rules = {} - for rule in self.rules: + for rule in self.all_rules: meta = rule.contents.metadata maturity = meta.maturity @@ -470,7 +437,7 @@ def test_event_override(self): """Test that rules have defined an timestamp_override if needed.""" missing = [] - for rule in self.rules: + for rule in self.all_rules: required = False if isinstance(rule.contents.data, BaseQueryRuleData) and 'endgame-*' in rule.contents.data.index: @@ -495,7 +462,7 @@ def test_required_lookback(self): long_indexes = {'logs-endpoint.events.*'} missing = [] - for rule in self.rules: + for rule in self.all_rules: contents = rule.contents if isinstance(contents.data, BaseQueryRuleData): diff --git a/tests/test_mappings.py b/tests/test_mappings.py index c0effcd74fd..48f749be7c1 100644 --- a/tests/test_mappings.py +++ b/tests/test_mappings.py @@ -5,20 +5,18 @@ """Test that all rules appropriately match against expected data sets.""" import copy -import unittest import warnings from detection_rules.rule import KQLRuleData from . import get_data_files, get_fp_data_files -from detection_rules import rule_loader from detection_rules.utils import combine_sources, evaluate, load_etc_dump +from .base import BaseRuleTest -class TestMappings(unittest.TestCase): +class TestMappings(BaseRuleTest): """Test that all rules appropriately match against expected data sets.""" FP_FILES = get_fp_data_files() - RULES = rule_loader.load_rules().values() def evaluate(self, documents, rule, expected, msg): """KQL engine to evaluate.""" @@ -31,7 +29,7 @@ def test_true_positives(self): mismatched_ecs = [] mappings = load_etc_dump('rule-mapping.yml') - for rule in rule_loader.get_production_rules(): + for rule in self.production_rules: if isinstance(rule.contents.data, KQLRuleData): if rule.id not in mappings: continue @@ -64,7 +62,7 @@ def test_true_positives(self): def test_false_positives(self): """Test that expected results return against false positives.""" - for rule in rule_loader.get_production_rules(): + for rule in self.production_rules: if isinstance(rule.contents.data, KQLRuleData): for fp_name, merged_data in get_fp_data_files().items(): msg = 'Unexpected FP match for: {} - {}, against: {}'.format(rule.id, rule.name, fp_name) diff --git a/tests/test_packages.py b/tests/test_packages.py index 6c717a2c751..7ec2225aba9 100644 --- a/tests/test_packages.py +++ b/tests/test_packages.py @@ -9,12 +9,13 @@ from detection_rules import rule_loader from detection_rules.packaging import PACKAGE_FILE, Package - +from detection_rules.rule_loader import RuleCollection +from tests.base import BaseRuleTest package_configs = Package.load_configs() -class TestPackages(unittest.TestCase): +class TestPackages(BaseRuleTest): """Test package building and saving.""" @staticmethod @@ -52,10 +53,9 @@ def test_package_loader_default_configs(self): """Test configs in etc/packages.yml.""" Package.from_config(package_configs) - @rule_loader.mock_loader def test_package_summary(self): """Test the generation of the package summary.""" - rules = rule_loader.get_production_rules() + rules = self.production_rules package = Package(rules, 'test-package') changed_rule_ids, new_rule_ids, deprecated_rule_ids = package.bump_versions(save_changes=False) package.generate_summary_and_changelog(changed_rule_ids, new_rule_ids, deprecated_rule_ids) @@ -90,11 +90,10 @@ def test_package_summary(self): # self.assertEqual(0, len(new_rules), 'Package version bumping is improperly detecting new rules') # self.assertEqual(2, package.rules[0].contents['version'], 'Package version not bumping on changes') - @rule_loader.mock_loader def test_rule_versioning(self): """Test that all rules are properly versioned and tracked""" self.maxDiff = None - rules = rule_loader.load_rules().values() + rules = RuleCollection.default() original_hashes = [] post_bump_hashes = []