diff --git a/ssm-diff b/ssm-diff index 3a99532..66014cd 100755 --- a/ssm-diff +++ b/ssm-diff @@ -2,15 +2,27 @@ from __future__ import print_function import argparse +import logging import os +import sys from states import * +root = logging.getLogger() +root.setLevel(logging.INFO) + +handler = logging.StreamHandler(sys.stdout) +handler.setLevel(logging.INFO) +formatter = logging.Formatter('%(name)s - %(message)s') +handler.setFormatter(formatter) +root.addHandler(handler) + def configure_endpoints(args): # configure() returns a DiffBase class (whose constructor may be wrapped in `partial` to pre-configure it) diff_class = DiffBase.get_plugin(args.engine).configure(args) - return storage.ParameterStore(args.profile, diff_class, paths=args.path), storage.YAMLFile(args.filename, paths=args.path) + return storage.ParameterStore(args.profile, diff_class, paths=args.paths, no_secure=args.no_secure), \ + storage.YAMLFile(args.filename, paths=args.paths, no_secure=args.no_secure, root_path=args.yaml_root) def init(args): @@ -39,18 +51,12 @@ def apply(args): def plan(args): """Print a representation of the changes that would be applied to SSM Parameter Store if applied (per config in args)""" remote, local = configure_endpoints(args) - diff = remote.dry_run(local.get()) - - if diff.differ: - print(DiffBase.describe_diff(diff.plan)) - else: - print("Remote state is up to date.") + print(DiffBase.describe_diff(remote.dry_run(local.get()))) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-f', help='local state yml file', action='store', dest='filename', default='parameters.yml') - parser.add_argument('--path', '-p', action='append', help='filter SSM path') + parser.add_argument('-f', help='local state yml file', action='store', dest='filename') parser.add_argument('--engine', '-e', help='diff engine to use when interacting with SSM', action='store', dest='engine', default='DiffResolver') parser.add_argument('--profile', help='AWS profile name', action='store', dest='profile') subparsers = parser.add_subparsers(dest='func', help='commands') @@ -70,12 +76,29 @@ if __name__ == "__main__": parser_apply.set_defaults(func=apply) args = parser.parse_args() - args.path = args.path if args.path else ['/'] - - if args.filename == 'parameters.yml': - if not args.profile: - if 'AWS_PROFILE' in os.environ: - args.filename = os.environ['AWS_PROFILE'] + '.yml' - else: - args.filename = args.profile + '.yml' + + args.no_secure = os.environ.get('SSM_NO_SECURE', 'false').lower() in ['true', '1'] + args.yaml_root = os.environ.get('SSM_YAML_ROOT', '/') + args.paths = os.environ.get('SSM_PATHS', None) + if args.paths is not None: + args.paths = args.paths.split(';:') + else: + # this defaults to '/' + args.paths = args.yaml_root + + # root filename + if args.filename is not None: + filename = args.filename + elif args.profile: + filename = args.profile + elif 'AWS_PROFILE' in os.environ: + filename = os.environ['AWS_PROFILE'] + else: + filename = 'parameters' + + # remove extension (will be restored by storage classes) + if filename[-4:] == '.yml': + filename = filename[:-4] + args.filename = filename + args.func(args) diff --git a/states/engine.py b/states/engine.py index 67b09fa..4550df2 100644 --- a/states/engine.py +++ b/states/engine.py @@ -1,5 +1,6 @@ import collections import logging +import re from functools import partial from termcolor import colored @@ -38,24 +39,24 @@ def configure(cls, args): @classmethod def _flatten(cls, d, current_path='', sep='/'): """Convert a nested dict structure into a "flattened" dict i.e. {"full/path": "value", ...}""" - items = [] - for k in d: + items = {} + for k, v in d.items(): new = current_path + sep + k if current_path else k - if isinstance(d[k], collections.MutableMapping): - items.extend(cls._flatten(d[k], new, sep=sep).items()) + if isinstance(v, collections.MutableMapping): + items.update(cls._flatten(v, new, sep=sep).items()) else: - items.append((sep + new, d[k])) - return dict(items) + items[sep + new] = v + return items @classmethod def _unflatten(cls, d, sep='/'): """Converts a "flattened" dict i.e. {"full/path": "value", ...} into a nested dict structure""" output = {} - for k in d: + for k, v in d.items(): add( obj=output, path=k, - value=d[k], + value=v, sep=sep, ) return output @@ -66,15 +67,18 @@ def describe_diff(cls, plan): description = "" for k, v in plan['add'].items(): # { key: new_value } - description += colored("+", 'green'), "{} = {}".format(k, v) + '\n' + description += colored("+", 'green') + "{} = {}".format(k, v) + '\n' for k in plan['delete']: # { key: old_value } - description += colored("-", 'red'), k + '\n' + description += colored("-", 'red') + k + '\n' for k, v in plan['change'].items(): # { key: {'old': value, 'new': value} } - description += colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, v['old'], v['new']) + '\n' + description += colored("~", 'yellow') + "{}:\n\t< {}\n\t> {}".format(k, v['old'], v['new']) + '\n' + + if description == "": + description = "No Changes Detected" return description diff --git a/states/helpers.py b/states/helpers.py index 046502c..a767982 100644 --- a/states/helpers.py +++ b/states/helpers.py @@ -5,24 +5,31 @@ def add(obj, path, value, sep='/'): """Add value to the `obj` dict at the specified path""" parts = path.strip(sep).split(sep) last = len(parts) - 1 + current = obj for index, part in enumerate(parts): if index == last: - obj[part] = value + current[part] = value else: - obj = obj.setdefault(part, {}) + current = current.setdefault(part, {}) + # convenience return, object is mutated + return obj def search(state, path): - result = state + """Get value in `state` at the specified path, returning {} if the key is absent""" + if path.strip("/") == '': + return state for p in path.strip("/").split("/"): - if result.clone(p): - result = result[p] - else: - result = {} - break - output = {} - add(output, path, result) - return output + if p not in state: + return {} + state = state[p] + return state + + +def filter(state, path): + if path.strip("/") == '': + return state + return add({}, path, search(state, path)) def merge(a, b): diff --git a/states/storage.py b/states/storage.py index e56655e..4c425ca 100644 --- a/states/storage.py +++ b/states/storage.py @@ -1,13 +1,16 @@ from __future__ import print_function +import logging +import re import sys +from copy import deepcopy import boto3 import termcolor import yaml from botocore.exceptions import ClientError, NoCredentialsError -from .helpers import merge, add, search +from .helpers import merge, add, filter, search def str_presenter(dumper, data): @@ -62,18 +65,34 @@ def to_yaml(cls, dumper, data): class YAMLFile(object): """Encodes/decodes a dictionary to/from a YAML file""" - def __init__(self, filename, paths=('/',)): - self.filename = filename + METADATA_CONFIG = 'ssm-diff:config' + METADATA_PATHS = 'ssm-diff:paths' + METADATA_ROOT = 'ssm:root' + METADATA_NO_SECURE = 'ssm:no-secure' + + def __init__(self, filename, paths=('/',), root_path='/', no_secure=False): + self.filename = '{}.yml'.format(filename) + self.root_path = root_path self.paths = paths + self.validate_paths() + self.no_secure = no_secure + + def validate_paths(self): + length = len(self.root_path) + for path in self.paths: + if path[:length] != self.root_path: + raise ValueError('Root path {} does not contain path {}'.format(self.root_path, path)) def get(self): try: output = {} with open(self.filename, 'rb') as f: local = yaml.safe_load(f.read()) + self.validate_config(local) + local = self.nest_root(local) for path in self.paths: if path.strip('/'): - output = merge(output, search(local, path)) + output = merge(output, filter(local, path)) else: return local return output @@ -87,7 +106,55 @@ def get(self): return dict() raise + def validate_config(self, local): + """YAML files may contain a special ssm:config tag that stores information about the file when it was generated. + This information can be used to ensure the file is compatible with future calls. For example, a file created + with a particular subpath (e.g. /my/deep/path) should not be used to overwrite the root path since this would + delete any keys not in the original scope. This method does that validation (with permissive defaults for + backwards compatibility).""" + config = local.pop(self.METADATA_CONFIG, {}) + + # strict requirement that the no_secure setting is equal + config_no_secure = config.get(self.METADATA_NO_SECURE, False) + if config_no_secure != self.no_secure: + raise ValueError("YAML file generated with no_secure={} but current class set to no_secure={}".format( + config_no_secure, self.no_secure, + )) + # strict requirement that root_path is equal + config_root = config.get(self.METADATA_ROOT, '/') + if config_root != self.root_path: + raise ValueError("YAML file generated with root_path={} but current class set to root_path={}".format( + config_root, self.root_path, + )) + # make sure all paths are subsets of file paths + config_paths = config.get(self.METADATA_PATHS, ['/']) + for path in self.paths: + for config_path in config_paths: + # if path is not found in a config path, it could look like we've deleted values + if path[:len(config_path)] == config_path: + break + else: + raise ValueError("Path {} was not included in this file when it was created.".format(path)) + + def unnest_root(self, state): + if self.root_path == '/': + return state + return search(state, self.root_path) + + def nest_root(self, state): + if self.root_path == '/': + return state + return add({}, self.root_path, state) + def save(self, state): + state = self.unnest_root(state) + # inject state information so we can validate the file on load + # colon is not allowed in SSM keys so this namespace cannot collide with keys at any depth + state[self.METADATA_CONFIG] = { + self.METADATA_PATHS: self.paths, + self.METADATA_ROOT: self.root_path, + self.METADATA_NO_SECURE: self.no_secure + } try: with open(self.filename, 'wb') as f: content = yaml.safe_dump(state, default_flow_style=False) @@ -99,12 +166,24 @@ def save(self, state): class ParameterStore(object): """Encodes/decodes a dict to/from the SSM Parameter Store""" - def __init__(self, profile, diff_class, paths=('/',)): + invalid_characters = r'[^a-zA-Z0-9\-_\./]' + + def __init__(self, profile, diff_class, paths=('/',), no_secure=False): + self.logger = logging.getLogger(self.__class__.__name__) if profile: boto3.setup_default_session(profile_name=profile) self.ssm = boto3.client('ssm') self.diff_class = diff_class self.paths = paths + self.parameter_filters = [] + if no_secure: + self.parameter_filters.append({ + 'Key': 'Type', + 'Option': 'Equals', + 'Values': [ + 'String', 'StringList', + ] + }) def clone(self): p = self.ssm.get_paginator('get_parameters_by_path') @@ -114,7 +193,9 @@ def clone(self): for page in p.paginate( Path=path, Recursive=True, - WithDecryption=True): + WithDecryption=True, + ParameterFilters=self.parameter_filters, + ): for param in page['Parameters']: add(obj=output, path=param['Name'], @@ -126,7 +207,11 @@ def clone(self): # noinspection PyMethodMayBeStatic def _read_param(self, value, ssm_type='String'): - return SecureTag(value) if ssm_type == 'SecureString' else str(value) + if ssm_type == 'SecureString': + value = SecureTag(value) + elif ssm_type == 'StringList': + value = value.split(',') + return value def pull(self, local): diff = self.diff_class( @@ -135,35 +220,75 @@ def pull(self, local): ) return diff.merge() + @classmethod + def coerce_state(cls, state, path='/', sep='/'): + errors = {} + for k, v in state.items(): + if re.search(cls.invalid_characters, k) is not None: + errors[path+sep+k]: 'Invalid Key' + continue + if isinstance(v, dict): + errors.update(cls.coerce_state(v, path=path + sep + k)) + elif isinstance(v, list): + list_errors = [] + for item in v: + if not isinstance(item, str): + list_errors.append('list items must be strings: {}'.format(repr(item))) + elif re.search(r'[,]', item) is not None: + list_errors.append("StringList is comma separated so items may not contain commas: {}".format(item)) + if list_errors: + errors[path+sep+k] = list_errors + elif isinstance(v, (str, SecureTag)): + continue + elif isinstance(v, (int, float, type(None))): + state[k] = str(v) + else: + errors[path+sep+k] = 'Cannot coerce type {}'.format(type(v)) + return errors + def dry_run(self, local): - return self.diff_class(self.clone(), local).plan + working = deepcopy(local) + errors = self.coerce_state(working) + if errors: + raise ValueError('Errors during dry run:\n{}'.format(errors)) + plan = self.diff_class(self.clone(), working).plan + return plan + + def prepare_value(self, value): + if isinstance(value, list): + ssm_type = 'StringList' + value = ','.join(value) + elif isinstance(value, SecureTag): + ssm_type = 'SecureString' + else: + value = repr(value) + ssm_type = 'String' + return ssm_type, value def push(self, local): plan = self.dry_run(local) # plan for k, v in plan['add'].items(): + self.logger.info('add: {}'.format(k)) # { key: new_value } - ssm_type = 'String' - if isinstance(v, list): - ssm_type = 'StringList' - if isinstance(v, SecureTag): - ssm_type = 'SecureString' + ssm_type, v = self.prepare_value(v) self.ssm.put_parameter( Name=k, - Value=repr(v) if type(v) == SecureTag else str(v), + Value=v, Type=ssm_type) - for k in plan['delete']: - # { key: old_value } - self.ssm.delete_parameter(Name=k) - for k, delta in plan['change']: + self.logger.info('change: {}'.format(k)) # { key: {'old': value, 'new': value} } - v = delta['new'] - ssm_type = 'SecureString' if isinstance(v, SecureTag) else 'String' + ssm_type, v = self.prepare_value(delta['new']) self.ssm.put_parameter( Name=k, - Value=repr(v) if type(v) == SecureTag else str(v), + Value=v, Overwrite=True, Type=ssm_type) + + for k in plan['delete']: + self.logger.info('delete: {}'.format(k)) + # { key: old_value } + self.ssm.delete_parameter(Name=k) diff --git a/states/tests.py b/states/tests.py index 3bef0b1..3d1addd 100644 --- a/states/tests.py +++ b/states/tests.py @@ -1,12 +1,14 @@ +import random +import string from unittest import TestCase, mock -from . import engine +from . import engine, storage -class FlatDictDiffer(TestCase): - +class DiffBaseFlatten(TestCase): + """Verifies the behavior of the _flatten and _unflatten methods""" def setUp(self) -> None: - self.obj = engine.DiffResolver({}, {}) + self.obj = engine.DiffBase({}, {}) def test_flatten_single(self): nested = { @@ -62,6 +64,7 @@ def test_flatten_nested_sep(self): class DiffResolverMerge(TestCase): + """Verifies that the `merge` method produces the expected output""" def test_add_remote(self): """Remote additions should be added to local""" @@ -236,3 +239,309 @@ def test_delete(self): }, diff.plan ) + + +class YAMLFileValidatePaths(TestCase): + """YAMLFile calls `validate_paths` in `__init__` to ensure the root and paths are compatible""" + def test_validate_paths_invalid(self): + with self.assertRaises(ValueError): + storage.YAMLFile(filename='unused', root_path='/one/branch', paths=['/another/branch']) + + def test_validate_paths_valid_same(self): + self.assertIsInstance( + storage.YAMLFile(filename='unused', root_path='/one/branch', paths=['/one/branch']), + storage.YAMLFile, + ) + + def test_validate_paths_valid_child(self): + self.assertIsInstance( + storage.YAMLFile(filename='unused', root_path='/one/branch', paths=['/one/branch/child']), + storage.YAMLFile, + ) + + +class YAMLFileMetadata(TestCase): + """Verifies that exceptions are thrown if the metadata in the target file is incompatible with the class configuration""" + def test_get_methods(self): + """Make sure we use the methods mocked by other tests""" + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + provider = storage.YAMLFile(filename=filename, no_secure=True) + with mock.patch('states.storage.open') as open_, mock.patch('states.storage.yaml') as yaml, \ + mock.patch.object(provider, 'validate_config'): + self.assertEqual( + provider.get(), + yaml.safe_load.return_value, + ) + open_.assert_called_once_with( + filename + '.yml', 'rb' + ) + yaml.safe_load.assert_called_once_with( + open_.return_value.__enter__.return_value.read.return_value + ) + + def test_get_invalid_no_secure(self): + """Exception should be raised if the secure metadata in the file does not match the instance""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_NO_SECURE: False + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + provider = storage.YAMLFile(filename=filename, no_secure=True) + + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + with self.assertRaises(ValueError): + provider.get() + + def test_get_valid_no_secure(self): + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_NO_SECURE: False + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + provider = storage.YAMLFile(filename=filename, no_secure=False) + + with mock.patch('states.storage.open') as open_, mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + self.assertEqual( + provider.get(), + yaml.safe_load.return_value, + ) + + def test_get_valid_no_secure_true(self): + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_NO_SECURE: True + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + provider = storage.YAMLFile(filename=filename, no_secure=True) + + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + self.assertEqual( + provider.get(), + yaml.safe_load.return_value, + ) + + def test_get_invalid_root(self): + """Exception should be raised if the root metadata in the file does not match the instance""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_ROOT: '/' + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + with mock.patch.object(storage.YAMLFile, 'validate_paths'): + provider = storage.YAMLFile(filename=filename, root_path='/another') + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + with self.assertRaises(ValueError): + provider.get() + + def test_get_valid_root(self): + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_ROOT: '/same' + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + with mock.patch.object(storage.YAMLFile, 'validate_paths'): + provider = storage.YAMLFile(filename=filename, root_path='/same') + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml, \ + mock.patch.object(provider, 'nest_root'): + yaml.safe_load.return_value = yaml_contents + provider.get() + + def test_get_invalid_paths(self): + """Exception should be raised if the paths metadata is incompatible with the instance""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_PATHS: ['/limited'] + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths='/') + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + with self.assertRaises(ValueError): + provider.get() + + def test_get_invalid_paths_mixed(self): + """A single invalid path should fail even in the presence of multiple matching paths""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_PATHS: ['/limited'] + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/', '/limited']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + with self.assertRaises(ValueError): + provider.get() + + def test_get_invalid_paths_multiple(self): + """Multiple invalid paths should fail""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_PATHS: ['/limited'] + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/', '/another']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + with self.assertRaises(ValueError): + provider.get() + + def test_get_valid_paths_same(self): + """The same path is valid""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_PATHS: ['/'] + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + provider.get() + + def test_get_valid_paths_child(self): + """A descendant (child) of a path is valid since it's contained in the original""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_PATHS: ['/'] + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/child']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + provider.get() + + def test_get_valid_paths_child_multiple(self): + """Multiple descendant (child) of a path is valid since it's contained in the original""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_PATHS: ['/'] + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/child', '/another_child']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + provider.get() + + def test_get_valid_paths_default_nested(self): + """The default path is '/' so it should be valid for anything""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/child']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + provider.get() + + def test_get_valid_paths_default_root(self): + """The default path is '/' so it should be valid for anything""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + provider.get() + + +class YAMLFileRoot(TestCase): + """Verify that the `root_path` config works as expected""" + def test_unnest_path(self): + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + # must match root_path of object to pass checks + storage.YAMLFile.METADATA_ROOT: '/nested/path' + }, + 'key': 'value' + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, root_path='/nested/path', paths=['/nested/path']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + self.assertEqual( + { + 'nested': { + 'path': { + 'key': 'value' + } + } + }, + provider.get(), + ) + + def test_nest_path(self): + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, root_path='/nested/path', paths=['/nested/path']) + + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + provider.save({ + 'nested': { + 'path': { + 'key': 'value' + } + } + }) + + yaml.safe_dump.assert_called_once_with( + { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_ROOT: '/nested/path', + storage.YAMLFile.METADATA_PATHS: ['/nested/path'], + storage.YAMLFile.METADATA_NO_SECURE: False, + }, + 'key': 'value' + }, + # appears to replicate a default, but included in the current code + default_flow_style=False + )