diff --git a/src/cnaas_nms/api/generic.py b/src/cnaas_nms/api/generic.py index 1c4fb22c..356d1a79 100644 --- a/src/cnaas_nms/api/generic.py +++ b/src/cnaas_nms/api/generic.py @@ -1,8 +1,11 @@ import re +from typing import List from flask import request import sqlalchemy +from cnaas_nms.db.settings import get_pydantic_error_value, get_pydantic_field_descr + FILTER_RE = re.compile(r"^filter\[([a-zA-Z0-9_.]+)\](\[[a-z]+\])?$") DEFAULT_PER_PAGE = 50 @@ -127,3 +130,40 @@ def empty_result(status='success', data=None): 'status': status, 'message': data if data else "Unknown error" } + + +def parse_pydantic_error(e: Exception, schema, data: dict) -> List[str]: + errors = [] + for num, error in enumerate(e.errors()): + loc = error['loc'] + origin = 'unknown' + errors.append("Validation error for setting {}, bad value: {}".format( + '->'.join(str(x) for x in loc), + get_pydantic_error_value(data, loc) + )) + try: + pydantic_descr = get_pydantic_field_descr(schema.schema(), loc) + if pydantic_descr: + pydantic_descr_msg = ", field should be: {}".format(pydantic_descr) + else: + pydantic_descr_msg = "" + except Exception as e_pydantic_descr: + pydantic_descr_msg = ", exception while getting pydantic description" + errors.append("Message: {}{}".format(error['msg'], pydantic_descr_msg)) + return errors + + +def update_sqla_object(instance, new_data: dict) -> bool: + """Update SQLalchemy object instance with data dict. + Returns: + Returns True if any values were changed + """ + changed = False + for k, v in new_data.items(): + try: + if getattr(instance, k) != v: + setattr(instance, k, v) + changed = True + except AttributeError: + continue + return changed diff --git a/src/cnaas_nms/api/mgmtdomain.py b/src/cnaas_nms/api/mgmtdomain.py index 0c51eb20..1f14ec34 100644 --- a/src/cnaas_nms/api/mgmtdomain.py +++ b/src/cnaas_nms/api/mgmtdomain.py @@ -1,7 +1,11 @@ +from typing import Optional + +from sqlalchemy.exc import IntegrityError from flask import request from flask_restx import Resource, Namespace, fields from flask_jwt_extended import jwt_required - +from pydantic import BaseModel, validator +from pydantic.error_wrappers import ValidationError from ipaddress import IPv4Interface from cnaas_nms.api.generic import build_filter, empty_result, limit_results @@ -9,6 +13,8 @@ from cnaas_nms.db.mgmtdomain import Mgmtdomain from cnaas_nms.db.session import sqla_session from cnaas_nms.version import __api_version__ +from cnaas_nms.db.settings_fields import vlan_id_schema_optional +from cnaas_nms.api.generic import parse_pydantic_error, update_sqla_object mgmtdomains_api = Namespace('mgmtdomains', description='API for handling management domains', @@ -21,9 +27,33 @@ 'device_b': fields.String(required=True), 'vlan': fields.Integer(required=True), 'ipv4_gw': fields.String(required=True), + 'description': fields.String(required=False), }) +class f_mgmtdomain(BaseModel): + vlan: Optional[int] = vlan_id_schema_optional + ipv4_gw: Optional[str] = None + description: Optional[str] = None + + @validator('ipv4_gw') + def ipv4_gw_valid_address(cls, v, values, **kwargs): + try: + addr = IPv4Interface(v) + prefix_len = int(addr.network.prefixlen) + except: + raise ValueError('Invalid ipv4_gw received. Must be correct IPv4 address with mask') + else: + if addr.ip == addr.network.network_address: + raise ValueError("Specify gateway address, not subnet address") + if addr.ip == addr.network.broadcast_address: + raise ValueError("Specify gateway address, not broadcast address") + if prefix_len >= 31 or prefix_len <= 16: + raise ValueError("Bad prefix length {} for management network".format(prefix_len)) + + return v + + class MgmtdomainByIdApi(Resource): @jwt_required def get(self, mgmtdomain_id): @@ -60,38 +90,28 @@ def delete(self, mgmtdomain_id): def put(self, mgmtdomain_id): """ Modify management domain """ json_data = request.get_json() - data = {} errors = [] - if 'vlan' in json_data: - try: - vlan_id_int = int(json_data['vlan']) - except: - errors.append('Invalid VLAN received.') - else: - data['vlan'] = vlan_id_int - if 'ipv4_gw' in json_data: - try: - addr = IPv4Interface(json_data['ipv4_gw']) - prefix_len = int(addr.network.prefixlen) - except: - errors.append('Invalid ipv4_gw received. Must be correct IPv4 address with mask') - else: - if prefix_len <= 31 and prefix_len >= 16: - data['ipv4_gw'] = str(addr) - else: - errors.append("Bad prefix length for management network: {}".format( - prefix_len)) + try: + f_mgmtdomain(**json_data).dict() + except ValidationError as e: + errors += parse_pydantic_error(e, f_mgmtdomain, json_data) + + if errors: + return empty_result('error', errors), 400 + with sqla_session() as session: instance: Mgmtdomain = session.query(Mgmtdomain).\ filter(Mgmtdomain.id == mgmtdomain_id).one_or_none() if instance: - instance.device_a.synchronized = False - instance.device_b.synchronized = False - #TODO: auto loop through class members and match - if 'vlan' in data: - instance.vlan = data['vlan'] - if 'ipv4_gw' in data: - instance.ipv4_gw = data['ipv4_gw'] + changed: bool = update_sqla_object(instance, json_data) + if changed: + instance.device_a.synchronized = False + instance.device_b.synchronized = False + return empty_result(status='success', data={"updated_mgmtdomain": instance.as_dict()}), 200 + else: + return empty_result(status='success', data={"unchanged_mgmtdomain": instance.as_dict()}), 200 + else: + return empty_result(status='error', data="mgmtdomain not found"), 400 class MgmtdomainsApi(Resource): @@ -142,35 +162,30 @@ def post(self): errors.append(f"Device with hostname {hostname_b} not found") else: data['device_b'] = device_b - if 'vlan' in json_data: - try: - vlan_id_int = int(json_data['vlan']) - except: - errors.append('Invalid VLAN received.') - else: - data['vlan'] = vlan_id_int - if 'ipv4_gw' in json_data: - try: - addr = IPv4Interface(json_data['ipv4_gw']) - prefix_len = int(addr.network.prefixlen) - except: - errors.append(('Invalid ipv4_gw received. ' - 'Must be correct IPv4 address with mask')) - else: - if prefix_len <= 31 and prefix_len >= 16: - data['ipv4_gw'] = str(addr) - else: - errors.append("Bad prefix length for management network: {}".format( - prefix_len)) + + try: + data = {**data, **f_mgmtdomain(**json_data).dict()} + except ValidationError as e: + errors += parse_pydantic_error(e, f_mgmtdomain, json_data) + required_keys = ['device_a', 'device_b', 'vlan', 'ipv4_gw'] - if all([key in data for key in required_keys]): + if all([key in data for key in required_keys]) and \ + all([key in json_data for key in required_keys]): new_mgmtd = Mgmtdomain() new_mgmtd.device_a = data['device_a'] new_mgmtd.device_b = data['device_b'] new_mgmtd.ipv4_gw = data['ipv4_gw'] new_mgmtd.vlan = data['vlan'] - session.add(new_mgmtd) - session.flush() + try: + session.add(new_mgmtd) + session.flush() + except IntegrityError as e: + session.rollback() + if 'duplicate' in str(e): + return empty_result('error', "Duplicate value: {}".format(e.orig.args[0])), 400 + else: + return empty_result('error', "Integrity error: {}".format(e)), 400 + device_a.synchronized = False device_b.synchronized = False return empty_result(status='success', data={"added_mgmtdomain": new_mgmtd.as_dict()}), 200 diff --git a/src/cnaas_nms/api/tests/test_api.py b/src/cnaas_nms/api/tests/test_api.py index d0d9d95d..6a30fc76 100644 --- a/src/cnaas_nms/api/tests/test_api.py +++ b/src/cnaas_nms/api/tests/test_api.py @@ -58,6 +58,40 @@ def test_get_managementdomain(self): # The one result should have the same ID we asked for self.assertIsInstance(result.json['data']['mgmtdomains'][0]['id'], int) + def test_update_managementdomain(self): + result = self.client.get('/api/v1.0/mgmtdomains?per_page=1') + self.assertIsInstance(result.json['data']['mgmtdomains'][0]['id'], int) + id = result.json['data']['mgmtdomains'][0]['id'] + data = {"vlan": 601} + result = self.client.put('/api/v1.0/mgmtdomain/{}'.format(id), json=data) + self.assertEqual(result.status_code, 200) + self.assertIn('updated_mgmtdomain', result.json['data']) + # Make sure returned data inclueds new vlan + self.assertEqual(result.json['data']['updated_mgmtdomain']['vlan'], data['vlan']) + # Change back to old vlan + data["vlan"] = 600 + result = self.client.put('/api/v1.0/mgmtdomain/{}'.format(id), json=data) + self.assertEqual(result.status_code, 200) + self.assertIn('updated_mgmtdomain', result.json['data']) + # Check that no change is made when applying same vlan twice + data["vlan"] = 600 + result = self.client.put('/api/v1.0/mgmtdomain/{}'.format(id), json=data) + self.assertEqual(result.status_code, 200) + self.assertIn('unchanged_mgmtdomain', result.json['data']) + + def test_validate_managementdomain(self): + result = self.client.get('/api/v1.0/mgmtdomains?per_page=1') + self.assertIsInstance(result.json['data']['mgmtdomains'][0]['id'], int) + id = result.json['data']['mgmtdomains'][0]['id'] + # Check that you get error if using invalid gw + data = {"ipv4_gw": "10.0.6.0/24"} + result = self.client.put('/api/v1.0/mgmtdomain/{}'.format(id), json=data) + self.assertEqual(result.status_code, 400) + # Check that you get error if using invalid vlan id + data = {"vlan": 5000} + result = self.client.put('/api/v1.0/mgmtdomain/{}'.format(id), json=data) + self.assertEqual(result.status_code, 400) + def test_repository_refresh(self): data = {"action": "refresh"} result = self.client.put('/api/v1.0/repository/settings', json=data) diff --git a/src/cnaas_nms/db/settings_fields.py b/src/cnaas_nms/db/settings_fields.py index 1779a23a..8dd8e0df 100644 --- a/src/cnaas_nms/db/settings_fields.py +++ b/src/cnaas_nms/db/settings_fields.py @@ -44,6 +44,7 @@ description="Max 32 alphanumeric chars, " + "beginning with a non-numeric character") vlan_id_schema = Field(..., gt=0, lt=4096, description="Numeric 802.1Q VLAN ID, 1-4095") +vlan_id_schema_optional = Field(None, gt=0, lt=4096, description="Numeric 802.1Q VLAN ID, 1-4095") vxlan_vni_schema = Field(..., gt=0, lt=16777215, description="VXLAN Network Identifier") vrf_id_schema = Field(..., gt=0, lt=65536, description="VRF identifier, integer between 1-65535") mtu_schema = Field(None, ge=68, le=9214,