diff --git a/config/main.py b/config/main.py index 4e2ae68de6..baa71732c2 100644 --- a/config/main.py +++ b/config/main.py @@ -278,6 +278,40 @@ def interface_alias_to_name(config_db, interface_alias): # portchannel is passed in as argument, which does not have an alias return interface_alias if sub_intf_sep_idx == -1 else interface_alias + VLAN_SUB_INTERFACE_SEPARATOR + vlan_id +def loopback_name_is_valid(config_db, loopback_name): + """Check if the loopback name is valid + """ + # If the input parameter config_db is None, try DEFAULT_NAMESPACE. + if config_db is None: + namespace = DEFAULT_NAMESPACE + config_db = ConfigDBConnector(use_unix_socket_path=True, namespace=namespace) + + config_db.connect() + loopback_dict = config_db.get_table('LOOPBACK_INTERFACE') + + if loopback_name is not None and loopback_dict: + for loopback_dict_keys in loopback_dict.keys(): + if loopback_name == loopback_dict_keys: + return True + return False + +def vlan_name_is_valid(config_db, vlan_name): + """Check if the vlan name is valid + """ + # If the input parameter config_db is None, try DEFAULT_NAMESPACE. + if config_db is None: + namespace = DEFAULT_NAMESPACE + config_db = ConfigDBConnector(use_unix_socket_path=True, namespace=namespace) + + config_db.connect() + vlan_dict = config_db.get_table('VLAN') + + if vlan_name is not None and vlan_dict: + for vlan_dict_keys in vlan_dict.keys(): + if vlan_name == vlan_dict_keys: + return True + return False + def interface_name_is_valid(config_db, interface_name): """Check if the interface name is valid """ @@ -3823,6 +3857,11 @@ def add(ctx, interface_name, ip_addr, gw): return + if not (interface_name_is_valid(config_db, interface_name) + or vlan_name_is_valid(config_db, interface_name) + or loopback_name_is_valid(config_db, interface_name)): + ctx.fail("'interface_name' is not valid. Valid names [Ethernet/PortChannel/Vlan/Loopback]") + table_name = get_interface_table_name(interface_name) if table_name == "": ctx.fail("'interface_name' is not valid. Valid names [Ethernet/PortChannel/Vlan/Loopback]") diff --git a/tests/config_int_ip_test.py b/tests/config_int_ip_test.py index 6968fcbe45..bd5efedf28 100644 --- a/tests/config_int_ip_test.py +++ b/tests/config_int_ip_test.py @@ -14,6 +14,10 @@ sys.path.insert(0, test_path) mock_db_path = os.path.join(test_path, "int_ip_input") +ERROR_MSG_WRONG_INTERFACE_NAME = '''Usage: add [OPTIONS] +Try "add --help" for help. + +Error: \'interface_name\' is not valid. Valid names [Ethernet/PortChannel/Vlan/Loopback]''' class TestIntIp(object): @pytest.fixture(scope="class", autouse=True) @@ -155,4 +159,40 @@ def test_config_int_ip_rem_static_multiasic( print(result.exit_code, result.output) assert result.exit_code != 0 assert "Error: Cannot remove the last IP entry of interface Ethernet8. A static ipv6 route is still bound to the RIF." in result.output - assert mock_run_command.call_count == 0 \ No newline at end of file + assert mock_run_command.call_count == 0 + + +class TestConfigIP_wrong_name(object): + @classmethod + def setup_class(cls): + os.environ['UTILITIES_UNIT_TESTING'] = "1" + print("SETUP") + + def test_add_interface_invalid_vlan(self): + db = Db() + runner = CliRunner() + obj = {'config_db':db.cfgdb} + import config.main as config + + #Try to set wrong VLAN: config int ip add Vlan100500 100.50.20.1/24 + result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["add"], ["Vlan100500", "100.50.20.1/24"], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code != 0 + assert ERROR_MSG_WRONG_INTERFACE_NAME in result.output + + def test_add_interface_invalid_name(self): + db = Db() + runner = CliRunner() + obj = {'config_db':db.cfgdb} + import config.main as config + + #Try to set IP on wrong interface: config int ip add Ethernet2abc 100.50.20.1/24 + result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["add"], ["Ethernet2abc", "100.50.20.1/24"], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code != 0 + assert ERROR_MSG_WRONG_INTERFACE_NAME in result.output + + @classmethod + def teardown_class(cls): + os.environ['UTILITIES_UNIT_TESTING'] = "0" + print("TEARDOWN")