Skip to content

Commit

Permalink
Cleaning Artemis API (#860)
Browse files Browse the repository at this point in the history
  • Loading branch information
kazet committed Mar 26, 2024
1 parent f1f18b7 commit 08ee90f
Show file tree
Hide file tree
Showing 31 changed files with 381 additions and 116 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 30
timeout-minutes: 40
steps:
- name: Check out repository
uses: actions/checkout@v2
Expand Down
92 changes: 56 additions & 36 deletions artemis/api.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,55 @@
from typing import Any, Dict, List, Optional
from typing import Annotated, Any, Dict, List, Optional

from fastapi import APIRouter, HTTPException, Query, Request
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Request
from karton.core.backend import KartonBackend
from karton.core.config import Config as KartonConfig
from karton.core.inspect import KartonState

from artemis.config import Config
from artemis.db import DB, ColumnOrdering, TaskFilter
from artemis.modules.classifier import Classifier
from artemis.producer import create_tasks
from artemis.templating import render_analyses_table_row, render_task_table_row

router = APIRouter()
db = DB()


@router.get("/task/{task_id}")
def get_task(task_id: str) -> Dict[str, Any]:
if result := db.get_task_by_id(task_id):
return result
raise HTTPException(status_code=404, detail="Task not found")
def verify_api_token(x_api_token: Annotated[str, Header()]) -> None:
if not Config.Miscellaneous.API_TOKEN:
raise HTTPException(
status_code=401,
detail="Please fill the API_TOKEN variable in .env in order to use the API",
)
elif x_api_token != Config.Miscellaneous.API_TOKEN:
raise HTTPException(status_code=401, detail="Invalid API token")


@router.post("/add", dependencies=[Depends(verify_api_token)])
def add(
targets: List[str],
tag: Annotated[Optional[str], Body()] = None,
disabled_modules: List[str] = Config.Miscellaneous.MODULES_DISABLED_BY_DEFAULT,
) -> Dict[str, Any]:
"""Add targets to be scanned."""
for task in targets:
if not Classifier.is_supported(task):
return {"error": f"Invalid task: {task}"}

create_tasks(targets, tag, disabled_modules=disabled_modules)

return {"ok": True}

@router.get("/analysis")

@router.get("/analyses", dependencies=[Depends(verify_api_token)])
def list_analysis() -> List[Dict[str, Any]]:
"""Returns the list of analysed targets. Any scanned target would be listed here."""
return db.list_analysis()


@router.get("/num-queued-tasks")
def num_queued_tasks(karton_names: Optional[List[str]] = Query(default=None)) -> int:
@router.get("/num-queued-tasks", dependencies=[Depends(verify_api_token)])
def num_queued_tasks(karton_names: Optional[List[str]] = None) -> int:
"""Return the number of queued tasks for all or only some kartons."""
# We check the backend redis queue length directly to avoid the long runtimes of
# KartonState.get_all_tasks()
backend = KartonBackend(config=KartonConfig())
Expand All @@ -39,14 +63,25 @@ def num_queued_tasks(karton_names: Optional[List[str]] = Query(default=None)) ->
return sum([backend.redis.llen(key) for key in backend.redis.keys("karton.queue.*")])


@router.get("/analysis/{root_id}")
def get_analysis(root_id: str) -> Dict[str, Any]:
if result := db.get_analysis_by_id(root_id):
return result
raise HTTPException(status_code=404, detail="Analysis not found")


@router.get("/analyses-table")
@router.get("/task-results", dependencies=[Depends(verify_api_token)])
def get_task_results(
only_interesting: bool = False,
page: int = 1,
page_size: int = 100,
analysis_id: Optional[str] = None,
search: Optional[str] = None,
) -> List[Dict[str, Any]]:
return db.get_paginated_task_results(
start=(page - 1) * page_size,
length=page_size,
ordering=[ColumnOrdering(column_name="created_at", ascending=True)],
search_query=search,
analysis_id=analysis_id,
task_filter=TaskFilter.INTERESTING if only_interesting else None,
).data


@router.get("/analyses-table", include_in_schema=False)
def get_analyses_table(
request: Request,
draw: int = Query(),
Expand All @@ -71,8 +106,7 @@ def get_analyses_table(
{
"id": entry["id"],
"tag": entry["tag"],
"payload": entry["task"]["payload"],
"payload_persistent": entry["task"]["payload_persistent"],
"target": entry["target"],
"num_active_tasks": num_active_tasks,
"stopped": entry.get("stopped", None),
}
Expand All @@ -86,7 +120,7 @@ def get_analyses_table(
}


@router.get("/task-results-table")
@router.get("/task-results-table", include_in_schema=False)
def get_task_results_table(
request: Request,
analysis_id: Optional[str] = Query(default=None),
Expand All @@ -101,34 +135,20 @@ def get_task_results_table(
)
search_query = _get_search_query(request)

fields = [
"created_at",
"target_string",
"headers",
"payload_persistent",
"status",
"status_reason",
"priority",
"uid",
"decision_type",
"operator_comment",
]

if analysis_id:
if not db.get_analysis_by_id(analysis_id):
raise HTTPException(status_code=404, detail="Analysis not found")
result = db.get_paginated_task_results(
start,
length,
ordering,
fields=fields,
search_query=search_query,
analysis_id=analysis_id,
task_filter=task_filter,
)
else:
result = db.get_paginated_task_results(
start, length, ordering, fields=fields, search_query=search_query, task_filter=task_filter
start, length, ordering, search_query=search_query, task_filter=task_filter
)

return {
Expand Down
4 changes: 4 additions & 0 deletions artemis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ class Limits:
] = get_config("REQUESTS_PER_SECOND", default=0, cast=float)

class Miscellaneous:
API_TOKEN: Annotated[str, "The token to authenticate to the API. Provide one to use the API."] = get_config(
"API_TOKEN", default=None
)

BLOCKLIST_FILE: Annotated[
str,
"A file that determines what should not be scanned or reported",
Expand Down
32 changes: 21 additions & 11 deletions artemis/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class Analysis(Base): # type: ignore
target = Column(String, index=True)
tag = Column(String, index=True)
stopped = Column(Boolean, index=True)
task = Column(JSON)

fulltext = Column(
TSVector(),
Expand Down Expand Up @@ -153,7 +152,7 @@ def __init__(self) -> None:

def list_analysis(self) -> List[Dict[str, Any]]:
with self.session() as session:
return [item.__dict__ for item in session.query(Analysis).all()]
return [self._strip_internal_db_info(item.__dict__) for item in session.query(Analysis).all()]

def mark_analysis_as_stopped(self, analysis_id: str) -> None:
with self.session() as session:
Expand All @@ -164,16 +163,12 @@ def mark_analysis_as_stopped(self, analysis_id: str) -> None:

def create_analysis(self, analysis: Task) -> None:
analysis_dict = self.task_to_dict(analysis)
del analysis_dict["status"]
if "status_reason" in analysis_dict:
del analysis_dict["status_reason"]

analysis = Analysis(
id=analysis.uid,
target=analysis_dict["payload"]["data"],
tag=analysis_dict["payload_persistent"].get("tag", None),
stopped=False,
task=analysis_dict,
)
with self.session() as session:
session.add(analysis)
Expand All @@ -194,6 +189,9 @@ def save_task_result(
# Used to allow searching in the names and values of all existing headers
headers_string=" ".join([key + " " + value for key, value in task.headers.items()]),
)

del to_save["task"]["status"] # at the moment of saving it's "started", which will be misleading

if isinstance(data, BaseModel):
to_save["result"] = data.dict()
elif isinstance(data, Exception):
Expand All @@ -215,7 +213,7 @@ def get_analysis_by_id(self, analysis_id: str) -> Optional[Dict[str, Any]]:
item = session.query(Analysis).get(analysis_id)

if item:
return item.__dict__ # type: ignore
return self._strip_internal_db_info(item.__dict__)
else:
return None
except NoResultFound:
Expand Down Expand Up @@ -247,7 +245,10 @@ def get_paginated_analyses(
query = query.filter(Analysis.fulltext.match(self._to_postgresql_query(search_query))) # type: ignore

records_count_filtered: int = query.count()
results_page = [item.__dict__ for item in query.order_by(*ordering_postgresql).slice(start, start + length)]
results_page = [
self._strip_internal_db_info(item.__dict__)
for item in query.order_by(*ordering_postgresql).slice(start, start + length)
]
return PaginatedResults(
records_count_total=records_count_total,
records_count_filtered=records_count_filtered,
Expand All @@ -259,7 +260,6 @@ def get_paginated_task_results(
start: int,
length: int,
ordering: List[ColumnOrdering],
fields: List[str],
*,
search_query: Optional[str] = None,
analysis_id: Optional[str] = None,
Expand Down Expand Up @@ -290,7 +290,10 @@ def get_paginated_task_results(
query = query.filter(getattr(TaskResult, key) == value)

records_count_filtered = query.count()
results_page = [item.__dict__ for item in query.order_by(*ordering_postgresql).slice(start, start + length)]
results_page = [
self._strip_internal_db_info(item.__dict__)
for item in query.order_by(*ordering_postgresql).slice(start, start + length)
]
return PaginatedResults(
records_count_total=records_count_total,
records_count_filtered=records_count_filtered,
Expand All @@ -303,7 +306,7 @@ def get_task_by_id(self, task_id: str) -> Optional[Dict[str, Any]]:
item = session.query(TaskResult).get(task_id)

if item:
return item.__dict__ # type: ignore
return self._strip_internal_db_info(item.__dict__)
else:
return None
except NoResultFound:
Expand Down Expand Up @@ -393,3 +396,10 @@ def _to_postgresql_query(self, query: str) -> str:
query = query.replace("\\", " ") # just in case
query = query.replace('"', " ") # just in case
return " & ".join([f'"{item}"' for item in query.split(" ") if item])

def _strip_internal_db_info(self, d: Dict[str, Any]) -> Dict[str, Any]:
del d["_sa_instance_state"]
del d["fulltext"]
if "headers_string" in d:
del d["headers_string"]
return d
3 changes: 3 additions & 0 deletions artemis/db_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def _single_migration_iteration() -> None:


def migrate_and_start_thread() -> None:
if not Config.Data.LEGACY_MONGODB_CONN_STR:
return

client = MongoClient(Config.Data.LEGACY_MONGODB_CONN_STR)
client.artemis.task_results.create_index([("migrated", ASCENDING)])
client.artemis.analysis.create_index([("migrated", ASCENDING)])
Expand Down
19 changes: 17 additions & 2 deletions artemis/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Request,
Response,
)
from fastapi.responses import RedirectResponse
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_csrf_protect import CsrfProtect
from karton.core.backend import KartonBackend, KartonBind
from karton.core.config import Config as KartonConfig
Expand Down Expand Up @@ -74,7 +74,22 @@ def get_binds_that_can_be_disabled() -> List[KartonBind]:


def error_content_not_found(request: Request, exc: HTTPException) -> Response:
return templates.TemplateResponse("not_found.jinja2", {"request": request}, status_code=404)
if request.url.path.startswith("/api"):
return JSONResponse({"error": 404}, status_code=404)
else:
return templates.TemplateResponse("not_found.jinja2", {"request": request}, status_code=404)


if not Config.Miscellaneous.API_TOKEN:

@router.get("/docs", include_in_schema=False)
def api_docs_information(request: Request) -> Response:
return templates.TemplateResponse(
"no_api_token.jinja2",
{
"request": request,
},
)


@router.get("/", include_in_schema=False)
Expand Down
9 changes: 8 additions & 1 deletion artemis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@

from artemis import csrf, db_migration
from artemis.api import router as router_api
from artemis.config import Config
from artemis.db import DB
from artemis.frontend import error_content_not_found
from artemis.frontend import router as router_front
from artemis.utils import read_template

app = FastAPI()
app = FastAPI(
docs_url="/docs" if Config.Miscellaneous.API_TOKEN else None,
redoc_url=None,
# This will be displayed as the additional text in Swagger docs
description=read_template("components/generating_reports_hint.jinja2"),
)
app.exception_handler(CsrfProtectError)(csrf.csrf_protect_exception_handler)
app.exception_handler(404)(error_content_not_found)

Expand Down
2 changes: 1 addition & 1 deletion artemis/reporting/export/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path

# This is the output location *inside the container*. The scripts/export_emails
# This is the output location *inside the container*. The scripts/export_reports
# script is responsible for mounting a host path to a path inside the container.
OUTPUT_LOCATION = Path("./output/autoreporter/")
17 changes: 11 additions & 6 deletions artemis/reporting/export/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ class DataLoader:
A wrapper around DB that loads data and converts them to Reports.
"""

def __init__(self, db: DB, blocklist: List[BlocklistItem], language: Language, tag: Optional[str]):
def __init__(
self, db: DB, blocklist: List[BlocklistItem], language: Language, tag: Optional[str], silent: bool = False
):
self._db = db
self._blocklist = blocklist
self._language = language
self._tag = tag
self._data_initialized = False
self._silent = silent

def _initialize_data_if_needed(self) -> None:
"""
Expand All @@ -43,11 +46,13 @@ def _initialize_data_if_needed(self) -> None:
self._scanned_targets = set()
self._tag_stats: DefaultDict[str, int] = defaultdict(lambda: 0)

for result in tqdm(
self._db.get_task_results_since(
datetime.datetime.now() - datetime.timedelta(days=Config.Reporting.REPORTING_MAX_VULN_AGE_DAYS)
)
):
results = self._db.get_task_results_since(
datetime.datetime.now() - datetime.timedelta(days=Config.Reporting.REPORTING_MAX_VULN_AGE_DAYS)
)
if not self._silent:
results = tqdm(results) # type: ignore

for result in results:
result_tag = result["task"].get("payload_persistent", {}).get("tag", None)
self._tag_stats[result_tag] += 1

Expand Down
Loading

0 comments on commit 08ee90f

Please sign in to comment.