diff --git a/job/models.py b/job/models.py index e85ca0cd4..97d23760c 100644 --- a/job/models.py +++ b/job/models.py @@ -78,6 +78,7 @@ class JobOperation(enum.IntEnum): CREATE_ENTRY_V2 = 27 EDIT_ENTRY_V2 = 28 DELETE_ENTRY_V2 = 29 + IMPORT_ROLE_V2 = 30 @enum.unique @@ -392,6 +393,7 @@ def method_table(kls): JobOperation.CREATE_ENTRY_V2: entry_task.create_entry_v2, JobOperation.EDIT_ENTRY_V2: entry_task.edit_entry_v2, JobOperation.DELETE_ENTRY_V2: entry_task.delete_entry_v2, + JobOperation.IMPORT_ROLE_V2: role_task.import_role_v2, } for operation_num, task in CUSTOM_TASKS.items(): custom_task = kls.get_task_module("custom_view.tasks") @@ -667,3 +669,13 @@ def _get_job_timeout(kls) -> int: return settings.AIRONE["JOB_TIMEOUT"] else: return kls.DEFAULT_JOB_TIMEOUT + + @classmethod + def new_role_import_v2(kls, user: User, text="", params={}) -> "Job": + return kls._create_new_job( + user, + None, + JobOperation.IMPORT_ROLE_V2, + text, + json.dumps(params, default=_support_time_default, sort_keys=True), + ) diff --git a/role/api_v2/views.py b/role/api_v2/views.py index 9d66ccf96..e197d3218 100644 --- a/role/api_v2/views.py +++ b/role/api_v2/views.py @@ -1,11 +1,9 @@ from rest_framework import generics, serializers, status, viewsets from rest_framework.permissions import BasePermission, IsAuthenticated -from rest_framework.request import Request from rest_framework.response import Response -from acl.models import ACLBase from airone.lib.drf import YAMLParser, YAMLRenderer -from group.models import Group +from job.models import Job from role.api_v2.serializers import ( RoleCreateUpdateSerializer, RoleImportExportChildSerializer, @@ -43,91 +41,25 @@ def get_serializer_class(self): class RoleImportAPI(generics.GenericAPIView): parser_classes = [YAMLParser] - permission_classes = [IsAuthenticated] serializer_class = serializers.Serializer - def post(self, request: Request) -> Response: + def post(self, request): import_datas = request.data + user: User = request.user serializer = RoleImportSerializer(data=import_datas) serializer.is_valid(raise_exception=True) - # TODO better to move the saving logic into the serializer - for role_data in import_datas: - if "name" not in role_data: - return Response("Role name is required", status=status.HTTP_400_BAD_REQUEST) - - if "id" in role_data: - # update group by id - role = Role.objects.filter(id=role_data["id"]).first() - if not role: - return Response( - "Specified id role does not exist(id:%s, group:%s)" - % (role_data["id"], role_data["name"]), - status=status.HTTP_400_BAD_REQUEST, - ) - - # check new name is not used - if (role.name != role_data["name"]) and ( - Role.objects.filter(name=role_data["name"]).count() > 0 - ): - return Response( - "New role name is already used(id:%s, group:%s->%s)" - % (role_data["id"], role.name, role_data["name"]), - status=status.HTTP_400_BAD_REQUEST, - ) - - role.name = role_data["name"] - else: - # update group by name - role = Role.objects.filter(name=role_data["name"]).first() - if not role: - # create group - role = Role.objects.create(name=role_data["name"]) - else: - # clear registered members (users, groups and administrative ones) to that role - for key in ["users", "groups", "admin_users", "admin_groups"]: - getattr(role, key).clear() - - role.description = role_data["description"] - - # set registered members (users, groups and administrative ones) to that role - for key in ["users", "admin_users"]: - for name in role_data[key]: - user: User | None = User.objects.filter(username=name, is_active=True).first() - if not user: - return Response( - "specified user is not found (username: %s)" % name, - status=status.HTTP_400_BAD_REQUEST, - ) - getattr(role, key).add(user) - for key in ["groups", "admin_groups"]: - for name in role_data[key]: - group: Group | None = Group.objects.filter(name=name, is_active=True).first() - if not group: - return Response( - "specified group is not found (name: %s)" % name, - status=status.HTTP_400_BAD_REQUEST, - ) - getattr(role, key).add(group) - - for permission in role_data.get("permissions", []): - acl: ACLBase | None = ACLBase.objects.filter(id=permission["obj_id"]).first() - if not acl: - return Response( - "Invalid obj_id given: %s" % str(permission["obj_id"]), - status=status.HTTP_400_BAD_REQUEST, - ) - - if permission["permission"] == "readable": - acl.readable.roles.add(role) - elif permission["permission"] == "writable": - acl.writable.roles.add(role) - elif permission["permission"] == "full": - acl.full.roles.add(role) - - role.save() - - return Response() + job_ids = [] + error_list = [] + + job = Job.new_role_import_v2( + user, text="Preparing to import role data", params=import_datas + ) + job.run() + job_ids.append(job.id) + return Response( + {"result": {"job_ids": job_ids, "error": error_list}}, status=status.HTTP_200_OK + ) class RoleExportAPI(generics.ListAPIView): diff --git a/role/tasks.py b/role/tasks.py index 62ad9c7c2..2f51976e0 100644 --- a/role/tasks.py +++ b/role/tasks.py @@ -1,9 +1,13 @@ import json +from acl.models import ACLBase from airone.celery import app from airone.lib.job import may_schedule_until_job_is_ready +from airone.lib.log import Logger +from group.models import Group from job.models import Job, JobStatus from role.models import Role +from user.models import User @app.task(bind=True) @@ -16,3 +20,103 @@ def edit_role_referrals(self, job: Job) -> JobStatus: entry.register_es() return JobStatus.DONE + + +@app.task(bind=True) +@may_schedule_until_job_is_ready +def import_role_v2(self, job: Job) -> tuple[JobStatus, str, None] | None: + import_data = json.loads(job.params) + err_msg = [] + total_count = len(import_data) + + for index, role_data in enumerate(import_data): + job.text = "Now importing roles... (progress: [%5d/%5d])" % (index + 1, total_count) + job.save(update_fields=["text"]) + + # Interrupt processing if the job is canceled + if job.is_canceled(): + job.status = JobStatus.CANCELED + job.save(update_fields=["status"]) + return None + + # Skip processing if the role name is not provided + if "name" not in role_data: + err_msg.append("Role name is required") + continue + + # Retrieve or create roles + if "id" in role_data: + role = Role.objects.filter(id=role_data["id"]).first() + if not role: + err_msg.append(f"Role with ID {role_data['id']} does not exist.") + continue + + if (role.name != role_data["name"]) and ( + Role.objects.filter(name=role_data["name"]).count() > 0 + ): + err_msg.append( + "New role name is already used(id:%s, group:%s->%s)" + % (role_data["id"], role.name, role_data["name"]) + ) + continue + + role.name = role_data["name"] + else: + # Update the group by name + role = Role.objects.filter(name=role_data["name"]).first() + if not role: + # create group + role = Role.objects.create(name=role_data["name"]) + else: + # Clear registered members (users, groups, and administrative ones) for that role + for key in ["users", "groups", "admin_users", "admin_groups"]: + getattr(role, key).clear() + + # Update role information + role.description = role_data.get("description", "") + + # Configure associated users and groups + for key in ["users", "admin_users"]: + for name in role_data[key]: + instance = User.objects.filter(username=name, is_active=True).first() + if not instance: + err_msg.append("specified user is not found (username: %s)" % name) + continue + getattr(role, key).add(instance) + + for key in ["groups", "admin_groups"]: + for name in role_data[key]: + instance = Group.objects.filter(name=name, is_active=True).first() + if not instance: + err_msg.append("specified group is not found (name: %s)" % name) + continue + getattr(role, key).add(instance) + + # Configure ACL + for permission in role_data.get("permissions", []): + acl = ACLBase.objects.filter(id=permission["obj_id"]).first() + if not acl: + raise ValueError(f"Invalid obj_id: {permission['obj_id']}") + if permission["permission"] == "readable": + acl.readable.roles.add(role) + elif permission["permission"] == "writable": + acl.writable.roles.add(role) + elif permission["permission"] == "full": + acl.full.roles.add(role) + + try: + role.save() + except Exception as e: + err_msg.append(role_data["name"]) + Logger.warning("failed to save role: name=%s, error=%s" % (role_data["name"], str(e))) + + # Update the job based on the result of the process + if err_msg: + return ( + JobStatus.WARNING, + "Imported Role count: %d, Failed import Roles: %s" + % (total_count - len(err_msg), err_msg), + None, + ) + else: + return JobStatus.DONE, "Imported Role count: %d" % total_count, None diff --git a/role/tests/test_api_v2.py b/role/tests/test_api_v2.py index 1e2180a59..91d575115 100644 --- a/role/tests/test_api_v2.py +++ b/role/tests/test_api_v2.py @@ -1,9 +1,12 @@ import json +from unittest.mock import Mock, patch import yaml from airone.lib.test import AironeViewTest from group.models import Group +from job.models import Job, JobOperation, JobStatus +from role import tasks from role.models import Role from user.models import User @@ -269,6 +272,7 @@ def test_delete_without_permission(self): resp = self.client.delete(f"/role/api/v2/{role.id}") self.assertEqual(resp.status_code, 403) + @patch("role.tasks.import_role_v2.delay", Mock(side_effect=tasks.import_role_v2)) def test_import(self): self.admin_login() @@ -278,6 +282,7 @@ def test_import(self): self.assertEqual(resp.status_code, 200) self.assertEqual(Role.objects.filter(name="role1").count(), 1) + @patch("role.tasks.import_role_v2.delay", Mock(side_effect=tasks.import_role_v2)) def test_import_for_update(self): self.admin_login() @@ -290,6 +295,8 @@ def test_import_for_update(self): import_data = fp.read().replace("", str(role.id)) resp = self.client.post("/role/api/v2/import", import_data, content_type="application/yaml") self.assertEqual(resp.status_code, 200) + job = Job.objects.get(operation=JobOperation.IMPORT_ROLE_V2) + self.assertEqual(job.status, JobStatus.DONE) # This confirms role instance is updated as expected values role = Role.objects.filter(name="role1").first() @@ -300,6 +307,7 @@ def test_import_for_update(self): self.assertEqual([x.name for x in role.groups.all()], ["group2"]) self.assertEqual([x.name for x in role.admin_groups.all()], ["group1"]) + @patch("role.tasks.import_role_v2.delay", Mock(side_effect=tasks.import_role_v2)) def test_import_with_permissions(self): admin = self.admin_login() @@ -308,6 +316,8 @@ def test_import_with_permissions(self): fp = self.open_fixture_file("import_roles_with_permissions.yaml") import_data = fp.read().replace("", str(entity.id)) resp = self.client.post("/role/api/v2/import", import_data, content_type="application/yaml") + job = Job.objects.get(operation=JobOperation.IMPORT_ROLE_V2) + self.assertEqual(job.status, JobStatus.DONE) self.assertEqual(resp.status_code, 200) role = Role.objects.filter(name="role1").first()