Skip to content

Commit

Permalink
Merge pull request #177 from SUNET/bugfix.mgmtdomain_valid_ip_and_output
Browse files Browse the repository at this point in the history
Mgmtdomains API, better validation and messages
  • Loading branch information
indy-independence authored Jul 12, 2021
2 parents 8df1862 + 5d04574 commit ee5947e
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 51 deletions.
40 changes: 40 additions & 0 deletions src/cnaas_nms/api/generic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import re
from typing import List

from flask import request
import sqlalchemy

from cnaas_nms.db.settings import get_pydantic_error_value, get_pydantic_field_descr


FILTER_RE = re.compile(r"^filter\[([a-zA-Z0-9_.]+)\](\[[a-z]+\])?$")
DEFAULT_PER_PAGE = 50
Expand Down Expand Up @@ -127,3 +130,40 @@ def empty_result(status='success', data=None):
'status': status,
'message': data if data else "Unknown error"
}


def parse_pydantic_error(e: Exception, schema, data: dict) -> List[str]:
errors = []
for num, error in enumerate(e.errors()):
loc = error['loc']
origin = 'unknown'
errors.append("Validation error for setting {}, bad value: {}".format(
'->'.join(str(x) for x in loc),
get_pydantic_error_value(data, loc)
))
try:
pydantic_descr = get_pydantic_field_descr(schema.schema(), loc)
if pydantic_descr:
pydantic_descr_msg = ", field should be: {}".format(pydantic_descr)
else:
pydantic_descr_msg = ""
except Exception as e_pydantic_descr:
pydantic_descr_msg = ", exception while getting pydantic description"
errors.append("Message: {}{}".format(error['msg'], pydantic_descr_msg))
return errors


def update_sqla_object(instance, new_data: dict) -> bool:
"""Update SQLalchemy object instance with data dict.
Returns:
Returns True if any values were changed
"""
changed = False
for k, v in new_data.items():
try:
if getattr(instance, k) != v:
setattr(instance, k, v)
changed = True
except AttributeError:
continue
return changed
117 changes: 66 additions & 51 deletions src/cnaas_nms/api/mgmtdomain.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from typing import Optional

from sqlalchemy.exc import IntegrityError
from flask import request
from flask_restx import Resource, Namespace, fields
from flask_jwt_extended import jwt_required

from pydantic import BaseModel, validator
from pydantic.error_wrappers import ValidationError
from ipaddress import IPv4Interface

from cnaas_nms.api.generic import build_filter, empty_result, limit_results
from cnaas_nms.db.device import Device
from cnaas_nms.db.mgmtdomain import Mgmtdomain
from cnaas_nms.db.session import sqla_session
from cnaas_nms.version import __api_version__
from cnaas_nms.db.settings_fields import vlan_id_schema_optional
from cnaas_nms.api.generic import parse_pydantic_error, update_sqla_object


mgmtdomains_api = Namespace('mgmtdomains', description='API for handling management domains',
Expand All @@ -21,9 +27,33 @@
'device_b': fields.String(required=True),
'vlan': fields.Integer(required=True),
'ipv4_gw': fields.String(required=True),
'description': fields.String(required=False),
})


class f_mgmtdomain(BaseModel):
vlan: Optional[int] = vlan_id_schema_optional
ipv4_gw: Optional[str] = None
description: Optional[str] = None

@validator('ipv4_gw')
def ipv4_gw_valid_address(cls, v, values, **kwargs):
try:
addr = IPv4Interface(v)
prefix_len = int(addr.network.prefixlen)
except:
raise ValueError('Invalid ipv4_gw received. Must be correct IPv4 address with mask')
else:
if addr.ip == addr.network.network_address:
raise ValueError("Specify gateway address, not subnet address")
if addr.ip == addr.network.broadcast_address:
raise ValueError("Specify gateway address, not broadcast address")
if prefix_len >= 31 or prefix_len <= 16:
raise ValueError("Bad prefix length {} for management network".format(prefix_len))

return v


class MgmtdomainByIdApi(Resource):
@jwt_required
def get(self, mgmtdomain_id):
Expand Down Expand Up @@ -60,38 +90,28 @@ def delete(self, mgmtdomain_id):
def put(self, mgmtdomain_id):
""" Modify management domain """
json_data = request.get_json()
data = {}
errors = []
if 'vlan' in json_data:
try:
vlan_id_int = int(json_data['vlan'])
except:
errors.append('Invalid VLAN received.')
else:
data['vlan'] = vlan_id_int
if 'ipv4_gw' in json_data:
try:
addr = IPv4Interface(json_data['ipv4_gw'])
prefix_len = int(addr.network.prefixlen)
except:
errors.append('Invalid ipv4_gw received. Must be correct IPv4 address with mask')
else:
if prefix_len <= 31 and prefix_len >= 16:
data['ipv4_gw'] = str(addr)
else:
errors.append("Bad prefix length for management network: {}".format(
prefix_len))
try:
f_mgmtdomain(**json_data).dict()
except ValidationError as e:
errors += parse_pydantic_error(e, f_mgmtdomain, json_data)

if errors:
return empty_result('error', errors), 400

