Skip to content

Commit

Permalink
refactor: improve error handling and code readability in API views
Browse files Browse the repository at this point in the history
Overview
- Enhanced error handling and logging across multiple API views.
- Refactored code for better readability and maintainability.
- Added new test cases to cover edge scenarios and improve test coverage.
- Updated Docker configuration for development environment.

Details
- OllamaManager: Simplified error handling and added comments for clarity.
- GPTAttackSuggestion: Improved error handling and refactored code for better readability.
- FetchMostCommonVulnerability: Refactored query logic and improved error handling.
- AddTarget: Added validation for domain names and improved error messages.
- DeleteVulnerability: Added validation for input data and improved error handling.
- ListTechnology: Refactored query logic for better readability.
- get_ips_from_cidr_range: Improved error handling and logging.
- Test cases: Added new test cases for various scenarios, including failure cases and edge cases.
- Docker: Updated docker-compose.dev.yml to enable remote debugging and added a new port.
- Miscellaneous: Various minor improvements and bug fixes across different files.
  • Loading branch information
psyray committed Sep 11, 2024
1 parent 5e47026 commit 1f21887
Show file tree
Hide file tree
Showing 23 changed files with 329 additions and 248 deletions.
22 changes: 22 additions & 0 deletions web/api/tests/test_ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.urls import reverse
from rest_framework import status
from utils.test_base import BaseTestCase
import socket

__all__ = [
'TestIpAddressViewSet',
Expand Down Expand Up @@ -65,6 +66,27 @@ def test_ip_to_domain(self, mock_gethostbyaddr):
response.data["ip_address"][0]["domain"], self.data_generator.domain.name
)

@patch("api.views.socket.gethostbyaddr")
def test_ip_to_domain_failure(self, mock_gethostbyaddr):
"""Test IP to domain resolution when it fails."""
mock_gethostbyaddr.side_effect = socket.herror
url = reverse("api:ip_to_domain")
response = self.client.get(url, {"ip_address": "192.0.2.1"})
self.assertEqual(response.status_code, 200)
self.assertTrue(response.data["status"])
self.assertEqual(response.data["ip_address"][0]["domain"], "192.0.2.1")

@patch("api.views.socket.gethostbyaddr")
def test_ip_to_domain_multiple(self, mock_gethostbyaddr):
"""Test IP to domain resolution with multiple domains."""
mock_domains = ["example.com", "example.org"]
mock_gethostbyaddr.return_value = (mock_domains[0], mock_domains, ["192.0.2.1"])
url = reverse("api:ip_to_domain")
response = self.client.get(url, {"ip_address": "192.0.2.1"})
self.assertEqual(response.status_code, 200)
self.assertIn("domains", response.data["ip_address"][0])
self.assertEqual(response.data["ip_address"][0]["domains"], mock_domains)

class TestDomainIPHistory(BaseTestCase):
"""Test case for domain IP history lookup."""

Expand Down
9 changes: 9 additions & 0 deletions web/api/tests/test_organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.urls import reverse
from rest_framework import status
from utils.test_base import BaseTestCase
from targetApp.models import Organization

__all__ = [
'TestListOrganizations',
Expand All @@ -20,6 +21,14 @@ def setUp(self):
super().setUp()
self.data_generator.create_project_full()

def test_list_empty_organizations(self):
"""Test listing organizations when the database is empty."""
Organization.objects.all().delete()
url = reverse("api:listOrganizations")
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.json()['organizations']), 0)

def test_list_organizations(self):
"""Test listing all organizations."""
url = reverse("api:listOrganizations")
Expand Down
5 changes: 5 additions & 0 deletions web/api/tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def test_scan_status(self):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("scans", response.data)
self.assertIn("tasks", response.data)
self.assertIsInstance(response.data["scans"], dict)
self.assertIsInstance(response.data["tasks"], dict)
if response.data["scans"]:
self.assertIn("id", response.data["scans"]["completed"][0])
self.assertIn("scan_status", response.data["scans"]["completed"][0])

class TestListScanHistory(BaseTestCase):
"""Test case for listing scan history."""
Expand Down
9 changes: 9 additions & 0 deletions web/api/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ def test_universal_search_no_query(self):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertFalse(response.data["status"])
self.assertEqual(response.data["message"], "No query parameter provided!")

def test_universal_search_with_special_characters(self):
"""Test the universal search functionality with special characters."""
api_url = reverse("api:search")
special_query = "admin'; DROP TABLE users;--"
response = self.client.get(api_url, {"query": special_query})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertFalse(response.data["status"])
self.assertNotIn("users", response.data["results"])
7 changes: 7 additions & 0 deletions web/api/tests/test_subdomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def test_delete_subdomain(self):
Subdomain.objects.filter(id=self.data_generator.subdomain.id).exists()
)

def test_delete_nonexistent_subdomain(self):
"""Test deleting a non-existent subdomain."""
api_url = reverse("api:delete_subdomain")
data = {"subdomain_ids": ["nonexistent_id"]}
response = self.client.post(api_url, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

class TestListSubdomains(BaseTestCase):
"""Test case for listing subdomains."""

Expand Down
6 changes: 6 additions & 0 deletions web/api/tests/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def setUp(self):
"""Set up test environment."""
super().setUp()
self.data_generator.create_project_base()
Domain.objects.all().delete()

def test_add_target(self):
"""Test adding a new target."""
Expand All @@ -38,6 +39,11 @@ def test_add_target(self):
Domain.objects.filter(name=self.data_generator.domain.name).exists()
)

# Test adding duplicate target
response = self.client.post(api_url, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertFalse(response.data["status"])

class TestListTargetsDatatableViewSet(BaseTestCase):
"""Tests for the List Targets Datatable API."""

Expand Down
12 changes: 11 additions & 1 deletion web/api/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,20 @@ def test_get_download_model(self, mock_post):
"""Test downloading an Ollama model."""
mock_post.return_value.json.return_value = {"status": "success"}
api_url = reverse("api:ollama_manager")
response = self.client.get(api_url, data={"model": "gpt-4"})
response = self.client.get(api_url, data={"model": "llama2"})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue(response.data["status"])

@patch("requests.post")
def test_get_download_model_failure(self, mock_post):
"""Test failed downloading of an Ollama model."""
mock_post.return_value.json.return_value = {"error": "pull model manifest: file does not exist"}
api_url = reverse("api:ollama_manager")
response = self.client.get(api_url, data={"model": "invalid-model"})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["error"], "pull model manifest: file does not exist")
self.assertFalse(response.data["status"])

@patch("requests.delete")
def test_delete_model(self, mock_delete):
"""Test deleting an Ollama model."""
Expand Down
5 changes: 5 additions & 0 deletions web/api/tests/test_vulnerability.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def test_list_vulnerabilities(self):
api_url = reverse("api:vulnerabilities-list")
response = self.client.get(api_url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('count', response.data)
self.assertIn('next', response.data)
self.assertIn('previous', response.data)
self.assertIn('results', response.data)
self.assertIsInstance(response.data['results'][0], dict)
self.assertGreaterEqual(len(response.data["results"]), 1)
self.assertEqual(
response.data["results"][0]["name"],
Expand Down
Loading

0 comments on commit 1f21887

Please sign in to comment.