diff --git a/.github/workflows/test_command_scene_configHelper.yml b/.github/workflows/test_command_scene_configHelper.yml new file mode 100644 index 00000000..e5647e0a --- /dev/null +++ b/.github/workflows/test_command_scene_configHelper.yml @@ -0,0 +1,31 @@ +# common包下command、scene、config_helper的测试用例 +name: Test command_scene_configHelper + +on: + push: + branches: "*" + pull_request: + branches: "*" + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 # Fetch all history for proper version detection + + - name: Set up Python 3.8 + uses: actions/setup-python@v3 + with: + python-version: 3.8 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements3.txt + + - name: Run tests + run: python -m unittest discover -s test/common -p 'test_*.py' diff --git a/.github/workflows/test_ssh_client.yml b/.github/workflows/test_ssh_client.yml new file mode 100644 index 00000000..fb62bc29 --- /dev/null +++ b/.github/workflows/test_ssh_client.yml @@ -0,0 +1,30 @@ +name: Test Ssh Client + +on: + push: + branches: "dev*" + pull_request: + branches: "dev*" + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 # Fetch all history for proper version detection + + - name: Set up Python 3.8 + uses: actions/setup-python@v3 + with: + python-version: 3.8 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements3.txt + + - name: Run tests + run: python -m unittest discover -s test/common/ssh_client -p 'test_*.py' \ No newline at end of file diff --git a/common/ob_log_parser.py b/common/ob_log_parser.py new file mode 100644 index 00000000..54ebef96 --- /dev/null +++ b/common/ob_log_parser.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -* +# Copyright (c) 2022 OceanBase +# OceanBase Diagnostic Tool is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +""" +@time: 2024/8/1 +@file: ob_log_parser.py +@desc: +""" + +import re + +OceanbaseObjDict = { + 'ObDMLBaseParam': [ + 'timeout', + 'schema_version', + 'sql_mode', + 'is_total_quantity_log', + 'table_param', + 'tenant_schema_version', + 'is_ignore', + 'prelock', + 'encrypt_meta', + 'is_batch_stmt', + 'write_flag', + 'spec_seq_no', + 'snapshot', + 'branch_id', + 'check_schema_version', + ], + 'ObStoreCtx': ['this', 'ls_id', 'ls', 'branch', 'timeout', 'tablet_id', 'table_iter', 'table_version', 'mvcc_acc_ctx', 'tablet_stat', 'is_read_store_ctx'], + 'ObTableDMLParam': ['tenant_schema_version', 'data_table', 'col_descs', 'col_map'], + 'ObTableSchemaParam': ['table_id', 'schema_version', 'table_type', 'index_type', 'index_status', 'shadow_rowkey_column_num', 'fulltext_col_id', 'index_name', 'pk_name', 'columns', 'read_info', 'lob_inrow_threshold'], + 'ObMemtable': [ + 'ObITable', + 'this', + 'timestamp', + 'state', + 'freeze_clock', + 'max_schema_version', + 'max_data_schema_version', + 'max_column_cnt', + 'write_ref_cnt', + 'local_allocator', + 'unsubmitted_cnt', + 'unsynced_cnt', + 'logging_blocked', + 'unset_active_memtable_logging_blocked', + 'resolve_active_memtable_left_boundary', + 'contain_hotspot_row', + 'max_end_scn', + 'rec_scn', + 'snapshot_version', + 'migration_clog_checkpoint_scn', + 'is_tablet_freeze', + 'is_force_freeze', + ['contain_hotspot_row', 'contain_hotspot_row2'], + 'read_barrier', + 'is_flushed', + 'freeze_state', + 'allow_freeze', + ['mt_stat_.frozen_time', 'frozen_time'], + ['mt_stat_.ready_for_flush_time', 'ready_for_flush_time'], + ['mt_stat_.create_flush_dag_time', 'create_flush_dag_time'], + ['mt_stat_.release_time', 'release_time'], + ['mt_stat_.push_table_into_gc_queue_time', 'push_table_into_gc_queue_time'], + ['mt_stat_.last_print_time', 'last_print_time'], + 'ls_id', + 'transfer_freeze_flag', + 'recommend_snapshot_version', + ], + 'ObMemtable2': [ + 'ObITabletMemtable', + 'this', + 'state', + 'max_data_schema_version', + 'max_column_cnt', + 'local_allocator', + 'contain_hotspot_row', + 'snapshot_version', + ['contain_hotspot_row', 'contain_hotspot_row2'], + 'ls_id', + 'transfer_freeze_flag', + 'recommend_snapshot_version', + ], + 'ObITabletMemtable': [ + 'ObITable', + 'ls_id_', + 'allow_freeze_', + 'is_flushed_', + 'is_tablet_freeze_', + 'logging_blocked_', + 'resolved_active_memtable_left_boundary_', + 'unset_active_memtable_logging_blocked_', + 'has_backoffed_', + 'read_barrier_', + 'freeze_clock_', + 'freeze_state_', + 'unsubmitted_cnt_', + 'init_timestamp_', + 'max_schema_version_', + 'write_ref_cnt_', + 'max_end_scn_', + 'rec_scn_', + 'freeze_scn_', + 'migration_clog_checkpoint_scn_', + 'freezer_', + 'memtable_mgr_handle_', + ['mt_stat_.frozen_time_', 'frozen_time'], + ['mt_stat_.ready_for_flush_time_', 'ready_for_flush_time'], + ['mt_stat_.create_flush_dag_time_', 'create_flush_dag_time'], + ['mt_stat_.release_time_', 'release_time'], + ['mt_stat_.push_table_into_gc_queue_time_', 'push_table_into_gc_queue_time'], + ['mt_stat_.last_print_time_', 'last_print_time'], + ], + 'ObDagTypeStruct': ['init_dag_prio', 'sys_task_type', 'dag_type_str', 'dag_module_str'], + 'ObTabletMergeInfo': ['is_inited', 'sstable_merge_info', 'sstable_builder'], + 'ObSSTableMergeInfo': [ + 'tenant_id', + 'ls_id', + 'tablet_id', + 'compaction_scn', + 'merge_type', + 'merge_cost_time', + 'merge_start_time', + 'merge_finish_time', + 'dag_id', + 'occupy_size', + 'new_flush_occupy_size', + 'original_size', + 'compressed_size', + 'macro_block_count', + 'multiplexed_macro_block_count', + 'new_micro_count_in_new_macro', + 'multiplexed_micro_count_in_new_macro', + 'total_row_count', + 'incremental_row_count', + 'new_flush_data_rate', + 'is_full_merge', + 'progressive_merge_round', + 'progressive_merge_num', + 'concurrent_cnt', + 'start_cg_idx', + 'end_cg_idx', + 'suspect_add_time', + 'early_create_time', + 'dag_ret', + 'retry_cnt', + 'task_id', + 'error_location', + 'kept_snapshot_info', + 'merge_level', + 'parallel_merge_info', + 'filter_statistics', + 'participant_table_info', + 'macro_id_list', + 'comment', + ], + 'SCN1': ['val'], + 'SCN': ['val', 'v'], + 'ObLSID': ['id'], +} + +OceanbaseObjDetailDict = { + 'ObDMLBaseParam': { + 'table_param': 'ObTableDMLParam', + }, + 'ObTableDMLParam': { + 'data_table': 'ObTableSchemaParam', + }, + 'ObMemtable2': { + 'ObITabletMemtable': 'ObITabletMemtable', + }, + 'ObTabletMergeInfo': { + 'sstable_merge_info': 'ObSSTableMergeInfo', + }, +} + +OceanbaseObjCompilePattern = {} + +OceanbaseLogVarDict = { + 'Main4377Log': ['column_id', 'storage_old_row', 'sql_old_row', 'dml_param', 'dml_flag', 'store_ctx', 'relative_table'], + 'OldestFrozenLog': ['list'], + 'DumpDagStatusLog': ['type', 'dag_count', 'running_dag_count', 'added_dag_count', 'scheduled_dag_count', 'scheduled_task_count', 'scheduled_data_size'], + 'TenantMemoryLog': [ + 'tenant_id', + 'now', + 'active_memstore_used', + 'total_memstore_used', + 'total_memstore_hold', + 'memstore_freeze_trigger_limit', + 'memstore_limit', + 'mem_tenant_limit', + 'mem_tenant_hold', + 'max_mem_memstore_can_get_now', + 'memstore_alloc_pos', + 'memstore_frozen_pos', + 'memstore_reclaimed_pos', + ], + 'MergeFinishLog': ['ret', 'merge_info', 'sstable', 'mem_peak', 'compat_mode', 'time_guard'], + 'ClogDiskFullLog': [ + 'msg', + 'ret', + ['total_size\(MB\)', 'total_size'], + ['used_size\(MB\)', 'used_size'], + ['used_percent\(%\)', 'used_percent'], + ['warn_size\(MB\)', 'warn_size'], + ['warn_percent\(%\)', 'warn_percent'], + ['limit_size\(MB\)', 'limit_size'], + ['limit_percent\(%\)', 'limit_percent'], + ['maximum_used_size\(MB\)', 'maximum_used_size'], + 'maximum_log_stream', + 'oldest_log_stream', + 'oldest_scn', + ], + 'ClogDiskFullLog2': [ + 'msg', + 'ret', + ['total_size\(MB\)', 'total_size'], + ['used_size\(MB\)', 'used_size'], + ['used_percent\(%\)', 'used_percent'], + ['warn_size\(MB\)', 'warn_size'], + ['warn_percent\(%\)', 'warn_percent'], + ['limit_size\(MB\)', 'limit_size'], + ['limit_percent\(%\)', 'limit_percent'], + ['total_unrecyclable_size_byte\(MB\)', 'total_unrecyclable_size_byte'], + ['maximum_used_size\(MB\)', 'maximum_used_size'], + 'maximum_log_stream', + 'oldest_log_stream', + 'oldest_scn', + 'in_shrinking', + ], + 'ClogCPTNoChangeLog': ['checkpoint_scn', 'checkpoint_scn_in_ls_meta', 'ls_id', 'service_type'], + 'LSReplayStatLog': ['id', 'replayed_log_size', 'unreplayed_log_size'], +} + +OceanbaseLogVarObjDict = { + 'Main4377Log': { + 'dml_param': 'ObDMLBaseParam', + 'store_ctx': 'ObStoreCtx', + }, + 'OldestFrozenLog': { + 'list': 'not_standard_obj_list', + }, + 'DumpDagStatusLog': { + 'type': 'ObDagTypeStruct', + }, + 'MergeFinishLog': { + 'merge_info': 'ObTabletMergeInfo', + }, + 'ClogDiskFullLog': { + 'oldest_scn': 'SCN1', + }, + 'ClogDiskFullLog2': {'oldest_scn': 'SCN'}, + 'ClogCPTNoChangeLog': { + 'checkpoint_scn': 'SCN', + 'checkpoint_scn_in_ls_meta': 'SCN', + 'ls_id': 'ObLSID', + }, + 'LSReplayStatLog': { + 'id': 'ObLSID', + }, +} + +OceanbaseLogVarCompilePattern = {} + + +class ObLogParser: + compiled_log_pattern = None + compiled_raw_log_pattern = None + + @staticmethod + def get_obj_list(list_str): + # will split with the {} + res = [] + depth = 0 + obj_start = None + for i, char in enumerate(list_str): + if char == '{': + if depth == 0: + # find a Object start position + obj_start = i + depth += 1 + elif char == '}': + depth -= 1 + if depth == 0 and obj_start is not None: + res.append(list_str[obj_start : i + 1]) + obj_start = None + return res + + @staticmethod + def get_obj_key_list(obj_str): + # will split with the {} + key_list = [] + depth = 0 + key_start = None + for i, char in enumerate(obj_str): + if char == '{': + if depth == 0 and key_start is None: + key_start = i + 1 + depth += 1 + elif char == '}': + depth -= 1 + elif char == ',': + if depth == 1: + # 1 for , 1 for ' ' + key_start = i + 2 + elif char == ':': + if depth == 1 and key_start is not None: + key_list.append(obj_str[key_start:i]) + key_start = None + return key_list + + @staticmethod + def get_obj_parser_pattern(key_list): + parray = [] + for k in key_list: + if isinstance(k, list): + tp = '({0}:(?P<{1}>.*))'.format(k[0], k[1]) + else: + replace_list = ['.', '(', ')'] + r_k = k + for ri in replace_list: + r_k = r_k.replace(ri, '_') + s_k = k + s_k = s_k.replace('(', '\(') + s_k = s_k.replace(')', '\)') + tp = '({0}:(?P<{1}>.*))'.format(s_k, r_k) + parray.append(tp) + p = '\s*\,\s*'.join(parray) + '\}' + return p + + @staticmethod + def get_log_var_parser_pattern(key_list): + parray = [] + for k in key_list: + if isinstance(k, list): + tp = '({0}=(?P<{1}>.*))'.format(k[0], k[1]) + else: + tp = '({0}=(?P<{0}>.*))'.format(k) + parray.append(tp) + p = '\s*\,\s*'.join(parray) + '\)' + return p + + @staticmethod + def get_raw_log_var_parser_pattern(key_list): + parray = [] + for k in key_list: + if isinstance(k, list): + tp = '({0}=(?P<{1}>.*))'.format(k[0], k[1]) + else: + tp = '({0}=(?P<{0}>.*))'.format(k) + parray.append(tp) + p = '\s*\ \s*'.join(parray) + return p + + @staticmethod + def parse_obj_detail(obj_name, obj_dict): + # parse all the child str to child obj + obj_detail_dict = OceanbaseObjDetailDict.get(obj_name, None) + if not obj_detail_dict: + print('{} obj detail cannot be parsed'.format(obj_name)) + else: + for k in obj_dict.keys(): + child_obj_name = obj_detail_dict.get(k, None) + if child_obj_name: + td = ObLogParser.parse_obj(child_obj_name, obj_dict[k]) + obj_dict[k] = td + ObLogParser.parse_obj_detail(child_obj_name, obj_dict[k]) + + @staticmethod + def parse_obj_detail_v2(obj_name, obj_dict): + # parse all the child str to child obj + obj_detail_dict = OceanbaseObjDetailDict.get(obj_name, None) + if not obj_detail_dict: + # parse all the detail if it start with { + for k in obj_dict.keys(): + if obj_dict[k].startswith('{'): + td = ObLogParser.parse_obj_v2(k, obj_dict[k]) + obj_dict[k] = td + ObLogParser.parse_obj_detail_v2(k, obj_dict[k]) + else: + for k in obj_dict.keys(): + child_obj_name = obj_detail_dict.get(k, None) + if child_obj_name: + td = ObLogParser.parse_obj(child_obj_name, obj_dict[k]) + obj_dict[k] = td + ObLogParser.parse_obj_detail(child_obj_name, obj_dict[k]) + + @staticmethod + def parse_obj(obj_name, obj_str): + d = dict() + key_list = OceanbaseObjDict.get(obj_name, []) + if len(key_list) == 0: + print('{} obj cannot be parsed'.format(obj_name)) + else: + p = OceanbaseObjCompilePattern.get(obj_name, None) + if p is None: + tp = ObLogParser.get_obj_parser_pattern(key_list) + OceanbaseObjCompilePattern[obj_name] = re.compile(tp) + p = OceanbaseObjCompilePattern[obj_name] + m = p.finditer(obj_str) + for i in m: + d.update(i.groupdict()) + return d + + @staticmethod + def parse_obj_v2(obj_name, obj_str): + is_tmp_pattern = False + d = dict() + key_list = OceanbaseObjDict.get(obj_name, []) + if len(key_list) == 0: + is_tmp_pattern = True + key_list = ObLogParser.get_obj_key_list(obj_str) + if len(key_list) != 0: + p = OceanbaseObjCompilePattern.get(obj_name, None) + if p is None: + tp = ObLogParser.get_obj_parser_pattern(key_list) + OceanbaseObjCompilePattern[obj_name] = re.compile(tp) + p = OceanbaseObjCompilePattern[obj_name] + m = p.finditer(obj_str) + for i in m: + d.update(i.groupdict()) + if is_tmp_pattern: + OceanbaseObjCompilePattern[obj_name] = None + return d + + @staticmethod + def parse_log_vars_detail(log_name, var_dict): + var_obj_dict = OceanbaseLogVarObjDict.get(log_name, None) + if not var_obj_dict: + print('{} vars detail cannot be parsed'.format(log_name)) + else: + for k in var_dict.keys(): + var_obj_name = var_obj_dict.get(k, None) + if var_obj_name == "not_standard_obj_list": + tp_obj_list = ObLogParser.get_obj_list(var_dict[k]) + var_dict[k] = tp_obj_list + elif var_obj_name: + td = ObLogParser.parse_obj(var_obj_name, var_dict[k]) + var_dict[k] = td + ObLogParser.parse_obj_detail(var_obj_name, var_dict[k]) + + @staticmethod + def parse_log_vars_detail_v2(log_name, var_dict): + var_obj_dict = OceanbaseLogVarObjDict.get(log_name, None) + if not var_obj_dict: + # get obj list + for k in var_dict.keys(): + if var_dict[k].startswith('{'): + td = ObLogParser.parse_obj_v2(k, var_dict[k]) + var_dict[k] = td + ObLogParser.parse_obj_detail_v2(k, var_dict[k]) + else: + for k in var_dict.keys(): + var_obj_name = var_obj_dict.get(k, None) + if var_obj_name == "not_standard_obj_list": + tp_obj_list = ObLogParser.get_obj_list(var_dict[k]) + var_dict[k] = tp_obj_list + elif var_obj_name: + td = ObLogParser.parse_obj(var_obj_name, var_dict[k]) + var_dict[k] = td + ObLogParser.parse_obj_detail(var_obj_name, var_dict[k]) + + @staticmethod + def parse_raw_log_vars(log_name, var_str): + d = dict() + key_list = OceanbaseLogVarDict.get(log_name, []) + if len(key_list) == 0: + print('{} lob vars cannot be parsed'.format(log_name)) + else: + p = OceanbaseLogVarCompilePattern.get(log_name, None) + if p is None: + tp = ObLogParser.get_raw_log_var_parser_pattern(key_list) + OceanbaseLogVarCompilePattern[log_name] = re.compile(tp) + p = OceanbaseLogVarCompilePattern[log_name] + m = p.finditer(var_str) + for i in m: + d.update(i.groupdict()) + return d + + @staticmethod + def parse_normal_log_vars(log_name, var_str): + d = dict() + key_list = OceanbaseLogVarDict.get(log_name, []) + if len(key_list) == 0: + print('{} lob vars cannot be parsed'.format(log_name)) + else: + p = OceanbaseLogVarCompilePattern.get(log_name, None) + if p is None: + tp = ObLogParser.get_log_var_parser_pattern(key_list) + OceanbaseLogVarCompilePattern[log_name] = re.compile(tp) + p = OceanbaseLogVarCompilePattern[log_name] + m = p.finditer(var_str) + for i in m: + d.update(i.groupdict()) + return d + + @staticmethod + def parse_normal_log_vars_v2(var_str): + d = dict() + log_name = 'log_vars_v2' + p = OceanbaseLogVarCompilePattern.get(log_name, None) + if p is None: + tp = r'(\w+)=(.*?)(?=\s\w+=|$)' + OceanbaseLogVarCompilePattern[log_name] = re.compile(tp) + p = OceanbaseLogVarCompilePattern[log_name] + m = p.findall(var_str) + for i in m: + key = i[0] + val = i[1].strip(',') + d[key] = val + return d + + @staticmethod + def parse_log_vars(log_name, var_str, log_type=1): + d = dict() + if log_type == 1: + d = ObLogParser.parse_normal_log_vars(log_name, var_str) + if log_type == 2: + # raw log + d = ObLogParser.parse_raw_log_vars(log_name, var_str) + return d + + @staticmethod + def parse_raw_print_log(line): + # parse a log that produced by raw print + d = dict() + if ObLogParser.compiled_raw_log_pattern is None: + msg = "(?P\[.*\])" + vars = "(?P.*)" + parray = [msg, vars] + p = '\s*'.join(parray) + ObLogParser.compiled_raw_log_pattern = re.compile(p) + m = ObLogParser.compiled_raw_log_pattern.finditer(line) + for i in m: + d.update(i.groupdict()) + return d + + @staticmethod + def parse_log_vars_v2(log_name, var_str, log_type=1): + d = dict() + if log_type == 1: + d = ObLogParser.parse_normal_log_vars_v2(var_str) + if log_type == 2: + # raw log + d = ObLogParser.parse_raw_log_vars(log_name, var_str) + return d + + @staticmethod + def parse_log(line): + # parse a normal log, get all the element + # get raw print log if it is not a normal log. + d = dict() + # 1, means normal log + # 2, means raw print log + log_type = 1 + if ObLogParser.compiled_log_pattern is None: + date_time = "\[(?P\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d+)\]" + log_level = "(?P[A-Z]+)" + module = "\[?(?P[A-Z]+\.*[A-Z]*)?\]?" + func = "(?P[a-zA-Z_0-9]+\(?\)?)" + file_no = "\((?P[a-zA-Z0-9_\.\-]+):(?P[0-9]+)\)" + thread_no = "\[(?P[0-9]+)\]" + thread_name = "\[(?P[A-Za-z]+[0-9_A-Za-z]*)?\]" + tenant_id = "\[T(?P[0-9]+)\]" + trace_id = "\[(?P[A-Za-z\-0-9]+)\]" + lt = "\[lt=(?P[0-9]+)\]" + errcode = "(\[errcode=\-?)?(?P[0-9]+)?(\])?" + msg = "(?P[A-Za-z\s\,\.\[\]\!\_]+)?" + variables = "\(?(?P.*)?\)?$" + parray = [date_time, log_level, module, func, file_no, thread_no, thread_name, tenant_id, trace_id, lt, errcode, msg, variables] + p = '\s*'.join(parray) + ObLogParser.compiled_log_pattern = re.compile(p) + m = ObLogParser.compiled_log_pattern.finditer(line) + for i in m: + d.update(i.groupdict()) + if not d: + log_type = 2 + d = ObLogParser.parse_raw_print_log(line) + if d: + d['log_type'] = log_type + return d diff --git a/common/ssh_client/kubernetes_client.py b/common/ssh_client/kubernetes_client.py index 251ab839..5103571d 100644 --- a/common/ssh_client/kubernetes_client.py +++ b/common/ssh_client/kubernetes_client.py @@ -42,11 +42,14 @@ def __init__(self, context=None, node=None): def exec_cmd(self, cmd): exec_command = ['/bin/sh', '-c', cmd] self.stdio.verbose("KubernetesClient exec_cmd: {0}".format(cmd)) - resp = stream(self.client.connect_get_namespaced_pod_exec, self.pod_name, self.namespace, command=exec_command, stderr=True, stdin=False, stdout=True, tty=False, container=self.container_name) - self.stdio.verbose("KubernetesClient exec_cmd.resp: {0}".format(resp)) - if "init system (PID 1). Can't operate." in resp: - return "KubernetesClient can't get the resp by {0}".format(cmd) - return resp + try: + resp = stream(self.client.connect_get_namespaced_pod_exec, self.pod_name, self.namespace, command=exec_command, stderr=True, stdin=False, stdout=True, tty=False, container=self.container_name) + self.stdio.verbose("KubernetesClient exec_cmd.resp: {0}".format(resp)) + if "init system (PID 1). Can't operate." in resp: + return "KubernetesClient can't get the resp by {0}".format(cmd) + return resp + except Exception as e: + return f"KubernetesClient can't get the resp by {cmd}: {str(e)}" def download(self, remote_path, local_path): return self.__download_file_from_pod(self.namespace, self.pod_name, self.container_name, remote_path, local_path) diff --git a/config.py b/config.py index bd32cbbe..c3aa9ce1 100644 --- a/config.py +++ b/config.py @@ -262,22 +262,7 @@ def get_node_config(self, type, node_ip, config_item): class InnerConfigManager(Manager): - def __init__(self, stdio=None, inner_config_change_map=None): - if inner_config_change_map is None: - inner_config_change_map = {} + def __init__(self, stdio=None): inner_config_abs_path = os.path.abspath(INNER_CONFIG_FILE) super().__init__(inner_config_abs_path, stdio=stdio) self.config = self.load_config_with_defaults(DEFAULT_INNER_CONFIG) - if inner_config_change_map != {}: - self.config = self._change_inner_config(self.config, inner_config_change_map) - - def _change_inner_config(self, conf_map, change_conf_map): - for key, value in change_conf_map.items(): - if key in conf_map: - if isinstance(value, dict): - self._change_inner_config(conf_map[key], value) - else: - conf_map[key] = value - else: - conf_map[key] = value - return conf_map diff --git a/core.py b/core.py index 98e17cf7..4aa6cf92 100644 --- a/core.py +++ b/core.py @@ -51,7 +51,6 @@ from handler.gather.gather_tabledump import GatherTableDumpHandler from handler.gather.gather_parameters import GatherParametersHandler from handler.gather.gather_variables import GatherVariablesHandler -from stdio import SafeStdio from telemetry.telemetry import telemetry from update.update import UpdateHandler from colorama import Fore, Style @@ -61,11 +60,11 @@ from common.tool import TimeUtils -class ObdiagHome(SafeStdio): +class ObdiagHome(object): def __init__(self, stdio=None, config_path=os.path.expanduser('~/.obdiag/config.yml'), inner_config_change_map=None): self._optimize_manager = None - self.stdio = stdio + self.stdio = None self._stdio_func = None self.cmds = [] self.options = Values() @@ -118,7 +117,7 @@ def set_stdio(self, stdio): def _print(msg, *arg, **kwarg): sep = kwarg['sep'] if 'sep' in kwarg else None end = kwarg['end'] if 'end' in kwarg else None - return stdio.print(msg, sep='' if sep is None else sep, end='\n' if end is None else end) + return print(msg, sep='' if sep is None else sep, end='\n' if end is None else end) self.stdio = stdio self._stdio_func = {} @@ -323,38 +322,31 @@ def analyze_fuction(self, function_type, opt): return False def check(self, opts): - try: - result_map = {"data": {}} - config = self.config_manager - if not config: - self._call_stdio('error', 'No such custum config') - return False - else: - self.stdio.print("check start ...") - self.set_context('check', 'check', config) - obproxy_check_handler = None - observer_check_handler = None - if self.context.obproxy_config.get("servers") is not None and len(self.context.obproxy_config.get("servers")) > 0: - obproxy_check_handler = CheckHandler(self.context, check_target_type="obproxy") - obproxy_check_handler.handle() - report = obproxy_check_handler.execute() - result_map["data"]["obproxy"] = report.report_tobeMap() - if self.context.cluster_config.get("servers") is not None and len(self.context.cluster_config.get("servers")) > 0: - observer_check_handler = CheckHandler(self.context, check_target_type="observer") - observer_check_handler.handle() - report = observer_check_handler.execute() - result_map["data"]["observer"] = report.report_tobeMap() - if obproxy_check_handler is not None: - obproxy_report_path = os.path.expanduser(obproxy_check_handler.report.get_report_path()) - if os.path.exists(obproxy_report_path): - self.stdio.print("Check obproxy finished. For more details, please run cmd '" + Fore.YELLOW + " cat {0} ".format(obproxy_check_handler.report.get_report_path()) + Style.RESET_ALL + "'") - if observer_check_handler is not None: - observer_report_path = os.path.expanduser(observer_check_handler.report.get_report_path()) - if os.path.exists(observer_report_path): - self.stdio.print("Check observer finished. For more details, please run cmd'" + Fore.YELLOW + " cat {0} ".format(observer_check_handler.report.get_report_path()) + Style.RESET_ALL + "'") - except Exception as e: - self.stdio.error("check Exception: {0}".format(e)) - self.stdio.verbose(traceback.format_exc()) + config = self.config_manager + if not config: + self._call_stdio('error', 'No such custum config') + return False + else: + self.stdio.print("check start ...") + self.set_context('check', 'check', config) + obproxy_check_handler = None + observer_check_handler = None + if self.context.obproxy_config.get("servers") is not None and len(self.context.obproxy_config.get("servers")) > 0: + obproxy_check_handler = CheckHandler(self.context, check_target_type="obproxy") + obproxy_check_handler.handle() + obproxy_check_handler.execute() + if self.context.cluster_config.get("servers") is not None and len(self.context.cluster_config.get("servers")) > 0: + observer_check_handler = CheckHandler(self.context, check_target_type="observer") + observer_check_handler.handle() + observer_check_handler.execute() + if obproxy_check_handler is not None: + obproxy_report_path = os.path.expanduser(obproxy_check_handler.report.get_report_path()) + if os.path.exists(obproxy_report_path): + self.stdio.print("Check obproxy finished. For more details, please run cmd '" + Fore.YELLOW + " cat {0} ".format(obproxy_check_handler.report.get_report_path()) + Style.RESET_ALL + "'") + if observer_check_handler is not None: + observer_report_path = os.path.expanduser(observer_check_handler.report.get_report_path()) + if os.path.exists(observer_report_path): + self.stdio.print("Check observer finished. For more details, please run cmd'" + Fore.YELLOW + " cat {0} ".format(observer_check_handler.report.get_report_path()) + Style.RESET_ALL + "'") def check_list(self, opts): config = self.config_manager diff --git a/diag_cmd.py b/diag_cmd.py index 27f44527..7803fbca 100644 --- a/diag_cmd.py +++ b/diag_cmd.py @@ -22,6 +22,7 @@ import sys import textwrap import re +import json from uuid import uuid1 as uuid, UUID from optparse import OptionParser, BadOptionError, Option, IndentedHelpFormatter from core import ObdiagHome @@ -247,6 +248,7 @@ def parse_command(self): def do_command(self): self.parse_command() + trace_id = uuid() ret = False try: log_directory = os.path.join(os.path.expanduser("~"), ".obdiag", "log") @@ -260,6 +262,12 @@ def do_command(self): ROOT_IO.verbose('opts: %s' % self.opts) config_path = os.path.expanduser('~/.obdiag/config.yml') custom_config = Util.get_option(self.opts, 'c') + if custom_config: + if os.path.exists(os.path.abspath(custom_config)): + config_path = custom_config + else: + ROOT_IO.error('The option you provided with -c: {0} is a non-existent configuration file path.'.format(custom_config)) + return obdiag = ObdiagHome(stdio=ROOT_IO, config_path=custom_config, inner_config_change_map=self.inner_config_change_map) obdiag.set_options(self.opts) obdiag.set_cmds(self.cmds) @@ -272,7 +280,6 @@ def do_command(self): if self.has_trace: ROOT_IO.print('Trace ID: %s' % self.trace_id) ROOT_IO.print('If you want to view detailed obdiag logs, please run: {0} display-trace {1}'.format(obdiag_bin, self.trace_id)) - ROOT_IO.just_json(ret or ObdiagResult(code=ObdiagResult.SERVER_ERROR_CODE, data={"err_info": "The return value of the command is not ObdiagResult. Please contact thebase community."})) except NotImplementedError: ROOT_IO.exception('command \'%s\' is not implemented' % self.prev_cmd) except SystemExit: @@ -868,8 +875,61 @@ def __init__(self): super(ObdiagRCARunCommand, self).__init__('run', 'root cause analysis') self.parser.add_option('--scene', type='string', help="rca scene name. The argument is required.") self.parser.add_option('--store_dir', type='string', help='the dir to store rca result, current dir by default.', default='./rca/') - self.parser.add_option('--input_parameters', type='string', help='input parameters of scene') + self.parser.add_option('--input_parameters', action='callback', type='string', callback=self._input_parameters_scene, help='input parameters of scene') self.parser.add_option('-c', type='string', help='obdiag custom config', default=os.path.expanduser('~/.obdiag/config.yml')) + self.scene_input_param_map = {} + + def _input_parameters_scene(self, option, opt_str, value, parser): + """ + input parameters of scene + """ + try: + # input_parameters option is json format + try: + self.scene_input_param_map = json.loads(value) + return + except Exception as e: + # raise Exception("Failed to parse input_parameters. Please check the option is json:{0}".format(value)) + ROOT_IO.verbose("input_parameters option {0} is not json.".format(value)) + + # input_parameters option is key=val format + key, val = value.split('=', 1) + if key is None or key == "": + return + m = self._input_parameters_scene_set(key, val) + + def _scene_input_param(param_map, scene_param_map): + for scene_param_map_key, scene_param_map_value in scene_param_map.items(): + if scene_param_map_key in param_map: + if isinstance(scene_param_map_value, dict): + _scene_input_param(param_map[scene_param_map_key], scene_param_map_value) + else: + param_map[scene_param_map_key] = scene_param_map_value + else: + param_map[scene_param_map_key] = scene_param_map_value + return param_map + + self.scene_input_param_map = _scene_input_param(self.scene_input_param_map, m) + except Exception as e: + raise Exception("Key or val ({1}) is illegal: {0}".format(e, value)) + + def _input_parameters_scene_set(self, key, val): + def recursion(param_map, key, val): + if key is None or key == "": + raise Exception("key is None") + if val is None or val == "": + raise Exception("val is None") + if key.startswith(".") or key.endswith("."): + raise Exception("Key starts or ends '.'") + if "." in key: + map_key = key.split(".")[0] + param_map[map_key] = recursion({}, key[len(map_key) + 1 :], val) + return param_map + else: + param_map[key] = val + return param_map + + return recursion({}, key, val) def init(self, cmd, args): super(ObdiagRCARunCommand, self).init(cmd, args) @@ -877,6 +937,7 @@ def init(self, cmd, args): return self def _do_command(self, obdiag): + Util.set_option(self.opts, 'input_parameters', self.scene_input_param_map) return obdiag.rca_run(self.opts) diff --git a/handler/rca/rca_handler.py b/handler/rca/rca_handler.py index 993bd173..4699c37d 100644 --- a/handler/rca/rca_handler.py +++ b/handler/rca/rca_handler.py @@ -113,7 +113,6 @@ def __init__(self, context): all_scenes_info, all_scenes_item = rca_list.get_all_scenes() self.context.set_variable("rca_deep_limit", len(all_scenes_info)) self.all_scenes = all_scenes_item - self.rca_scene_parameters = None self.rca_scene = None self.cluster = self.context.get_variable("ob_cluster") self.nodes = self.context.get_variable("observer_nodes") @@ -122,15 +121,7 @@ def __init__(self, context): # init input parameters self.report = None self.tasks = None - rca_scene_parameters = Util.get_option(self.options, "input_parameters", "") - if rca_scene_parameters != "": - try: - rca_scene_parameters = json.loads(rca_scene_parameters) - except Exception as e: - raise Exception("Failed to parse input_parameters. Please check the option is json:{0}".format(rca_scene_parameters)) - else: - rca_scene_parameters = {} - self.context.set_variable("input_parameters", rca_scene_parameters) + self.context.set_variable("input_parameters", Util.get_option(self.options, "input_parameters")) self.store_dir = Util.get_option(self.options, "store_dir", "./rca/") self.context.set_variable("store_dir", self.store_dir) self.stdio.verbose( diff --git a/handler/rca/scene/ddl_disk_full_scene.py b/handler/rca/scene/ddl_disk_full_scene.py index 31da106a..71c5d90d 100644 --- a/handler/rca/scene/ddl_disk_full_scene.py +++ b/handler/rca/scene/ddl_disk_full_scene.py @@ -132,7 +132,11 @@ def execute(self): ## if the action is add_index sql = "select table_id from oceanbase.__all_virtual_table_history where tenant_id = '{0}' and data_table_id = '{1}' and table_name like '%{2}%';".format(self.tenant_id, self.table_id, self.index_name) self.verbose("execute_sql is {0}".format(sql)) - self.index_table_id = self.ob_connector.execute_sql_return_cursor_dictionary(sql).fetchall()[0]["table_id"] + sql_tables_data = self.ob_connector.execute_sql_return_cursor_dictionary(sql).fetchall() + if len(sql_tables_data) == 0: + self.stdio.error("can not find index table id by index name: {0}. Please check the index name.".format(self.index_name)) + return + self.index_table_id = sql_tables_data[0]["table_id"] self.verbose("index_table_id is {0}".format(self.index_table_id)) self.record.add_record("index_table_id is {0}".format(self.index_table_id)) diff --git a/rpm/build.sh b/rpm/build.sh index 5ea86c23..d56d3c4b 100755 --- a/rpm/build.sh +++ b/rpm/build.sh @@ -2,7 +2,7 @@ python_bin='python' W_DIR=`pwd` -VERSION=${VERSION:-'2.3.0'} +VERSION=${VERSION:-'2.4.0'} function python_version() diff --git a/rpm/oceanbase-diagnostic-tool.spec b/rpm/oceanbase-diagnostic-tool.spec index 0cb38620..c3621909 100644 --- a/rpm/oceanbase-diagnostic-tool.spec +++ b/rpm/oceanbase-diagnostic-tool.spec @@ -1,5 +1,5 @@ Name: oceanbase-diagnostic-tool -Version:2.3.0 +Version:2.4.0 Release: %(echo $RELEASE)%{?dist} Summary: oceanbase diagnostic tool program Group: Development/Tools diff --git a/test/common/ssh_client/test_docker_client.py b/test/common/ssh_client/test_docker_client.py new file mode 100644 index 00000000..f261f25c --- /dev/null +++ b/test/common/ssh_client/test_docker_client.py @@ -0,0 +1,464 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -* +# Copyright (c) 2022 OceanBase +# OceanBase Diagnostic Tool is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +""" +@time: 2024/07/28 +@file: test_docker_client.py +@desc: +""" + +import unittest +from unittest.mock import patch, MagicMock, call +from docker import DockerClient as DockerClientSDK +from common.ssh_client.docker_client import DockerClient +from context import HandlerContext +from common.obdiag_exception import OBDIAGShellCmdException + + +class TestDockerClient(unittest.TestCase): + + @patch('common.ssh_client.docker_client.docker.from_env') + def setUp(self, mock_docker_from_env): + """ + Configures the mock Docker client and sets up test parameters in a testing environment. + + Parameters: + - mock_docker_from_env: A Mock object to simulate creating a Docker client from an environment. + + Returns: + No direct return value, but sets up various mock objects and contexts used during testing. + + Explanation: + This function is primarily for setting up initialization and mock object configurations before tests run, ensuring controlled test execution. + """ + + # Use MagicMock to simulate a Docker client to avoid actual network operations during tests. + mock_docker_from_env.return_value = MagicMock(spec_set=DockerClientSDK) + + # Initialize a HandlerContext object to simulate the runtime environment. + self.context = HandlerContext() + + # Define a node dictionary containing a container name, which will be used during tests. + self.node_with_container_name = {'container_name': 'test_container'} + + # Define an empty node dictionary for scenarios where no container name is specified. + self.node_without_container_name = {} + + # Create a DockerClient object with the context and node configuration. + self.docker_client = DockerClient(self.context, {}) + + # Set the node attribute of the DockerClient object to simulate node information. + self.docker_client.node = {"container_name": "test_container"} + + # Set the container name attribute of the DockerClient object for scenarios where a container name is specified. + self.docker_client.container_name = "test_container" + + # Use MagicMock to simulate stdio to avoid actual input/output operations. + self.docker_client.stdio = MagicMock() + + # Use MagicMock to simulate the Docker client object to avoid actual Docker API calls. + self.docker_client.client = MagicMock() + + @patch('common.ssh_client.docker_client.docker.from_env') + def test_init_with_valid_node(self, mock_docker_from_env): + """ + Test the __init__ method with a valid node response. + + This test case ensures that the __init__ method initializes the object correctly when provided with a valid node response. + It first mocks the creation of a Docker client from an environment, then verifies if the mocked object's method was called correctly, + and checks if the properties of the initialized object match expectations. + + Parameters: + - mock_docker_from_env: A mock object used to simulate the creation of a Docker client. + """ + + # Mock returning a DockerClientSDK type object + mock_docker_from_env.return_value = MagicMock(spec_set=DockerClientSDK) + + # Call the function under test + docker_client = DockerClient(self.context, self.node_with_container_name) + + # Verify that the method of the mock object was called once + mock_docker_from_env.assert_called_once() + + # Verify that the container_name attribute of the docker_client object is set correctly + self.assertEqual(docker_client.container_name, 'test_container') + + # Verify that the client attribute of the docker_client object is of type DockerClientSDK + self.assertIsInstance(docker_client.client, DockerClientSDK) + + @patch('common.ssh_client.docker_client.docker.from_env') + def test_init_without_container_name(self, mock_docker_from_env): + """ + Test the initialization of DockerClient when no container name is provided. + + This test case aims to verify that when initializing the DockerClient without a container name, + the client can correctly create a Docker client instance using the provided environment, + and that the container_name attribute is correctly set to None. + + Parameters: + - mock_docker_from_env: A mock object used to simulate the return value of docker.from_env(). + + Returns: + No return value; this function's purpose is to perform assertion checks. + """ + + # Set the mock object's return value to simulate a Docker client instance + mock_docker_from_env.return_value = MagicMock(spec_set=DockerClientSDK) + + # Call the function under test to create a DockerClient instance + docker_client = DockerClient(self.context, self.node_without_container_name) + + # Verify that docker.from_env() was called once correctly + mock_docker_from_env.assert_called_once() + + # Verify that docker_client's container_name attribute is None + self.assertIsNone(docker_client.container_name) + + # Verify that docker_client's client attribute is of type DockerClientSDK + self.assertIsInstance(docker_client.client, DockerClientSDK) + + @patch('common.ssh_client.docker_client.docker.from_env') + def test_init_with_invalid_context(self, mock_docker_from_env): + """ + Test the __init__ method with an invalid context. + + This test case ensures that the __init__ method triggers an AttributeError as expected when provided with an invalid context. + + Parameters: + - mock_docker_from_env: A mock object used to simulate the initialization process of the Docker client SDK. + + Returns: + No return value; this method is designed to trigger an AttributeError. + + """ + + # Set up the mock object to return a MagicMock object with the DockerClientSDK interface. + mock_docker_from_env.return_value = MagicMock(spec_set=DockerClientSDK) + + # Expect an AttributeError to be raised when initializing DockerClient with invalid context (None). + # Use assertRaises to verify that the exception is correctly raised. + with self.assertRaises(AttributeError): + DockerClient(None, None) + + def test_exec_cmd_success(self): + """ + Tests the `exec_run` method to simulate successful command execution. + + This test aims to verify whether the `exec_cmd` method can execute commands correctly + and retrieve the correct output from a simulated container. + """ + + # Create a mock container object for simulating Docker API calls + mock_container = MagicMock() + + # Set up the mock to return the previously created mock container when containers.get is called + self.docker_client.client.containers.get.return_value = mock_container + + # Create a mock execution result object to simulate the command execution output and exit code + mock_exec_result = MagicMock() + + # Set the mock exit code to 0, indicating successful command execution + mock_exec_result.exit_code = 0 + + # Set the mock output as a byte string containing the command execution result + mock_exec_result.output = b'successful command output' + + # Set up the mock container to return the previously created mock execution result when exec_run is called + mock_container.exec_run.return_value = mock_exec_result + + # Call the method under test + result = self.docker_client.exec_cmd("echo 'Hello World'") + + # Verify that the methods are called correctly + # Assert that containers.get was called once with the correct container name + self.docker_client.client.containers.get.assert_called_once_with("test_container") + + # Assert that exec_run was called once with the correct parameters + # This checks the format of the command and related execution options + mock_container.exec_run.assert_called_once_with( + cmd=["bash", "-c", "echo 'Hello World'"], + detach=False, + stdout=True, + stderr=True, + ) + + # Compare the method's return value with the expected output + self.assertEqual(result, 'successful command output') + + def test_exec_cmd_failure(self): + """ + Test the exec_run method to simulate a failed command execution. + + This function sets up a mock container and a mock execution result to simulate a failure scenario. + It then calls the method under test and verifies that it behaves as expected. + """ + + # Create a mock container object + mock_container = MagicMock() + + # Set the return value for getting a container from the Docker client + self.docker_client.client.containers.get.return_value = mock_container + + # Create a mock execution result object + mock_exec_result = MagicMock() + + # Set the exit code and output of the mock execution result + mock_exec_result.exit_code = 1 + mock_exec_result.output = b'command failed output' + + # Set the return value for executing a command on the mock container + mock_container.exec_run.return_value = mock_exec_result + + # Call the method under test and expect an exception to be raised + with self.assertRaises(Exception): + self.docker_client.exec_cmd("exit 1") + + # Verify that the container get method was called correctly + self.docker_client.client.containers.get.assert_called_once_with("test_container") + # Verify that the exec_run method was called with the correct parameters + mock_container.exec_run.assert_called_once_with( + cmd=["bash", "-c", "exit 1"], + detach=False, + stdout=True, + stderr=True, + ) + + # Check that the expected exception is raised + self.assertRaises(OBDIAGShellCmdException) + + def test_exec_cmd_exception(self): + """ + Test if the containers.get method raises an exception. + + This function sets up a side effect for the containers.get method to simulate an error scenario, + calls the method under test, and verifies if the expected exception is raised. + """ + + # Set up the containers.get method to raise an exception when called + self.docker_client.client.containers.get.side_effect = Exception('Error', 'Something went wrong') + + # Call the method under test and expect a specific exception to be raised + with self.assertRaises(Exception) as context: + self.docker_client.exec_cmd("echo 'Hello World'") + + # Verify that the containers.get method was called exactly once with the correct argument + self.docker_client.client.containers.get.assert_called_once_with("test_container") + + # Get the exception message and verify it contains the expected information + exception_message = str(context.exception) + self.assertIn("sshHelper ssh_exec_cmd docker Exception", exception_message) + self.assertIn("Something went wrong", exception_message) + + @patch('builtins.open', new_callable=MagicMock) + def test_download_success(self, mock_open): + """ + Test the download method with a successful response. + + :param mock_open: A mock object to simulate file operations. + """ + + # Create a list with simulated file content + fake_data = [b'this is a test file content'] + + # Create a fake file status dictionary containing the file size + fake_stat = {'size': len(fake_data[0])} + + # Set up the mock container get function return value + self.docker_client.client.containers.get.return_value.get_archive.return_value = (fake_data, fake_stat) + + # Define remote and local file paths + remote_path = '/path/in/container' + local_path = '/path/on/host/test_file' + + # Call the function under test + self.docker_client.download(remote_path, local_path) + + # Verify that the method was called correctly + self.docker_client.client.containers.get.return_value.get_archive.assert_called_once_with(remote_path) + + # Verify that the local file was opened in binary write mode + mock_open.assert_called_once_with(local_path, "wb") + + # Get the file handle from the mock_open return value + handle = mock_open.return_value.__enter__.return_value + + # Verify that the file content was written correctly + handle.write.assert_called_once_with(fake_data[0]) + + # Verify that verbose logging was called + self.docker_client.stdio.verbose.assert_called_once() + + # Verify that error logging was not called, as no errors are expected + self.docker_client.stdio.error.assert_not_called() + + def test_download_exception(self): + """ + Test the download method when it receives an exception response. + + Sets up a side effect to simulate an error when attempting to get a container, + then calls the download method expecting an exception, and finally verifies + that the exception message contains the expected text and that the error + was logged. + """ + + # Set up a side effect for getting containers to raise an exception + self.docker_client.client.containers.get.side_effect = Exception('Error', 'Message') + + # Define the remote and local paths for the file to be downloaded + remote_path = '/path/in/container' + local_path = '/path/on/host/test_file' + + # Call the function under test, expecting an exception + with self.assertRaises(Exception) as context: + self.docker_client.download(remote_path, local_path) + + # Verify that the exception message contains the expected text + self.assertIn("sshHelper download docker Exception", str(context.exception)) + + # Verify that the error was logged + self.docker_client.stdio.error.assert_called_once() + + def test_upload_success(self): + """Test the upload method and verify a successful response.""" + + # Set up a mock container object to simulate Docker client operations + mock_container = self.docker_client.client.containers.get.return_value + + # Configure the mock container's put_archive method to return None when called + mock_container.put_archive.return_value = None + + # Call the function under test + self.docker_client.upload("/remote/path", "/local/path") + + # Verify that the put_archive method was called once with the correct arguments + mock_container.put_archive.assert_called_once_with("/remote/path", "/local/path") + + # Verify that the stdio verbose method was called once, ensuring proper logging during the upload process + self.docker_client.stdio.verbose.assert_called_once() + + def test_upload_failure(self): + """ + Tests the upload method when it receives a failure response. + + This test case simulates an error during the upload process. + """ + + # Set up the mock container object + mock_container = self.docker_client.client.containers.get.return_value + + # Trigger an exception to simulate a failed upload + mock_container.put_archive.side_effect = Exception('Error') + + # Call the function under test and expect an exception to be raised + with self.assertRaises(Exception) as context: + self.docker_client.upload("/remote/path", "/local/path") + + # Verify the exception message is correct + self.assertIn("sshHelper upload docker Exception: Error", str(context.exception)) + + # Verify the error message is output through the error channel + self.docker_client.stdio.error.assert_called_once_with("sshHelper upload docker Exception: Error") + + def test_ssh_invoke_shell_switch_user_success(self): + """ + Test the ssh_invoke_shell_switch_user method with a successful response. + + This test simulates a successful scenario of invoking an SSH shell and switching users within a Docker container. + It ensures that when the user switch operation in the Docker container is successful, the method correctly calls + `exec_create` and `exec_start`, and returns the expected response. + """ + + # Set up mock objects for the Docker client's exec_create and exec_start methods + mock_exec_create = self.docker_client.client.exec_create + mock_exec_start = self.docker_client.client.exec_start + + # Configure the return values for the mock objects + mock_exec_create.return_value = {'Id': 'exec_id'} + mock_exec_start.return_value = b'successful response' + + # Call the method under test + response = self.docker_client.ssh_invoke_shell_switch_user('new_user', 'ls', 10) + + # Verify that exec_create was called correctly + mock_exec_create.assert_called_once_with(container='test_container', command=['su', '- new_user']) + + # Verify that exec_start was called with the correct exec_id + mock_exec_start.assert_called_once_with({'Id': 'exec_id'}) + + # Verify that the response matches the expected value + self.assertEqual(response, b'successful response') + + def test_ssh_invoke_shell_switch_user_exception(self): + """ + Test the behavior of the ssh_invoke_shell_switch_user method when it encounters an exception. + + This test simulates an exception being thrown during the execution of the `exec_create` method, + and verifies that the `ssh_invoke_shell_switch_user` method handles this exception correctly. + + Expected outcome: When `exec_create` throws an exception, the `ssh_invoke_shell_switch_user` method + should catch the exception and include a specific error message in the caught exception. + """ + + # Set up the mock object to simulate the `exec_create` method throwing an exception + mock_exec_create = self.docker_client.client.exec_create + mock_exec_create.side_effect = Exception('Error') + + # Call the function under test and expect it to raise an exception + with self.assertRaises(Exception) as context: + self.docker_client.ssh_invoke_shell_switch_user('new_user', 'ls', 10) + + # Verify that the raised exception contains the expected error message + self.assertIn("sshHelper ssh_invoke_shell_switch_user docker Exception: Error", str(context.exception)) + + def test_get_name(self): + """Test the get_name method to ensure it correctly returns the container name. + + This test case verifies that the custom naming convention for containers is implemented correctly. + It checks the correctness by comparing the expected container name with the actual one obtained. + """ + + # Set a test container name + self.container_name = "test_container" + + # Assign the test container name to the docker_client object + self.docker_client.container_name = self.container_name + + # Construct the expected container name in the format "docker_{actual_container_name}" + expected_name = "docker_{0}".format(self.container_name) + + # Assert that the actual container name matches the expected one + self.assertEqual(self.docker_client.get_name(), expected_name) + + def test_get_ip(self): + """Test the test_get_ip method.""" + + # Set the expected IP address + expected_ip = '192.168.1.100' + + # Mock the return value of the Docker client's containers.get method + # This is to ensure the get_ip method returns the correct IP address + self.docker_client.client.containers.get.return_value.attrs = {'NetworkSettings': {'Networks': {'bridge': {"IPAddress": expected_ip}}}} + + # Call the function under test + ip = self.docker_client.get_ip() + + # Verify that the method is called correctly + # Here we use an assertion to check if the returned IP matches the expected one + self.assertEqual(ip, expected_ip) + + # Ensure that the containers.get method is called correctly with the right parameters + self.docker_client.client.containers.get.assert_called_once_with(self.docker_client.node["container_name"]) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/common/ssh_client/test_kubernetes_cilent.yaml b/test/common/ssh_client/test_kubernetes_cilent.yaml new file mode 100644 index 00000000..a5d6d048 --- /dev/null +++ b/test/common/ssh_client/test_kubernetes_cilent.yaml @@ -0,0 +1,18 @@ +apiVersion: v1 +kind: Config +clusters: +- cluster: + certificate-authority-data: DATA+OMITTED + server: https://127.0.0.1:8443 + name: dev-cluster +users: +- user: + client-certificate-data: DATA+OMITTED + client-key-data: DATA+OMITTED + name: dev-user +contexts: +- context: + cluster: dev-cluster + user: dev-user + name: dev-context +current-context: dev-context \ No newline at end of file diff --git a/test/common/ssh_client/test_kubernetes_client.py b/test/common/ssh_client/test_kubernetes_client.py new file mode 100644 index 00000000..d6a80168 --- /dev/null +++ b/test/common/ssh_client/test_kubernetes_client.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -* +# Copyright (c) 2022 OceanBase +# OceanBase Diagnostic Tool is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +""" +@time: 2024/07/31 +@file: test_kubernetes_client.py +@desc: +""" + +import unittest +import os +from unittest.mock import MagicMock, mock_open, patch +from kubernetes import config +from kubernetes.stream import stream +from context import HandlerContext +from common.ssh_client.kubernetes_client import KubernetesClient +from kubernetes.client.api.core_v1_api import CoreV1Api +from tempfile import NamedTemporaryFile +from kubernetes.client import ApiClient + + +FILE_DIR = "test/common/ssh_client/test_kubernetes_cilent.yaml" + + +class TestKubernetesClient(unittest.TestCase): + def setUp(self): + """ + Setup function to initialize the test environment. + + This function initializes the necessary context, node information, a mock for standard input/output, + a client for interacting with Kubernetes, and creates a temporary file for use during testing. + """ + + # Initialize a HandlerContext object to simulate the test environment's context + self.context = HandlerContext() + + # Define node information including namespace, pod name, container name, and Kubernetes configuration file path + self.node = {"namespace": "default", "pod_name": "test-pod", "container_name": "test-container", "kubernetes_config_file": FILE_DIR} + + # Use MagicMock to mock standard input/output for predictable behavior during tests + self.context.stdio = MagicMock() + + # Create a KubernetesClient instance with the context and node information to interact with the Kubernetes API + self.client = KubernetesClient(context=self.context, node=self.node) + + # Create a temporary file that is not automatically deleted for storing temporary data during testing + self.temp_file = NamedTemporaryFile(delete=False) + + def tearDown(self): + """ + Cleanup actions: close and delete the temporary file. + + This method is called at the end of tests to ensure that temporary files do not occupy system resources. + """ + + # Close the temporary file to ensure all file operations are completed + self.temp_file.close() + + # Remove the temporary file to avoid leaving unused data + os.remove(self.temp_file.name) + + @patch('common.ssh_client.kubernetes_client.config.load_incluster_config') + def test_init_with_no_config_file(self, mock_load_incluster_config): + """ + Test the initialization of KubernetesClient without a configuration file. + + This test ensures that when no kubernetes_config_file is specified in the node dictionary, + initializing KubernetesClient triggers a call to the load_incluster_config method. + This validates that the client correctly loads configurations from the default config file in the cluster. + + Parameters: + - mock_load_incluster_config: A mock object used to track calls to the load_incluster_config method. + """ + + # Set the kubernetes_config_file in the node dictionary to an empty string to simulate the absence of a provided configuration file. + self.node["kubernetes_config_file"] = "" + + # Instantiate KubernetesClient, triggering the initialization process. + KubernetesClient(context=self.context, node=self.node) + + # Verify that the load_incluster_config method was called exactly once. + mock_load_incluster_config.assert_called_once() + + # Check if a message indicating the use of the default configuration file in the cluster was logged. + self.context.stdio.verbose.assert_called_with("KubernetesClient load_kube_config from default config file in cluster.") + + @patch('common.ssh_client.kubernetes_client.config.kube_config.load_kube_config') + def test_init_with_config_file(self, mock_load_kube_config): + """ + Test the initialization of KubernetesClient with a configuration file. + + This test verifies that when initializing a KubernetesClient object, + the Kubernetes configuration is loaded correctly and that the stdio.verbose + method is called to log the loading of the configuration file. + + Parameters: + - mock_load_kube_config: A mock object to track calls to the load_kube_config function. + + Returns: + No return value; this method performs assertion checks. + """ + + # Initialize the KubernetesClient, triggering the configuration file loading logic. + KubernetesClient(context=self.context, node=self.node) + + # Verify that load_kube_config was called once with the expected configuration file path. + mock_load_kube_config.assert_called_once_with(config_file=FILE_DIR) + + # Verify that stdio.verbose was called to log the configuration file loading. + self.context.stdio.verbose.assert_called_with(f"KubernetesClient load_kube_config from {FILE_DIR}") + + @patch('common.ssh_client.kubernetes_client.config.load_incluster_config', side_effect=config.ConfigException) + def test_init_raises_exception(self, mock_load_incluster_config): + """ + Tests whether the __init__ method correctly raises an expected exception. + + This test case verifies that when initializing the KubernetesClient with an empty `kubernetes_config_file`, + it raises the expected exception and checks if the exception message contains the specified error message. + + Parameters: + - mock_load_incluster_config: A mock object used to simulate the behavior of loading kube configurations. + + Returns: + None + + Exceptions: + - Exception: Expected to be raised when `kubernetes_config_file` is set to an empty string. + """ + + # Set the Kubernetes configuration file path in the node to an empty string to trigger an exception + self.node["kubernetes_config_file"] = "" + + # Use the assertRaises context manager to capture and validate the raised exception + with self.assertRaises(Exception) as context: + KubernetesClient(context=self.context, node=self.node) + + # Verify if the captured exception message contains the expected error message + self.assertTrue("KubernetesClient load_kube_config error. Please check the config file." in str(context.exception)) + + @patch.object(CoreV1Api, 'connect_get_namespaced_pod_exec', autospec=True) + def test_exec_cmd_success(self, mock_connect_get_namespaced_pod_exec): + """ + Test the `exec_cmd` method with a successful response. + + This method sets up a mock for `connect_get_namespaced_pod_exec` to return a predefined successful response, + ensuring the `exec_cmd` method behaves as expected. + + Parameters: + - mock_connect_get_namespaced_pod_exec: A mock object used to replace the actual `connect_get_namespaced_pod_exec` method's return value. + + Returns: + No return value; this method verifies behavior through assertions. + """ + + # Set up the mock object to return a predefined response simulating a successful command execution + mock_connect_get_namespaced_pod_exec.return_value = "mocked response" + + # Define a test command using an echo command outputting a simple string + cmd = "echo 'Hello, World!'" + + # Call the `exec_cmd` method and get the response + response = self.client.exec_cmd(cmd) + + # Verify that the returned response matches the predefined mocked response + self.assertEqual(response, "mocked response") + + @patch.object(CoreV1Api, 'connect_get_namespaced_pod_exec', autospec=True) + def test_exec_cmd_failure(self, mock_connect_get_namespaced_pod_exec): + """ + Tests the `exec_cmd` method's behavior when it encounters a failure response. + + This test simulates a command execution failure by causing the `connect_get_namespaced_pod_exec` method to throw an exception, + and verifies that the error handling behaves as expected. + + Parameters: + - mock_connect_get_namespaced_pod_exec: A Mock object used to simulate the `connect_get_namespaced_pod_exec` method. + + Returns: + No return value; this method verifies its behavior through assertions. + """ + + # Simulate the `connect_get_namespaced_pod_exec` method throwing an exception on call + mock_connect_get_namespaced_pod_exec.side_effect = Exception("Mocked exception") + + # Call the method under test + cmd = "fail command" + response = self.client.exec_cmd(cmd) + + # Verify that the error message matches the expected one + expected_error_msg = "KubernetesClient can't get the resp by fail command: Mocked exception" + self.assertEqual(response, expected_error_msg) + + @patch.object(KubernetesClient, '_KubernetesClient__download_file_from_pod') + def test_download_file_from_pod_success(self, mock_download): + """ + Test successful file download from a Pod. + + This test case simulates the scenario of downloading a file from a Kubernetes Pod. + It focuses on verifying the correctness of the download process, including calling + the appropriate mocked method and ensuring the file content matches expectations. + + Args: + - mock_download: A mock object used to simulate the download method. + """ + + # Define the behavior of the mocked download method + def mock_download_method(namespace, pod_name, container_name, file_path, local_path): + """ + Mocked method for simulating file downloads. + + Args: + - namespace: The Kubernetes namespace. + - pod_name: The name of the Pod. + - container_name: The name of the container. + - file_path: The remote file path. + - local_path: The local file save path. + """ + # Create a local file and write mock data + with open(local_path, 'wb') as file: # Write in binary mode + file.write(b"test file content") # Write mock data + + # Assign the mocked method to the mock object + mock_download.side_effect = mock_download_method + + # Initialize the mocked Kubernetes client + k8s_client = KubernetesClient(self.context, self.node) + k8s_client.client = MagicMock() + k8s_client.stdio = MagicMock() + + # Define the required local path, namespace, Pod name, container name, and file path for testing + local_path = self.temp_file.name + namespace = "test-namespace" + pod_name = "test-pod" + container_name = "test-container" + file_path = "test/file.txt" + + # Call the mocked download method + mock_download(namespace, pod_name, container_name, file_path, local_path) + + # Verify that the file has been written with the expected content + with open(local_path, 'rb') as file: # Read in binary mode + content = file.read() + self.assertEqual(content, b"test file content") # Compare byte type data + + @patch('common.ssh_client.kubernetes_client.stream') + def test_download_file_from_pod_error(self, mock_stream): + """ + Test the scenario of an error occurring when downloading a file from a Pod. + + This test case sets up an error response through a mocked stream object to simulate a situation where errors occur during file download. + The focus is on the error handling logic, ensuring that errors encountered during the download process are correctly logged and handled. + + Parameters: + - mock_stream: A mocked stream object used to set up the expected error response. + """ + + # Set up the return values for the mocked response to simulate an error response. + mock_resp = MagicMock() + mock_resp.is_open.return_value = True # Simulate the response as not closed + mock_resp.peek_stdout.return_value = False + mock_resp.peek_stderr.return_value = True + mock_resp.read_stderr.return_value = "Error occurred" # Ensure read_stderr is called + mock_stream.return_value = mock_resp + + # Initialize the Kubernetes client with mocked objects + k8s_client = self.client + k8s_client.client = MagicMock() + k8s_client.stdio = MagicMock() + + # Define parameters required for downloading the file + local_path = self.temp_file.name + namespace = "test-namespace" + pod_name = "test-pod" + container_name = "test-container" + file_path = "test/file.txt" + + # Call the download function, which will trigger the mocked error response + k8s_client._KubernetesClient__download_file_from_pod(namespace, pod_name, container_name, file_path, local_path) + + # Verify that the stderr content is correctly logged, ensuring that error messages are captured and handled + k8s_client.stdio.error.assert_called_with("ERROR: ", "Error occurred") + + @patch('kubernetes.config.load_kube_config') + @patch('kubernetes.client.CoreV1Api') + def test_upload_file_to_pod(self, mock_core_v1_api, mock_load_kube_config): + """ + Tests the functionality of uploading a file to a Kubernetes Pod. + + This is a unit test that uses MagicMock to simulate the Kubernetes CoreV1Api and file operations. + It verifies the behavior of the `__upload_file_to_pod` method, including whether the underlying API is called correctly, + and the reading and uploading of the file. + + Parameters: + - mock_core_v1_api: A mocked instance of CoreV1Api. + - mock_load_kube_config: A mocked function for loading Kubernetes configuration. + + Returns: + None + """ + + # Set up mock objects + mock_resp = MagicMock() + mock_resp.is_open.return_value = True # # Simulate interaction based on requirements + mock_resp.peek_stdout.return_value = False + mock_resp.peek_stderr.return_value = False + mock_resp.read_stdout.return_value = '' + mock_resp.read_stderr.return_value = '' + + # Set up the return value for the stream function + mock_core_v1_api_instance = MagicMock(spec=CoreV1Api) + mock_core_v1_api.return_value = mock_core_v1_api_instance + mock_core_v1_api_instance.api_client = MagicMock() # 添加 api_client 属性 + + # Create a mock object with a __self__ attribute + mock_self = MagicMock() + mock_self.api_client = mock_core_v1_api_instance.api_client + + # Bind connect_get_namespaced_pod_exec to an object with an api_client attribute + mock_core_v1_api_instance.connect_get_namespaced_pod_exec = MagicMock(__self__=mock_self, return_value=mock_resp) + + # Instantiate KubernetesClient and call the method + k8s_client = KubernetesClient(self.context, self.node) + k8s_client.stdio = MagicMock() # 模拟 stdio 对象 + namespace = 'test_namespace' + pod_name = 'test_pod' + container_name = 'test_container' + local_path = '/local/path/to/file' + remote_path = '/remote/path/to/file' + + # Since there's no real Kubernetes cluster or Pod in the test environment, use MagicMock to simulate the file + mock_file_content = b'test file content' + with patch('builtins.open', return_value=MagicMock(__enter__=lambda self: self, __exit__=lambda self, *args: None, read=lambda: mock_file_content)) as mock_open_file: + k8s_client._KubernetesClient__upload_file_to_pod(namespace, pod_name, container_name, local_path, remote_path) + + # Verify if load_kube_config was called + mock_load_kube_config.assert_called_once() + + # Verify if the stream function was called correctly + mock_core_v1_api_instance.connect_get_namespaced_pod_exec.assert_called_once() + + # Verify if the file was read and uploaded correctly + mock_open_file.assert_called_once_with(local_path, 'rb') + + # Ensure is_open returns True to trigger write_stdin + mock_resp.is_open.return_value = True + + # Use side_effect to simulate writing file content + mock_resp.write_stdin.side_effect = lambda data: None + + # Ensure write_stdin was called correctly + mock_resp.write_stdin.assert_called_once_with(mock_file_content) + + # Verify if the response was closed + mock_resp.close.assert_called_once() + + def test_ssh_invoke_shell_switch_user(self): + """ + Test the functionality of switching users within an SSH session. + + This test validates the ability to switch users within an SSH session by mocking the Kubernetes API client and related Pod execution environment. + It simulates calling the private method `__ssh_invoke_shell_switch_user` of a `KubernetesClient` instance and asserts that the method's return value matches the expected value. + """ + + # Mock some attributes of the KubernetesClient instance + self.client.pod_name = "test_pod" + self.client.namespace = "default" + self.client.container_name = "test_container" + + # Create a mock ApiClient instance + self.api_client_mock = MagicMock(spec=ApiClient) + self.api_client_mock.configuration = MagicMock() # 添加configuration属性 + + # Create a mock connect_get_namespaced_pod_exec method + self.client.client = MagicMock() + self.client.client.connect_get_namespaced_pod_exec = MagicMock(__self__=MagicMock(api_client=self.api_client_mock)) + + # Mock stream function + self.stream_mock = MagicMock() + + # Define test user, command, and timeout values + new_user = "test_user" + cmd = "echo 'Hello, World!'" + time_out = 10 + + # Define the expected response + expected_response = "Hello, World!\n" + + # Directly mock the function return value + self.client._KubernetesClient__ssh_invoke_shell_switch_user = MagicMock(return_value=expected_response) + + # Call the function + result = self.client._KubernetesClient__ssh_invoke_shell_switch_user(new_user, cmd, time_out) + + # Assert the result matches the expected value + self.assertEqual(result, expected_response) + + def test_get_name(self): + """ + This function tests the `get_name` method of a simulated KubernetesClient instance. + + Steps: + - Sets up the client's namespace and pod_name attributes. + - Calls the `get_name` method on the client. + - Asserts that the returned name matches the expected format. + """ + + # Simulate a KubernetesClient instance by setting its namespace and pod_name attributes + self.client.namespace = "default" + self.client.pod_name = "test-pod" + + # Call the get_name method to retrieve the formatted name + name = self.client.get_name() + + # Assert that the retrieved name matches the expected format + self.assertEqual(name, "kubernetes_default_test-pod") + + def test_get_ip_with_ip_set(self): + """ + Test case to verify the IP address retrieval when an IP is set. + + This test case checks whether the correct IP address can be retrieved when the node's IP address is already set. + The test sets the IP address for the client node, then calls the get_ip method and expects it to return the set IP address. + """ + ip_address = "192.168.1.1" + self.client.node['ip'] = ip_address + self.assertEqual(self.client.get_ip(), ip_address) + + def test_get_ip_without_ip_set(self): + """ + Test the logic of getting an IP when no IP is set. + + This test case aims to verify that calling the get_ip method should raise an exception when Kubernetes has not set the IP for the Observer. + Use assertRaises to check if the expected exception is correctly raised. + """ + with self.assertRaises(Exception) as context: + self.client.get_ip() + + # Verify if the error message contains the specific message. + self.assertTrue("kubernetes need set the ip of observer" in str(context.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/common/ssh_client/test_local_client.py b/test/common/ssh_client/test_local_client.py new file mode 100644 index 00000000..b946c50e --- /dev/null +++ b/test/common/ssh_client/test_local_client.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -* +# Copyright (c) 2022 OceanBase +# OceanBase Diagnostic Tool is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +""" +@time: 2024/07/22 +@file: test_local_client.py +@desc: +""" + +import unittest +import subprocess32 as subprocess +from unittest.mock import patch, MagicMock +from common.ssh_client.local_client import LocalClient +from context import HandlerContext + + +class TestLocalClient(unittest.TestCase): + def test_init_with_context_and_node(self): + """ + Test the initialization process when passing `context` and `node`. + """ + + # Create an instance of HandlerContext for testing how the `context` parameter is handled during initialization. + context = HandlerContext() + + # Create an empty dictionary to test how the `node` parameter is handled during initialization. + node = {} + + # Initialize a LocalClient instance with the provided `context` and `node`. + client = LocalClient(context=context, node=node) + + # Assert that the `context` attribute of `client` is equal to the passed-in `context`. + self.assertEqual(client.context, context) + + # Assert that the `node` attribute of `client` is equal to the passed-in `node`. + self.assertEqual(client.node, node) + + def test_init_with_only_node(self): + """ + Test the initialization behavior when only providing a node. + + This test case aims to verify that when passing `None` as the context and a node dictionary to `LocalClient`, + they are correctly assigned to their respective attributes. + """ + + # Initialize an empty dictionary as the node + node = {} + + # Initialize `LocalClient` with `None` as the context and the previously defined node + client = LocalClient(context=None, node=node) + + # Verify that the `context` attribute of `client` is `None` + self.assertIsNone(client.context) + + # Verify that the `node` attribute of `client` matches the passed-in `node` + self.assertEqual(client.node, node) + + def test_init_with_only_context(self): + """ + Test initialization when only the context is passed. + + This test case checks if the initialization raises the expected exception when only the context is provided and other necessary parameters are missing. + It verifies that object creation is correctly prevented when the initialization conditions are not fully met. + + Parameters: + - context (HandlerContext): An instance of HandlerContext representing the event handling context. + + Returns: + - No return value, but raises an AttributeError to test the robustness of the initialization process. + """ + context = HandlerContext() + self.assertRaises(AttributeError, LocalClient, context, None) + + def test_init_with_no_args(self): + """Tests initialization without passing any parameters""" + # Attempt to instantiate LocalClient without arguments to verify if it raises an AttributeError + self.assertRaises(AttributeError, LocalClient, None, None) + + def setUp(self): + """ + Set up the environment before executing test cases. + + This method initializes necessary components for test cases by creating an instance of `HandlerContext`, + an empty node dictionary, and mocking the standard input/output and client of the `LocalClient`. + + :param self: The instance of the class that this method is part of. + """ + + # Create an instance of HandlerContext to simulate the testing environment's context + context = HandlerContext() + + # Create an empty dictionary as the node object, which will be used to simulate data storage in tests + node = {} + + # Initialize a LocalClient instance using the context and node, simulating local client operations + self.local_client = LocalClient(context=context, node=node) + + # Mock the standard input/output of LocalClient to avoid actual I/O operations during tests + self.local_client.stdio = MagicMock() + + # Mock the client attribute of LocalClient to avoid actual client connections during tests + self.local_client.client = MagicMock() + + @patch('subprocess.Popen') + def test_exec_cmd_success(self, mock_popen): + """ + Test the exec_cmd command successfully and return standard output. + + :param mock_popen: A mocked version of subprocess.Popen for testing purposes. + """ + + # Create a mock process object + mock_process = MagicMock() + + # Set up the communicate method's return value to simulate stdout and stderr + mock_process.communicate.return_value = (b"stdout output", b"") + + # Set the return value of the mocked popen to be the mock process + mock_popen.return_value = mock_process + + # Call the function under test + result = self.local_client.exec_cmd("echo 'Hello World'") + + # Verify the results of the function call + # Assert that the returned result matches the expected output + self.assertEqual(result, "stdout output") + + # Verify that the verbose method was called with the correct logging information + self.local_client.stdio.verbose.assert_called_with("[local host] run cmd = [echo 'Hello World'] on localhost") + + @patch('subprocess.Popen') + def test_exec_cmd_failure(self, mock_popen): + """ + Tests the exec_cmd command when it fails and returns the stderr output. + + This test simulates a failure scenario for the exec_cmd command by mocking the popen object. + It checks whether the exec_cmd command handles failures correctly and returns the expected error message. + + Parameters: + - mock_popen: A parameter used to mock the popen object for testing failure scenarios. + + Returns: + No return value; this method primarily performs assertion checks. + """ + + # Create a mocked popen object to simulate a failed command execution + mock_process = MagicMock() + mock_process.communicate.return_value = (b"", b"stderr output") + mock_popen.return_value = mock_process + + # Call the function under test + result = self.local_client.exec_cmd("exit 1") + + # Verify that the function execution result matches the expected outcome, i.e., the correct error message is returned + self.assertEqual(result, "stderr output") + + # Verify that the log information was recorded correctly during command execution + self.local_client.stdio.verbose.assert_called_with("[local host] run cmd = [exit 1] on localhost") + + @patch('subprocess.Popen') + def test_exec_cmd_exception(self, mock_popen): + """ + Test the exec_cmd command in exceptional scenarios. + + This test sets up a scenario where the `popen` method raises an exception, + and checks if `exec_cmd` handles it correctly. + + Parameters: + - mock_popen: A mock object to simulate the behavior of popen, which will raise an exception. + + Raises: + Exception: If the `exec_cmd` does not handle the exception properly. + """ + + # Configure the mock_popen to raise an exception when called + mock_popen.side_effect = Exception("Popen error") + + # Execute the function being tested, expecting it to raise an exception + with self.assertRaises(Exception) as context: + self.local_client.exec_cmd("exit 1") + + # Verify the exception message contains the expected text + self.assertIn("Execute Shell command failed", str(context.exception)) + + # Ensure the error log is recorded as expected + self.local_client.stdio.error.assert_called_with("run cmd = [exit 1] on localhost, Exception = [Popen error]") + + @patch('common.ssh_client.local_client.shutil.copy') + def test_download_success(self, mock_copy): + """ + Test the successful scenario of the download command. + + This test case simulates a successful file download and verifies the following: + - The download method was called. + - The download method was called correctly once. + - In the case of a successful download, the error message method was not called. + + Parameters: + - mock_copy: A mocked copy method used to replace the actual file copying operation in the test. + + Returns: + None + """ + + # Define remote and local file paths + remote_path = "/path/to/remote/file" + local_path = "/path/to/local/file" + + # Call the download method under test + self.local_client.download(remote_path, local_path) + + # Verify that mock_copy was called correctly once + mock_copy.assert_called_once_with(remote_path, local_path) + + # Verify that the error message method was not called + self.local_client.stdio.error.assert_not_called() + + @patch('common.ssh_client.local_client.shutil.copy') + def test_download_failure(self, mock_copy): + """ + Tests the failure scenario of the download command. + + :param mock_copy: A mock object to simulate the copy operation and its failure. + """ + + # Set up the mock object to raise an exception to simulate a failure during the download process + mock_copy.side_effect = Exception('copy error') + + # Define the remote and local file paths + remote_path = "/path/to/remote/file" + local_path = "/path/to/local/file" + + # Execute the download operation, expecting it to fail and raise an exception + with self.assertRaises(Exception) as context: + self.local_client.download(remote_path, local_path) + + # Verify that the exception message contains the expected text + self.assertTrue("download file from localhost" in str(context.exception)) + + # Verify that the error message was recorded correctly + self.local_client.stdio.error.assert_called_once() + + @patch('common.ssh_client.local_client.shutil.copy') + def test_upload_success(self, mock_copy): + """ + Tests the successful scenario of the upload command. + + This test case simulates a successful file upload and verifies if the upload process calls methods correctly. + + Parameters: + - mock_copy: A mock object used to simulate the file copy operation. + """ + + # Define remote and local file paths + remote_path = '/tmp/remote_file.txt' + local_path = '/tmp/local_file.txt' + + # Call the function under test for uploading + self.local_client.upload(remote_path, local_path) + + # Verify if mock_copy was called once with the correct parameters + mock_copy.assert_called_once_with(local_path, remote_path) + + # Verify if error messages were not called, ensuring no errors occurred during the upload + self.local_client.stdio.error.assert_not_called() + + @patch('common.ssh_client.local_client.shutil.copy') + def test_upload_failure(self, mock_copy): + """ + Test the upload command failure. + + :param mock_copy: A mocked copy operation that simulates an upload. + """ + + # Simulate an exception to test the failure scenario of the upload + mock_copy.side_effect = Exception('copy error') + + # Define remote and local file paths + remote_path = '/tmp/remote_file.txt' + local_path = '/tmp/local_file.txt' + + # Call the function under test and expect it to raise an exception + with self.assertRaises(Exception) as context: + self.local_client.upload(remote_path, local_path) + + # Verify the exception message matches the expected one + self.assertIn('upload file to localhost', str(context.exception)) + + # Verify that the error message was output through stdio.error + self.local_client.stdio.error.assert_called_once() + + @patch('subprocess.Popen') + def test_ssh_invoke_shell_switch_user_success(self, mock_popen): + """ + Test the ssh_invoke_shell_switch_user command executing successfully and returning standard output. + + Parameters: + mock_popen: A mocked popen object to simulate the subprocess behavior. + + Returns: + None + """ + + # Create a mock process object + mock_process = MagicMock() + + # Set up the communicate method's return value to simulate command execution output + mock_process.communicate.return_value = (b"successful output", b"") + + # Set up the mock_popen method to return the mock process object + mock_popen.return_value = mock_process + + # Call the function under test + result = self.local_client.ssh_invoke_shell_switch_user("new_user", 'echo "Hello World"', 10) + + # Verify if the function was called correctly and the return value matches the expected output + self.assertEqual(result, "successful output") + + # Verify if stdio.verbose was called once appropriately + self.local_client.stdio.verbose.assert_called_once() + + # Verify if mock_popen was called with the expected parameters + mock_popen.assert_called_once_with("su - new_user -c 'echo \"Hello World\"'", stderr=subprocess.STDOUT, stdout=subprocess.PIPE, shell=True, executable='/bin/bash') + + @patch('subprocess.Popen') + def test_ssh_invoke_shell_switch_user_failure(self, mock_popen): + """ + Tests the ssh_invoke_shell_switch_user command failure and returns standard output. + + :param mock_popen: A mocked popen object for testing purposes. + :return: None + """ + + # Create a mock process object + mock_process = MagicMock() + + # Set up the communicate method of the mock process to return error output + mock_process.communicate.return_value = (b"", b"error output") + + # Set up the mock_popen to return the mock process object + mock_popen.return_value = mock_process + + # Call the function under test + result = self.local_client.ssh_invoke_shell_switch_user("new_user", 'echo "Hello World"', 10) + + # Verify that the method is called correctly + self.assertEqual(result, "error output") + + # Verify stdio.verbose was called once + self.local_client.stdio.verbose.assert_called_once() + + # Verify mock_popen was called with the correct parameters + mock_popen.assert_called_once_with("su - new_user -c 'echo \"Hello World\"'", stderr=subprocess.STDOUT, stdout=subprocess.PIPE, shell=True, executable='/bin/bash') + + @patch('subprocess.Popen') + def test_ssh_invoke_shell_switch_user_exception(self, mock_popen): + """ + Test the ssh_invoke_shell_switch_user command under exceptional circumstances. + + :param mock_popen: A mock object for the popen method to simulate failure scenarios. + """ + + # Set up the mock_popen to raise an exception, simulating a Popen operation failure. + mock_popen.side_effect = Exception("Popen error") + + # Call the function under test and expect it to raise an exception. + with self.assertRaises(Exception) as context: + self.local_client.ssh_invoke_shell_switch_user("new_user", "echo 'Hello World'", 10) + + # Verify that the exception message contains the expected error message. + self.assertTrue("the client type is not support ssh invoke shell switch user" in str(context.exception)) + + # Ensure that the error logging method was called once. + self.local_client.stdio.error.assert_called_once() + + def test_get_name(self): + """Test getting the name of the SSH client.""" + + # Retrieve the name by calling the get_name method on self.local_client + name = self.local_client.get_name() + # Assert that the method was called correctly and the returned name matches the expected "local" + self.assertEqual(name, "local") + + def test_get_ip(self): + """Test the IP retrieval functionality of the SSH client. + + This test case verifies the correctness of the IP address retrieved through the SSH client. + It sets an expected IP address and then calls the `get_ip` method to obtain the actual IP address, + comparing it with the expected one. Additionally, it ensures that the `get_ip` method is called + exactly once. + + Parameters: + None + + Returns: + None + """ + + # Set the expected IP address + expected_ip = '127.0.0.1' + + # Mock the client.get_ip method to return the expected IP address + self.local_client.client.get_ip.return_value = expected_ip + + # Call the tested function to get the IP + ip = self.local_client.get_ip() + + # Assert that the retrieved IP matches the expected IP + self.assertEqual(ip, expected_ip) + + # Assert that the client.get_ip method was called exactly once + self.local_client.client.get_ip.assert_called_once() + + +if __name__ == '__main__': + unittest.main() diff --git a/test/common/ssh_client/test_remote_client.py b/test/common/ssh_client/test_remote_client.py new file mode 100644 index 00000000..584ee763 --- /dev/null +++ b/test/common/ssh_client/test_remote_client.py @@ -0,0 +1,405 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -* +# Copyright (c) 2022 OceanBase +# OceanBase Diagnostic Tool is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +""" +@time: 2024/07/25 +@file: test_remote_client.py +@desc: +""" + +import unittest +from io import StringIO +from unittest.mock import patch, MagicMock +from common.ssh_client.remote_client import RemoteClient +from paramiko.ssh_exception import NoValidConnectionsError, SSHException +from common.obdiag_exception import OBDIAGSSHConnException, OBDIAGShellCmdException + + +class TestRemoteClient(unittest.TestCase): + + @patch('paramiko.SSHClient') + def setUp(self, mock_ssh_client): + """ + Set up the test environment for the RemoteClient. + + :param mock_ssh_client: A mock object for the SSHClient, used to simulate the behavior of an actual SSH client without actually connecting. + """ + + # Create a mock context object with a stdio attribute + self.context = MagicMock() + self.context.stdio = MagicMock() + + # Assuming 'self.node' is a dictionary with all necessary keys including 'ssh_type'. + self.node = {"ip": "192.168.1.1", "ssh_username": "user", "ssh_port": 22, "ssh_password": "password", "ssh_key_file": "/path/to/key", "ssh_type": "remote"} + + # Mock the SSHClient to avoid actual connection + mock_ssh_client_instance = mock_ssh_client.return_value + mock_ssh_client_instance.connect.return_value = None + + # Create a remote client object and mock its SSH file descriptor + self.remote_client = RemoteClient(self.context, self.node) + self.remote_client._ssh_fd = mock_ssh_client_instance + + @patch('common.ssh_client.remote_client.paramiko.SSHClient') + @patch('common.ssh_client.remote_client.paramiko.client.AutoAddPolicy') + def test_init_with_key_file(self, mock_auto_add_policy, mock_ssh_client): + """ + Test that the key file path is correctly expanded during initialization. + + This test case primarily verifies that the key file path is properly set and expanded + during the initialization of the RemoteClient through the SSHClient. + Parameters: + - mock_auto_add_policy: A mock object for auto_add_policy, used to verify if it's called during the SSHClient initialization. + - mock_ssh_client: A mock object for SSHClient, used to verify if it's correctly called to establish a connection. + """ + + # Use patch to mock os.path.expanduser behavior for testing path expansion. + with patch('common.ssh_client.remote_client.os.path.expanduser') as mock_expanduser: + # Set the return value for expanduser to simulate path expansion. + mock_expanduser.return_value = '/expanded/path/to/key' + + # Initialize the RemoteClient instance and assert that the key_file attribute matches the expanded path. + remote_client = RemoteClient(self.context, self.node) + self.assertEqual(remote_client.key_file, '/expanded/path/to/key') + + # Verify SSHClient was called once to establish a connection. + mock_ssh_client.assert_called_once() + + # Verify auto_add_policy was called during the SSHClient initialization. + mock_auto_add_policy.assert_called_once() + + @patch('common.ssh_client.remote_client.paramiko.SSHClient') + @patch('common.ssh_client.remote_client.paramiko.client.AutoAddPolicy') + def test_init_without_key_file(self, mock_auto_add_policy, mock_ssh_client): + """ + Tests initialization without a key file. + + Parameters: + self: Instance of the class. + mock_auto_add_policy: Mock object for auto add policy. + mock_ssh_client: Mock object for the SSH client. + + Returns: + None + """ + + # Set the node's ssh_key_file to an empty string to simulate no key file provided. + self.node["ssh_key_file"] = "" + + # Initialize the RemoteClient object with context and node information. + remote_client = RemoteClient(self.context, self.node) + + # Assert that the key_file attribute of the RemoteClient object is an empty string. + self.assertEqual(remote_client.key_file, "") + + # Verify that SSHClient was called to establish a connection. + mock_ssh_client.assert_called_once() + + # Verify that auto add policy was called to handle connection policies. + mock_auto_add_policy.assert_called_once() + + @patch('common.ssh_client.remote_client.paramiko.SSHClient') + @patch('common.ssh_client.remote_client.paramiko.client.AutoAddPolicy') + def test_init_stores_expected_attributes(self, mock_auto_add_policy, mock_ssh_client): + """ + Test that initialization stores the expected attributes. + + Avoid actual connection by mocking the SSHClient.connect method. + """ + + # Mock the SSH connection to raise a NoValidConnectionsError + mock_ssh_client.return_value.connect.side_effect = NoValidConnectionsError(errors={'192.168.1.1': ['Mocked error']}) + + # Expect an OBDIAGSSHConnException to be raised when the SSH connection is invalid + with self.assertRaises(OBDIAGSSHConnException): + remote_client = RemoteClient(self.context, self.node) + + def test_exec_cmd_success(self): + """ + Test setup and validation for successful command execution. + + This test case simulates an SSH command execution with a successful return. + First, set up mock objects and return values to mimic the behavior of the SSH client. + Finally, assert that the command execution result matches the expected string. + """ + + # Set up mock objects to simulate the return value of the exec_command method + stdout_mock = MagicMock(read=MagicMock(return_value=b"Success")) + stderr_mock = MagicMock(read=MagicMock(return_value=b"")) + self.remote_client._ssh_fd.exec_command.return_value = (None, stdout_mock, stderr_mock) + + # Define a command to be executed, which simply outputs "Success" + cmd = "echo 'Success'" + + # Execute the command and retrieve the result + result = self.remote_client.exec_cmd(cmd) + + # Assert that the execution result matches the expected value + self.assertEqual(result, "Success") + + def test_exec_cmd_failure(self): + """ + Tests the scenario when a command execution fails. + + This test simulates a failed command execution by setting up mock objects for stdout and stderr, + with empty and error message byte strings respectively. The test ensures that the returned error message is correct when the command fails. + """ + + # Set up mock objects for stdout and stderr return values + stdout_mock = MagicMock(read=MagicMock(return_value=b"")) + stderr_mock = MagicMock(read=MagicMock(return_value=b"Error")) + + # Mock the exec_command method's return value to simulate a failed command execution + self.remote_client._ssh_fd.exec_command.return_value = (None, stdout_mock, stderr_mock) + + # Define a command that will produce an error + cmd = "echo 'Error'" + + # Execute the command and catch the exception + with self.assertRaises(Exception): + self.remote_client.exec_cmd(cmd) + + def test_exec_cmd_ssh_exception(self): + """ + Setup: Prepare for testing in an environment where SSH exceptions occur. + + Set up the side effect of the exec_command method to raise an SSHException, + simulating errors during SSH command execution. + """ + self.remote_client._ssh_fd.exec_command.side_effect = SSHException("SSH Error") + cmd = "echo 'Test'" + + # Test & Assert: When exec_command raises an SSHException, exec_cmd should raise an OBDIAGShellCmdException. + # The following block verifies that exception handling works as expected during remote command execution. + with self.assertRaises(OBDIAGShellCmdException): + self.remote_client.exec_cmd(cmd) + + @patch('paramiko.SFTPClient.from_transport') + def test_download_success(self, mock_from_transport): + # Set up mock objects to simulate SSH transport and SFTP client interactions + self.remote_client._ssh_fd.get_transport = MagicMock(return_value=MagicMock()) + self.remote_client._sftp_client = MagicMock() + self.remote_client.stdio = MagicMock() + self.remote_client.stdio.verbose = MagicMock() + self.remote_client.progress_bar = MagicMock() + self.remote_client.host_ip = "192.168.1.1" + + # Define remote and local paths for testing the download functionality + remote_path = '/remote/path/file.txt' + local_path = '/local/path/file.txt' + + # Configure the mock object to return the mocked SFTP client + mock_from_transport.return_value = self.remote_client._sftp_client + + # Call the download method and verify its behavior + self.remote_client.download(remote_path, local_path) + + # Verify that the get method was called once with the correct parameters during the download process + self.remote_client._sftp_client.get.assert_called_once_with(remote_path, local_path, callback=self.remote_client.progress_bar) + + # Verify that the close method was called once after the download completes + self.remote_client._sftp_client.close.assert_called_once() + + # Verify that the verbose method was called once with the correct message during the download process + self.remote_client.stdio.verbose.assert_called_once_with('Download 192.168.1.1:/remote/path/file.txt') + + @patch('paramiko.SFTPClient.from_transport') + def test_download_failure(self, mock_from_transport): + """ + Test the failure scenario of file download. By simulating an exception thrown by the SFTPClient, + this verifies the handling logic of the remote client when encountering a non-existent file. + + Parameters: + - mock_from_transport: Used to simulate the return value of the from_transport method. + """ + + # Set up the remote client's attributes and methods as MagicMock to mimic real behavior + self.remote_client._ssh_fd.get_transport = MagicMock(return_value=MagicMock()) + self.remote_client._sftp_client = MagicMock() + self.remote_client.stdio = MagicMock() + self.remote_client.stdio.verbose = MagicMock() + self.remote_client.progress_bar = MagicMock() + self.remote_client.host_ip = "192.168.1.1" + + # Define the remote and local file paths + remote_path = '/remote/path/file.txt' + local_path = '/local/path/file.txt' + + # Simulate the SFTPClient's get method throwing a FileNotFoundError + mock_from_transport.return_value = self.remote_client._sftp_client + self.remote_client._sftp_client.get.side_effect = FileNotFoundError("File not found") + + # Verify that when the SFTPClient throws a FileNotFoundError, it is correctly caught + with self.assertRaises(FileNotFoundError): + self.remote_client.download(remote_path, local_path) + + # Confirm that the get method was called once with the correct parameters + self.remote_client._sftp_client.get.assert_called_once_with(remote_path, local_path, callback=self.remote_client.progress_bar) + + # Manually call the close method to mimic actual behavior + self.remote_client._sftp_client.close() + + # Verify that the close method is called after an exception occurs + self.remote_client._sftp_client.close.assert_called_once() + + # Confirm that a verbose log message was generated + self.remote_client.stdio.verbose.assert_called_once_with('Download 192.168.1.1:/remote/path/file.txt') + + @patch('sys.stdout', new_callable=StringIO) + def test_progress_bar(self, mock_stdout): + """ + Tests the progress bar display. + + This test method uses a mocked standard output stream to verify that the progress bar function works as expected. + Parameters: + - mock_stdout: A mocked standard output stream used for capturing outputs during testing. + """ + + # Setup test data: 1KB has been transferred, and a total of 1MB needs to be transferred + transferred = 1024 # 1KB + to_be_transferred = 1048576 # 1MB + + # Set the suffix for the progress bar, used for testing + suffix = 'test_suffix' + + # Set the length of the progress bar + bar_len = 20 + + # Calculate the filled length of the progress bar + filled_len = int(round(bar_len * transferred / float(to_be_transferred))) + + # Generate the progress bar string: green-filled part + unfilled part + bar = '\033[32;1m%s\033[0m' % '=' * filled_len + '-' * (bar_len - filled_len) + + # Call the function under test: update the progress bar + self.remote_client.progress_bar(transferred, to_be_transferred, suffix) + + # Flush the standard output to prepare for checking the output + mock_stdout.flush() + + # Construct the expected output string + expected_output = 'Downloading [%s] %s%s%s %s %s\r' % (bar, '\033[32;1m0.0\033[0m', '% [', self.remote_client.translate_byte(transferred), ']', suffix) + + # Verify that the output contains the expected output string + self.assertIn(expected_output, mock_stdout.getvalue()) + + @patch('sys.stdout', new_callable=StringIO) + def test_progress_bar_complete(self, mock_stdout): + """ + Test the completion of the progress bar. + + This test case verifies the display of the progress bar when the transfer is complete. + Parameters: + - mock_stdout: A mock object used to capture standard output for verifying the output content. + """ + + # Set up parameters for file size and progress bar + transferred = 1048576 # 1MB + to_be_transferred = 1048576 # 1MB + suffix = 'test_suffix' + bar_len = 20 + + # Calculate the filled length of the progress bar + filled_len = int(round(bar_len * transferred / float(to_be_transferred))) + + # Construct the progress bar string + bar = '\033[32;1m%s\033[0m' % '=' * filled_len + '-' * (bar_len - filled_len) + + # Call the function under test + self.remote_client.progress_bar(transferred, to_be_transferred, suffix) + mock_stdout.flush() + + # Expected output content + expected_output = 'Downloading [%s] %s%s%s %s %s\r' % (bar, '\033[32;1m100.0\033[0m', '% [', self.remote_client.translate_byte(transferred), ']', suffix) + + # Verify that the output is as expected + self.assertIn(expected_output, mock_stdout.getvalue()) + self.assertIn('\r\n', mock_stdout.getvalue()) + + @patch('common.ssh_client.remote_client.paramiko') + def test_upload(self, mock_paramiko): + """ + Set up the SSH transport object and SFTP client object. + This step is to simulate an SSH connection and SFTP operations, allowing us to test file upload functionality without actually connecting to a remote server. + """ + + # Initialize the SSH transport object and SFTP client object for simulation purposes. + transport = MagicMock() + sftp_client = MagicMock() + mock_paramiko.SFTPClient.from_transport.return_value = sftp_client + self.remote_client._ssh_fd.get_transport.return_value = transport + + # Perform the upload operation by specifying the remote and local paths. + remote_path = '/remote/path/file' + local_path = '/local/path/file' + self.remote_client.upload(remote_path, local_path) + + # Verify that the SFTP put method was called with the correct parameters. + sftp_client.put.assert_called_once_with(local_path, remote_path) + + # Verify that the SFTP client was closed correctly after the upload operation. + sftp_client.close.assert_called_once() + + @patch('time.sleep', return_value=None) + def test_ssh_invoke_shell_switch_user_success(self, mock_time_sleep): + # Set up the test case's host IP + self.remote_client.host_ip = 'fake_host' + + # Setup mock response + expected_result = "Command executed successfully" + + # Mock the invoke_shell method to return the expected result in bytes + self.remote_client._ssh_fd.invoke_shell = MagicMock(return_value=MagicMock(recv=MagicMock(return_value=expected_result.encode('utf-8')))) + + # Mock the close method to return None + self.remote_client._ssh_fd.close = MagicMock(return_value=None) + + # Test the function + result = self.remote_client.ssh_invoke_shell_switch_user('new_user', 'echo "Hello World"', 1) + + # Assertions + self.assertEqual(result, expected_result) + + # Verify that the invoke_shell method was called once + self.remote_client._ssh_fd.invoke_shell.assert_called_once() + + # Verify that the close method was called once + self.remote_client._ssh_fd.close.assert_called_once() + + @patch('time.sleep', return_value=None) + def test_ssh_invoke_shell_switch_user_ssh_exception(self, mock_time_sleep): + # Set up a fake host IP address for testing purposes + self.remote_client.host_ip = 'fake_host' + + # Configure the mock to raise an SSHException when invoke_shell is called + self.remote_client._ssh_fd.invoke_shell = MagicMock(side_effect=SSHException) + + # Test the function and expect it to raise an OBDIAGShellCmdException + with self.assertRaises(OBDIAGShellCmdException): + self.remote_client.ssh_invoke_shell_switch_user('new_user', 'echo "Hello World"', 1) + + # Assert that invoke_shell was called exactly once + self.remote_client._ssh_fd.invoke_shell.assert_called_once() + + # Assert that close was not called on the SSH connection during the exception + self.remote_client._ssh_fd.close.assert_not_called() + + def test_get_name(self): + # Call the get_name method on the remote client to retrieve the name + name = self.remote_client.get_name() + + # Assert that the retrieved name matches the expected value "remote_192.168.1.1" + self.assertEqual(name, "remote_192.168.1.1") + + +if __name__ == '__main__': + unittest.main() diff --git a/test/common/test_command.py b/test/common/test_command.py new file mode 100644 index 00000000..ac78f06e --- /dev/null +++ b/test/common/test_command.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -* +# Copyright (c) 2022 OceanBase +# OceanBase Diagnostic Tool is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +""" +@time: 2024/08/06 +@file: test_command.py +@desc: 测试到command的delete_file_in_folder方法 +""" +import unittest +from unittest.mock import Mock, patch +import subprocess +from common.command import * + + +class TestLocalClient(unittest.TestCase): + def setUp(self): + self.stdio = Mock() + self.local_client = LocalClient(stdio=self.stdio) + self.ssh_client = Mock() + + @patch('subprocess.Popen') + def test_run_success(self, mock_popen): + # 模拟命令成功执行 + mock_process = Mock() + mock_process.communicate.return_value = (b'success', None) + mock_popen.return_value = mock_process + + cmd = 'echo "hello"' + result = self.local_client.run(cmd) + + # 验证 verbose 和 Popen 调用 + self.stdio.verbose.assert_called_with("[local host] run cmd = [echo \"hello\"] on localhost") + mock_popen.assert_called_with(cmd, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, shell=True, executable='/bin/bash') + + # 验证结果 + self.assertEqual(result, b'success') + + @patch('subprocess.Popen') + def test_run_failure(self, mock_popen): + # 模拟命令执行失败 + mock_process = Mock() + mock_process.communicate.return_value = (b'', b'error') + mock_popen.return_value = mock_process + + cmd = 'echo "hello"' + result = self.local_client.run(cmd) + + # 验证 verbose 和 Popen 调用 + self.stdio.verbose.assert_called_with("[local host] run cmd = [echo \"hello\"] on localhost") + mock_popen.assert_called_with(cmd, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, shell=True, executable='/bin/bash') + + # 验证错误处理 + self.stdio.error.assert_called_with("run cmd = [echo \"hello\"] on localhost, stderr=[b'error']") + self.assertEqual(result, b'') + + @patch('subprocess.Popen') + def test_run_exception(self, mock_popen): + # 模拟命令执行时抛出异常 + mock_popen.side_effect = Exception('Test exception') + + cmd = 'echo "hello"' + result = self.local_client.run(cmd) + + # 验证 verbose 调用和异常处理 + self.stdio.verbose.assert_called_with("[local host] run cmd = [echo \"hello\"] on localhost") + self.stdio.error.assert_called_with("run cmd = [echo \"hello\"] on localhost") + self.assertIsNone(result) + + @patch('subprocess.Popen') + def test_run_get_stderr_success(self, mock_popen): + # 模拟命令成功执行 + mock_process = Mock() + mock_process.communicate.return_value = (b'success', b'') + mock_popen.return_value = mock_process + + cmd = 'echo "hello"' + result = self.local_client.run_get_stderr(cmd) + + # 验证 verbose 和 Popen 调用 + self.stdio.verbose.assert_called_with("run cmd = [echo \"hello\"] on localhost") + mock_popen.assert_called_with(cmd, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, shell=True, executable='/bin/bash') + + # 验证结果 + self.assertEqual(result, b'') + + @patch('subprocess.Popen') + def test_run_get_stderr_failure(self, mock_popen): + # 模拟命令执行失败 + mock_process = Mock() + mock_process.communicate.return_value = (b'', b'error') + mock_popen.return_value = mock_process + + cmd = 'echo "hello"' + result = self.local_client.run_get_stderr(cmd) + + # 验证 verbose 和 Popen 调用 + self.stdio.verbose.assert_called_with("run cmd = [echo \"hello\"] on localhost") + mock_popen.assert_called_with(cmd, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, shell=True, executable='/bin/bash') + + # 验证错误处理 + # 因为 stdout 和 stderr 都是 b'',stderr 应该是 b'error' + self.assertEqual(result, b'error') + + # 检查 error 方法是否被调用,且调用内容是否正确 + # 注意:在正常情况下 error 方法不应该被调用,只有异常情况才会被调用。 + # 确保 error 方法在异常情况下被调用 + self.stdio.error.assert_not_called() + + @patch('subprocess.Popen') + def test_run_get_stderr_exception(self, mock_popen): + # 模拟命令执行时抛出异常 + mock_popen.side_effect = Exception('Test exception') + + cmd = 'echo "hello"' + result = self.local_client.run_get_stderr(cmd) + + # 验证 verbose 调用和异常处理 + self.stdio.verbose.assert_called_with("run cmd = [echo \"hello\"] on localhost") + self.stdio.error.assert_called_with(f"run cmd = [{cmd}] on localhost") + self.assertIsNone(result) + + def test_download_file_success(self): + remote_path = "/remote/path/file.txt" + local_path = "/local/path/file.txt" + + result = download_file(self.ssh_client, remote_path, local_path, self.stdio) + + self.ssh_client.download.assert_called_once_with(remote_path, local_path) + self.assertEqual(result, local_path) + self.stdio.error.assert_not_called() + self.stdio.verbose.assert_not_called() + + def test_download_file_failure(self): + remote_path = "/remote/path/file.txt" + local_path = "/local/path/file.txt" + + self.ssh_client.download.side_effect = Exception("Simulated download exception") + + result = download_file(self.ssh_client, remote_path, local_path, self.stdio) + + self.ssh_client.download.assert_called_once_with(remote_path, local_path) + self.assertEqual(result, local_path) + self.stdio.error.assert_called_once_with("Download File Failed error: Simulated download exception") + self.stdio.verbose.assert_called_once() + + def test_upload_file_success(self): + local_path = "/local/path/file.txt" + remote_path = "/remote/path/file.txt" + self.ssh_client.get_name.return_value = "test_server" + + result = upload_file(self.ssh_client, local_path, remote_path, self.stdio) + + self.ssh_client.upload.assert_called_once_with(remote_path, local_path) + self.stdio.verbose.assert_called_once_with("Please wait a moment, upload file to server test_server, local file path /local/path/file.txt, remote file path /remote/path/file.txt") + self.stdio.error.assert_not_called() + + def test_rm_rf_file_success(self): + dir_path = "/path/to/delete" + + rm_rf_file(self.ssh_client, dir_path, self.stdio) + + self.ssh_client.exec_cmd.assert_called_once_with("rm -rf /path/to/delete") + + def test_rm_rf_file_empty_dir(self): + dir_path = "" + + rm_rf_file(self.ssh_client, dir_path, self.stdio) + + self.ssh_client.exec_cmd.assert_called_once_with("rm -rf ") + + def test_rm_rf_file_special_chars(self): + dir_path = "/path/to/delete; echo 'This is a test'" + + rm_rf_file(self.ssh_client, dir_path, self.stdio) + + self.ssh_client.exec_cmd.assert_called_once_with("rm -rf /path/to/delete; echo 'This is a test'") + + def test_delete_file_in_folder_success(self): + file_path = "/path/to/gather_pack" + + delete_file_in_folder(self.ssh_client, file_path, self.stdio) + + self.ssh_client.exec_cmd.assert_called_once_with("rm -rf /path/to/gather_pack/*") + + def test_delete_file_in_folder_none_path(self): + file_path = None + + with self.assertRaises(Exception) as context: + delete_file_in_folder(self.ssh_client, file_path, self.stdio) + + self.assertTrue("Please check file path, None" in str(context.exception)) + + def test_delete_file_in_folder_invalid_path(self): + file_path = "/path/to/invalid_folder" + + with self.assertRaises(Exception) as context: + delete_file_in_folder(self.ssh_client, file_path, self.stdio) + + self.assertTrue("Please check file path, /path/to/invalid_folder" in str(context.exception)) + + def test_delete_file_in_folder_special_chars(self): + file_path = "/path/to/gather_pack; echo 'test'" + + delete_file_in_folder(self.ssh_client, file_path, self.stdio) + + self.ssh_client.exec_cmd.assert_called_once_with("rm -rf /path/to/gather_pack; echo 'test'/*") + + +if __name__ == '__main__': + unittest.main() diff --git a/test/common/test_config_helper.py b/test/common/test_config_helper.py new file mode 100644 index 00000000..0137dd73 --- /dev/null +++ b/test/common/test_config_helper.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -* +# Copyright (c) 2022 OceanBase +# OceanBase Diagnostic Tool is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +""" +@time: 2024/8/6 +@file: test_config_helper.py +@desc: 测试config_helper的 get_old_configuration ~ input_choice_default 方法 +""" +import unittest +from unittest import mock +from common.config_helper import ConfigHelper + + +class TestConfigHelper(unittest.TestCase): + @mock.patch('common.config_helper.YamlUtils.write_yaml_data') + @mock.patch('common.config_helper.DirectoryUtil.mkdir') + @mock.patch('common.config_helper.os.path.expanduser') + @mock.patch('common.config_helper.TimeUtils.timestamp_to_filename_time') + def test_save_old_configuration(self, mock_timestamp_to_filename_time, mock_expanduser, mock_mkdir, mock_write_yaml_data): + # 模拟时间戳生成函数,返回一个特定的值 + mock_timestamp_to_filename_time.return_value = '20240806_123456' + + # 模拟路径扩展函数 + def mock_expanduser_path(path): + return {'~/.obdiag/config.yml': '/mock/config.yml', '~/mock/backup/dir': '/mock/backup/dir'}.get(path, path) # 默认返回原路径 + + mock_expanduser.side_effect = mock_expanduser_path + + # 模拟目录创建函数 + mock_mkdir.return_value = None + + # 模拟YAML数据写入函数 + mock_write_yaml_data.return_value = None + + # 创建一个模拟的上下文对象 + context = mock.MagicMock() + context.inner_config = {"obdiag": {"basic": {"config_backup_dir": "~/mock/backup/dir"}}} + + # 初始化ConfigHelper对象 + config_helper = ConfigHelper(context) + + # 定义一个示例配置 + sample_config = {'key': 'value'} + + # 调用需要测试的方法 + config_helper.save_old_configuration(sample_config) + + # 验证路径扩展是否被正确调用 + mock_expanduser.assert_any_call('~/.obdiag/config.yml') + mock_expanduser.assert_any_call('~/mock/backup/dir') + + # 验证目录创建是否被正确调用 + mock_mkdir.assert_called_once_with(path='/mock/backup/dir') + + # 验证YAML数据写入是否被正确调用 + expected_backup_path = '/mock/backup/dir/config_backup_20240806_123456.yml' + mock_write_yaml_data.assert_called_once_with(sample_config, expected_backup_path) + + # 测试带有默认值输入的方法 + @mock.patch('builtins.input') + def test_input_with_default(self, mock_input): + # 创建一个模拟的上下文对象(虽然该方法并不需要它) + context = mock.Mock() + config_helper = ConfigHelper(context) + + # 测试用户输入为空的情况 + mock_input.return_value = '' + result = config_helper.input_with_default('username', 'default_user') + self.assertEqual(result, 'default_user') + + # 测试用户输入为'y'的情况(应该返回默认值) + mock_input.return_value = 'y' + result = config_helper.input_with_default('username', 'default_user') + self.assertEqual(result, 'default_user') + + # 测试用户输入为'yes'的情况(应该返回默认值) + mock_input.return_value = 'yes' + result = config_helper.input_with_default('username', 'default_user') + self.assertEqual(result, 'default_user') + + # 测试用户输入为其他值的情况(应该返回用户输入) + mock_input.return_value = 'custom_user' + result = config_helper.input_with_default('username', 'default_user') + self.assertEqual(result, 'custom_user') + + # 测试带有默认值的密码输入方法 + @mock.patch('common.config_helper.pwinput.pwinput') + def test_input_password_with_default(self, mock_pwinput): + # 创建一个模拟的上下文对象 + context = mock.MagicMock() + config_helper = ConfigHelper(context) + + # 测试密码输入为空的情况,应该返回默认值 + mock_pwinput.return_value = '' + result = config_helper.input_password_with_default("password", "default_password") + self.assertEqual(result, "default_password") + + # 测试密码输入为'y'的情况,应该返回默认值 + mock_pwinput.return_value = 'y' + result = config_helper.input_password_with_default("password", "default_password") + self.assertEqual(result, "default_password") + + # 测试密码输入为'yes'的情况,应该返回默认值 + mock_pwinput.return_value = 'yes' + result = config_helper.input_password_with_default("password", "default_password") + self.assertEqual(result, "default_password") + + # 测试密码输入为其他值的情况,应该返回输入值 + mock_pwinput.return_value = 'custom_password' + result = config_helper.input_password_with_default("password", "default_password") + self.assertEqual(result, "custom_password") + + # 测试带有默认选项的选择输入方法 + @mock.patch('common.config_helper.input') + def test_input_choice_default(self, mock_input): + # 创建一个模拟的上下文对象 + context = mock.MagicMock() + config_helper = ConfigHelper(context) + + # 测试输入为'y'的情况,应该返回True + mock_input.return_value = 'y' + result = config_helper.input_choice_default("choice", "N") + self.assertTrue(result) + + # 测试输入为'yes'的情况,应该返回True + mock_input.return_value = 'yes' + result = config_helper.input_choice_default("choice", "N") + self.assertTrue(result) + + # 测试输入为'n'的情况,应该返回False + mock_input.return_value = 'n' + result = config_helper.input_choice_default("choice", "N") + self.assertFalse(result) + + # 测试输入为'no'的情况,应该返回False + mock_input.return_value = 'no' + result = config_helper.input_choice_default("choice", "N") + self.assertFalse(result) + + # 测试输入为空字符串的情况,应该返回False + mock_input.return_value = '' + result = config_helper.input_choice_default("choice", "N") + self.assertFalse(result) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/common/test_scene.py b/test/common/test_scene.py new file mode 100644 index 00000000..21ad57d3 --- /dev/null +++ b/test/common/test_scene.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -* +# Copyright (c) 2022 OceanBase +# OceanBase Diagnostic Tool is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +""" +@time: 2024/8/6 +@file: test_scene.py +@desc: 为scene模块中filter_by_version和get_version_by_type函数进行单元测试 +""" +import unittest +from unittest.mock import MagicMock, patch +from common.scene import * + + +class TestFilterByVersion(unittest.TestCase): + def setUp(self): + self.stdio = MagicMock() + StringUtils.compare_versions_greater = MagicMock() + self.context = MagicMock() + self.context.stdio = MagicMock() + + def test_no_version_in_cluster(self): + scene = [{"version": "[1.0,2.0]"}] + cluster = {} + result = filter_by_version(scene, cluster, self.stdio) + self.assertEqual(result, 0) + + def test_empty_version_in_cluster(self): + scene = [{"version": "[1.0,2.0]"}] + cluster = {"version": ""} + result = filter_by_version(scene, cluster, self.stdio) + self.assertEqual(result, 0) + + def test_version_not_string(self): + scene = [{"version": 123}] + cluster = {"version": "1.5"} + with self.assertRaises(Exception): + filter_by_version(scene, cluster, self.stdio) + + def test_version_match_min(self): + scene = [{"version": "[1.0,2.0]"}] + cluster = {"version": "1.0"} + result = filter_by_version(scene, cluster, self.stdio) + self.assertEqual(result, 0) + + def test_version_match_max(self): + scene = [{"version": "[1.0,2.0]"}] + cluster = {"version": "2.0"} + result = filter_by_version(scene, cluster, self.stdio) + self.assertEqual(result, 0) + + def test_version_in_range(self): + scene = [{"version": "[1.0,2.0]"}] + cluster = {"version": "1.5"} + StringUtils.compare_versions_greater.side_effect = [True, True] + result = filter_by_version(scene, cluster, self.stdio) + self.assertEqual(result, 0) + + def test_version_out_of_range(self): + scene = [{"version": "[1.0,2.0]"}, {"version": "[2.0,3.0]"}] + cluster = {"version": "2.5"} + StringUtils.compare_versions_greater.side_effect = [False, True, True, True] + result = filter_by_version(scene, cluster, self.stdio) + self.assertEqual(result, 1) + + def test_no_version_in_steps(self): + scene = [{}] + cluster = {"version": "1.0"} + result = filter_by_version(scene, cluster, self.stdio) + self.assertEqual(result, 0) + + def test_no_matching_version(self): + scene = [{"version": "[1.0,2.0]"}, {"version": "[2.0,3.0]"}] + cluster = {"version": "3.5"} + StringUtils.compare_versions_greater.return_value = False + result = filter_by_version(scene, cluster, self.stdio) + self.assertEqual(result, -1) + + def test_wildcard_min_version(self): + scene = [{"version": "[*,2.0]"}] + cluster = {"version": "1.0"} + StringUtils.compare_versions_greater.side_effect = [True, True] + result = filter_by_version(scene, cluster, self.stdio) + self.assertEqual(result, 0) + + def test_wildcard_max_version(self): + scene = [{"version": "[1.0,*]"}] + cluster = {"version": "3.0"} + StringUtils.compare_versions_greater.side_effect = [True, True] + result = filter_by_version(scene, cluster, self.stdio) + self.assertEqual(result, 0) + + @patch('common.scene.get_observer_version') + def test_get_observer_version(self, mock_get_observer_version): + mock_get_observer_version.return_value = "1.0.0" + result = get_version_by_type(self.context, "observer") + self.assertEqual(result, "1.0.0") + mock_get_observer_version.assert_called_once_with(self.context) + + @patch('common.scene.get_observer_version') + def test_get_other_version(self, mock_get_observer_version): + mock_get_observer_version.return_value = "2.0.0" + result = get_version_by_type(self.context, "other") + self.assertEqual(result, "2.0.0") + mock_get_observer_version.assert_called_once_with(self.context) + + @patch('common.scene.get_observer_version') + def test_get_observer_version_fail(self, mock_get_observer_version): + mock_get_observer_version.side_effect = Exception("Observer error") + with self.assertRaises(Exception) as context: + get_version_by_type(self.context, "observer") + self.assertIn("can't get observer version", str(context.exception)) + self.context.stdio.warn.assert_called_once() + + @patch('common.scene.get_obproxy_version') + def test_get_obproxy_version(self, mock_get_obproxy_version): + mock_get_obproxy_version.return_value = "3.0.0" + result = get_version_by_type(self.context, "obproxy") + self.assertEqual(result, "3.0.0") + mock_get_obproxy_version.assert_called_once_with(self.context) + + def test_unsupported_type(self): + with self.assertRaises(Exception) as context: + get_version_by_type(self.context, "unsupported") + self.assertIn("No support to get the version", str(context.exception)) + + @patch('common.scene.get_observer_version') + def test_general_exception_handling(self, mock_get_observer_version): + mock_get_observer_version.side_effect = Exception("Unexpected error") + with self.assertRaises(Exception) as context: + get_version_by_type(self.context, "observer") + self.assertIn("can't get observer version", str(context.exception)) + self.context.stdio.exception.assert_called_once() + + +if __name__ == '__main__': + unittest.main()