Skip to content

Commit

Permalink
Add full_info flag in /userdata endpoint to list out file size and la…
Browse files Browse the repository at this point in the history
…st modified timestamp (comfyanonymous#4905)

* Add full_info flag in /userdata endpoint to list out file size and last modified timestamp

* nit
  • Loading branch information
huchenlei authored Sep 13, 2024
1 parent f6b7194 commit cb12ad7
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 8 deletions.
33 changes: 25 additions & 8 deletions app/user_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,38 @@ async def listuserdata(request):
directory = request.rel_url.query.get('dir', '')
if not directory:
return web.Response(status=400)

path = self.get_request_user_filepath(request, directory)
if not path:
return web.Response(status=403)

if not os.path.exists(path):
return web.Response(status=404)

recurse = request.rel_url.query.get('recurse', '').lower() == "true"
results = glob.glob(os.path.join(
glob.escape(path), '**/*'), recursive=recurse)
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]

full_info = request.rel_url.query.get('full_info', '').lower() == "true"

# Use different patterns based on whether we're recursing or not
if recurse:
pattern = os.path.join(glob.escape(path), '**', '*')
else:
pattern = os.path.join(glob.escape(path), '*')

results = glob.glob(pattern, recursive=recurse)

if full_info:
results = [
{
'path': os.path.relpath(x, path),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]

split_path = request.rel_url.query.get('split', '').lower() == "true"
if split_path:
if split_path and not full_info:
results = [[x] + x.split(os.sep) for x in results]

return web.json_response(results)
Expand Down
90 changes: 90 additions & 0 deletions tests-unit/prompt_server_test/user_manager_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest
import os
from aiohttp import web
from app.user_manager import UserManager

pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module


@pytest.fixture
def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
)
return um


@pytest.fixture
def app(user_manager):
app = web.Application()
routes = web.RouteTableDef()
user_manager.add_routes(routes)
app.add_routes(routes)
return app


async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 404


async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")

client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 200
assert await resp.json() == ["file1.txt"]


async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
f.write("test content")

client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
assert set(await resp.json()) == {"file1.txt", os.path.join("subdir", "file2.txt")}


async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")

client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&full_info=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert result[0]["path"] == "file1.txt"
assert "size" in result[0]
assert "modified" in result[0]


async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")

client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
assert resp.status == 200
assert await resp.json() == [
[os.path.join("subdir", "file1.txt"), "subdir", "file1.txt"]
]


async def test_listuserdata_invalid_directory(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=")
assert resp.status == 400

0 comments on commit cb12ad7

Please sign in to comment.