with sqla_session() as session:
instance: Mgmtdomain = session.query(Mgmtdomain).\
filter(Mgmtdomain.id == mgmtdomain_id).one_or_none()
if instance:
instance.device_a.synchronized = False
instance.device_b.synchronized = False
#TODO: auto loop through class members and match
if 'vlan' in data:
instance.vlan = data['vlan']
if 'ipv4_gw' in data:
instance.ipv4_gw = data['ipv4_gw']
changed: bool = update_sqla_object(instance, json_data)
if changed:
instance.device_a.synchronized = False
instance.device_b.synchronized = False
return empty_result(status='success', data={"updated_mgmtdomain": instance.as_dict()}), 200
else:
return empty_result(status='success', data={"unchanged_mgmtdomain": instance.as_dict()}), 200
else:
return empty_result(status='error', data="mgmtdomain not found"), 400


class MgmtdomainsApi(Resource):
Expand Down Expand Up @@ -142,35 +162,30 @@ def post(self):
errors.append(f"Device with hostname {hostname_b} not found")
else:
data['device_b'] = device_b
if 'vlan' in json_data:
try:
vlan_id_int = int(json_data['vlan'])
except:
errors.append('Invalid VLAN received.')
else:
data['vlan'] = vlan_id_int
if 'ipv4_gw' in json_data:
try:
addr = IPv4Interface(json_data['ipv4_gw'])
prefix_len = int(addr.network.prefixlen)
except:
errors.append(('Invalid ipv4_gw received. '
'Must be correct IPv4 address with mask'))
else:
if prefix_len <= 31 and prefix_len >= 16:
data['ipv4_gw'] = str(addr)
else:
errors.append("Bad prefix length for management network: {}".format(
prefix_len))

try:
data = {**data, **f_mgmtdomain(**json_data).dict()}
except ValidationError as e:
errors += parse_pydantic_error(e, f_mgmtdomain, json_data)

required_keys = ['device_a', 'device_b', 'vlan', 'ipv4_gw']
if all([key in data for key in required_keys]):
if all([key in data for key in required_keys]) and \
all([key in json_data for key in required_keys]):
new_mgmtd = Mgmtdomain()
new_mgmtd.device_a = data['device_a']
new_mgmtd.device_b = data['device_b']
new_mgmtd.ipv4_gw = data['ipv4_gw']
new_mgmtd.vlan = data['vlan']
session.add(new_mgmtd)
session.flush()
try:
session.add(new_mgmtd)
session.flush()
except IntegrityError as e:
session.rollback()
if 'duplicate' in str(e):
return empty_result('error', "Duplicate value: {}".format(e.orig.args[0])), 400
else:
return empty_result('error', "Integrity error: {}".format(e)), 400

device_a.synchronized = False
device_b.synchronized = False
return empty_result(status='success', data={"added_mgmtdomain": new_mgmtd.as_dict()}), 200
Expand Down
34 changes: 34 additions & 0 deletions src/cnaas_nms/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,40 @@ def test_get_managementdomain(self):
# The one result should have the same ID we asked for
self.assertIsInstance(result.json['data']['mgmtdomains'][0]['id'], int)

def test_update_managementdomain(self):
result = self.client.get('/api/v1.0/mgmtdomains?per_page=1')
self.assertIsInstance(result.json['data']['mgmtdomains'][0]['id'], int)
id = result.json['data']['mgmtdomains'][0]['id']
data = {"vlan": 601}
result = self.client.put('/api/v1.0/mgmtdomain/{}'.format(id), json=data)
self.assertEqual(result.status_code, 200)
self.assertIn('updated_mgmtdomain', result.json['data'])
# Make sure returned data inclueds new vlan
self.assertEqual(result.json['data']['updated_mgmtdomain']['vlan'], data['vlan'])
# Change back to old vlan
data["vlan"] = 600
result = self.client.put('/api/v1.0/mgmtdomain/{}'.format(id), json=data)
self.assertEqual(result.status_code, 200)
self.assertIn('updated_mgmtdomain', result.json['data'])
# Check that no change is made when applying same vlan twice
data["vlan"] = 600
result = self.client.put('/api/v1.0/mgmtdomain/{}'.format(id), json=data)
self.assertEqual(result.status_code, 200)
self.assertIn('unchanged_mgmtdomain', result.json['data'])

def test_validate_managementdomain(self):
result = self.client.get('/api/v1.0/mgmtdomains?per_page=1')
self.assertIsInstance(result.json['data']['mgmtdomains'][0]['id'], int)
id = result.json['data']['mgmtdomains'][0]['id']
# Check that you get error if using invalid gw
data = {"ipv4_gw": "10.0.6.0/24"}
result = self.client.put('/api/v1.0/mgmtdomain/{}'.format(id), json=data)
self.assertEqual(result.status_code, 400)
# Check that you get error if using invalid vlan id
data = {"vlan": 5000}
result = self.client.put('/api/v1.0/mgmtdomain/{}'.format(id), json=data)
self.assertEqual(result.status_code, 400)

def test_repository_refresh(self):
data = {"action": "refresh"}
result = self.client.put('/api/v1.0/repository/settings', json=data)
Expand Down
1 change: 1 addition & 0 deletions src/cnaas_nms/db/settings_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
description="Max 32 alphanumeric chars, " +
"beginning with a non-numeric character")
vlan_id_schema = Field(..., gt=0, lt=4096, description="Numeric 802.1Q VLAN ID, 1-4095")
vlan_id_schema_optional = Field(None, gt=0, lt=4096, description="Numeric 802.1Q VLAN ID, 1-4095")
vxlan_vni_schema = Field(..., gt=0, lt=16777215, description="VXLAN Network Identifier")
vrf_id_schema = Field(..., gt=0, lt=65536, description="VRF identifier, integer between 1-65535")
mtu_schema = Field(None, ge=68, le=9214,
Expand Down

0 comments on commit ee5947e

Please sign in to comment.