Skip to content

Commit

Permalink
Add user avatar uploads (#687)
Browse files Browse the repository at this point in the history
* Add user avatar uploads

* Update empire/test/test_user_api.py
  • Loading branch information
vinnybod authored Sep 16, 2023
1 parent c98d30c commit 3bb06b8
Show file tree
Hide file tree
Showing 13 changed files with 191 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- Add avatars to users (@Vinnybod)
- Update plugin documentation, update embedded plugins to not abuse notifications (@Vinnybod)
- Add additional pre-commit hooks for code cleanup (@Vinnybod)
- Report test coverage on pull requests (@Vinnybod)
Expand Down
20 changes: 19 additions & 1 deletion empire/server/api/v2/user/user_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import timedelta

from fastapi import Depends, HTTPException
from fastapi import Depends, File, HTTPException, UploadFile
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from starlette import status
Expand Down Expand Up @@ -160,3 +160,21 @@ async def update_user_password(
raise HTTPException(status_code=400, detail=err)

return domain_to_dto_user(resp)


@router.post("/api/v2/users/{uid}/avatar", status_code=201)
async def create_avatar(
uid: int,
db: Session = Depends(get_db),
user: models.User = Depends(get_current_active_user),
file: UploadFile = File(...),
):
if not user.id == uid:
raise HTTPException(
status_code=403, detail="User does not have access to update this resource."
)

if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image.")

user_service.update_user_avatar(db, user, file)
13 changes: 12 additions & 1 deletion empire/server/api/v2/user/user_dto.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from datetime import datetime
from typing import List
from typing import List, Optional

from pydantic import BaseModel

from empire.server.api.v2.shared_dto import (
DownloadDescription,
domain_to_dto_download_description,
)


def domain_to_dto_user(user):
if user.avatar:
download_description = domain_to_dto_download_description(user.avatar)
else:
download_description = None
return User(
id=user.id,
username=user.username,
enabled=user.enabled,
is_admin=user.admin,
created_at=user.created_at,
updated_at=user.updated_at,
avatar=download_description,
)


Expand All @@ -20,6 +30,7 @@ class User(BaseModel):
username: str
enabled: bool
is_admin: bool
avatar: Optional[DownloadDescription]
created_at: datetime
updated_at: datetime

Expand Down
2 changes: 1 addition & 1 deletion empire/server/common/empire.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def __init__(self, args=None):

self.listenertemplatesv2 = ListenerTemplateService(self)
self.stagertemplatesv2 = StagerTemplateService(self)
self.usersv2 = UserService(self)
self.bypassesv2 = BypassService(self)
self.obfuscationv2 = ObfuscationService(self)
self.profilesv2 = ProfileService(self)
self.credentialsv2 = CredentialService(self)
self.hostsv2 = HostService(self)
self.processesv2 = HostProcessService(self)
self.downloadsv2 = DownloadService(self)
self.usersv2 = UserService(self)
self.listenersv2 = ListenerService(self)
self.stagersv2 = StagerService(self)
self.modulesv2 = ModuleService(self)
Expand Down
2 changes: 2 additions & 0 deletions empire/server/core/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ class User(Base):
updated_at = Column(
UtcDateTime, default=utcnow(), onupdate=utcnow(), nullable=False
)
avatar = relationship("Download")
avatar_id = Column(Integer, ForeignKey("downloads.id"), nullable=True)

def __repr__(self):
return "<User(username='%s')>" % (self.username)
Expand Down
8 changes: 8 additions & 0 deletions empire/server/core/user_service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from fastapi import UploadFile
from sqlalchemy.orm import Session

from empire.server.core.db import models
from empire.server.core.download_service import DownloadService


class UserService(object):
def __init__(self, main_menu):
self.main_menu = main_menu
self.download_service: DownloadService = main_menu.downloadsv2

@staticmethod
def get_all(db: Session):
Expand Down Expand Up @@ -57,3 +60,8 @@ def update_user_password(db: Session, db_user: models.User, hashed_password: str
db.flush()

return db_user, None

def update_user_avatar(self, db: Session, db_user: models.User, file: UploadFile):
download = self.download_service.create_download(db, db_user, file)

db_user.avatar = download
Binary file added empire/test/avatar.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added empire/test/avatar2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 8 additions & 3 deletions empire/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from fastapi import FastAPI
from starlette.testclient import TestClient

from empire.client.src.utils.data_util import get_random_string

SERVER_CONFIG_LOC = "empire/test/test_server_config.yaml"
CLIENT_CONFIG_LOC = "empire/test/test_client_config.yaml"
DEFAULT_ARGV = ["", "server", "--config", SERVER_CONFIG_LOC]
Expand Down Expand Up @@ -431,15 +433,18 @@ def credential(client, admin_auth_header):
json={
"credtype": "hash",
"domain": "the-domain",
"username": "user",
"password": "hunter2",
"username": get_random_string(8),
"password": get_random_string(8),
"host": "host1",
},
)

yield resp.json()["id"]

client.delete(f"/api/v2/credentials/{resp.json()['id']}", headers=admin_auth_header)
with suppress(Exception):
client.delete(
f"/api/v2/credentials/{resp.json()['id']}", headers=admin_auth_header
)


@pytest.fixture(scope="function")
Expand Down
49 changes: 34 additions & 15 deletions empire/test/test_credential_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_create_credential(client, admin_auth_header, base_credential):
)

assert response.status_code == 201
assert response.json()["id"] == 1
assert response.json()["id"] > 0
assert response.json()["credtype"] == "hash"
assert response.json()["domain"] == "the-domain"
assert response.json()["username"] == "user"
Expand Down Expand Up @@ -49,7 +49,7 @@ def test_update_credential_not_found(client, admin_auth_header, base_credential)


def test_update_credential_unique_constraint_failure(
client, admin_auth_header, base_credential
client, admin_auth_header, base_credential, credential
):
credential_2 = copy.deepcopy(base_credential)
credential_2["domain"] = "the-domain-2"
Expand All @@ -59,19 +59,26 @@ def test_update_credential_unique_constraint_failure(
assert response.status_code == 201

response = client.put(
"/api/v2/credentials/2", headers=admin_auth_header, json=base_credential
f"/api/v2/credentials/{credential}",
headers=admin_auth_header,
json=base_credential,
)

assert response.status_code == 400
assert response.json()["detail"] == "Credential not updated. Duplicate detected."


def test_update_credential(client, admin_auth_header, base_credential):
updated_credential = base_credential
def test_update_credential(client, admin_auth_header, credential):
response = client.get(
f"/api/v2/credentials/{credential}", headers=admin_auth_header
)
updated_credential = response.json()
updated_credential["domain"] = "new-domain"
updated_credential["password"] = "password3"
response = client.put(
"/api/v2/credentials/1", headers=admin_auth_header, json=updated_credential
f"/api/v2/credentials/{updated_credential['id']}",
headers=admin_auth_header,
json=updated_credential,
)

assert response.status_code == 200
Expand All @@ -86,11 +93,13 @@ def test_get_credential_not_found(client, admin_auth_header):
assert response.json()["detail"] == "Credential not found for id 9999"


def test_get_credential(client, admin_auth_header):
response = client.get("/api/v2/credentials/1", headers=admin_auth_header)
def test_get_credential(client, admin_auth_header, credential):
response = client.get(
f"/api/v2/credentials/{credential}", headers=admin_auth_header
)

assert response.status_code == 200
assert response.json()["id"] == 1
assert response.json()["id"] > 0


def test_get_credentials(client, admin_auth_header):
Expand All @@ -100,12 +109,18 @@ def test_get_credentials(client, admin_auth_header):
assert len(response.json()["records"]) > 0


def test_get_credentials_search(client, admin_auth_header):
response = client.get("/api/v2/credentials?search=hunt", headers=admin_auth_header)
def test_get_credentials_search(client, admin_auth_header, credential):
response = client.get(
f"/api/v2/credentials/{credential}", headers=admin_auth_header
)
password = response.json()["password"]
response = client.get(
f"/api/v2/credentials?search={password[:3]}", headers=admin_auth_header
)

assert response.status_code == 200
assert len(response.json()["records"]) == 1
assert response.json()["records"][0]["password"] == "hunter2"
assert response.json()["records"][0]["password"] == password

response = client.get(
"/api/v2/credentials?search=qwerty", headers=admin_auth_header
Expand All @@ -115,11 +130,15 @@ def test_get_credentials_search(client, admin_auth_header):
assert len(response.json()["records"]) == 0


def test_delete_credential(client, admin_auth_header):
response = client.delete("/api/v2/credentials/1", headers=admin_auth_header)
def test_delete_credential(client, admin_auth_header, credential):
response = client.delete(
f"/api/v2/credentials/{credential}", headers=admin_auth_header
)

assert response.status_code == 204

response = client.get("/api/v2/credentials/1", headers=admin_auth_header)
response = client.get(
f"/api/v2/credentials/{credential}", headers=admin_auth_header
)

assert response.status_code == 404
4 changes: 3 additions & 1 deletion empire/test/test_download_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ def test_download_download(client, admin_auth_header):
)

assert response.status_code == 200
assert response.headers.get("content-disposition").startswith(
assert response.headers.get("content-disposition").lower().startswith(
'attachment; filename="test-upload-2'
) or response.headers.get("content-disposition").lower().startswith(
"attachment; filename*=utf-8''test-upload-2"
)


Expand Down
26 changes: 15 additions & 11 deletions empire/test/test_plugin_task_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@


@pytest.fixture(scope="module", autouse=True)
def plugin_task_1(main, db, models, plugin_name):
db.add(
models.PluginTask(
def plugin_task_1(main, session_local, models, plugin_name):
with session_local.begin() as db:
task = models.PluginTask(
plugin_id=plugin_name,
input="This is the trimmed input for the task.",
input_full="This is the full input for the task.",
user_id=1,
)
)
db.commit()
yield
db.add(task)
db.flush()

task_id = task.id

yield task_id

db.query(models.PluginTask).delete()
db.commit()
with session_local.begin() as db:
db.query(models.PluginTask).delete()


def test_get_tasks_for_plugin_not_found(client, admin_auth_header):
Expand Down Expand Up @@ -60,10 +63,11 @@ def test_get_task_for_plugin_not_found(client, admin_auth_header, plugin_name):
)


def test_get_task_for_plugin(client, admin_auth_header, plugin_name, db):
def test_get_task_for_plugin(client, admin_auth_header, plugin_name, db, plugin_task_1):
response = client.get(
f"/api/v2/plugins/{plugin_name}/tasks/1", headers=admin_auth_header
f"/api/v2/plugins/{plugin_name}/tasks/{plugin_task_1}",
headers=admin_auth_header,
)
assert response.status_code == 200
assert response.json()["id"] == 1
assert response.json()["id"] == plugin_task_1
assert response.json()["plugin_id"] == plugin_name
Loading

0 comments on commit 3bb06b8

Please sign in to comment.