Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check last_pruned instead of is_pruning #2748

Merged
merged 5 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions backend/danswer/server/documents/cc_pair.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from datetime import datetime
from http import HTTPStatus

from fastapi import APIRouter
Expand Down Expand Up @@ -201,12 +202,12 @@ def update_cc_pair_name(
raise HTTPException(status_code=400, detail="Name must be unique")


@router.get("/admin/cc-pair/{cc_pair_id}/prune")
def get_cc_pair_latest_prune(
@router.get("/admin/cc-pair/{cc_pair_id}/last_pruned")
def get_cc_pair_last_pruned(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> bool:
) -> datetime | None:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
Expand All @@ -216,11 +217,10 @@ def get_cc_pair_latest_prune(
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
detail="cc_pair not found for current user's permissions",
)

rcp = RedisConnectorPruning(cc_pair.id)
return rcp.is_pruning(db_session, get_redis_client())
return cc_pair.last_pruned


@router.post("/admin/cc-pair/{cc_pair_id}/prune")
Expand Down
23 changes: 16 additions & 7 deletions backend/tests/integration/common_utils/managers/cc_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,31 +274,40 @@ def prune(
result.raise_for_status()

@staticmethod
def is_pruning(
def last_pruned(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> bool:
) -> datetime | None:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/last_pruned",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
response_bool = response.json()
return response_bool
response_str = response.json()

# If the response itself is a datetime string, parse it
if not isinstance(response_str, str):
return None

try:
return datetime.fromisoformat(response_str)
except ValueError:
return None

@staticmethod
def wait_for_prune(
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: The task register time must be after this time."""
start = time.monotonic()
while True:
result = CCPairManager.is_pruning(cc_pair, user_performing_action)
if not result:
last_pruned = CCPairManager.last_pruned(cc_pair, user_performing_action)
if last_pruned and last_pruned > after:
break

elapsed = time.monotonic() - start
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ def test_slack_prune(
)

# Prune the cc_pair
now = datetime.now(timezone.utc)
CCPairManager.prune(cc_pair, user_performing_action=admin_user)
CCPairManager.wait_for_prune(cc_pair, user_performing_action=admin_user)
CCPairManager.wait_for_prune(cc_pair, now, user_performing_action=admin_user)

# ----------------------------VERIFY THE CHANGES---------------------------
# Ensure admin user can't see deleted messages
Expand Down
54 changes: 50 additions & 4 deletions backend/tests/integration/tests/pruning/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from time import sleep
from typing import Any

import uvicorn
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles

from danswer.server.documents.models import DocumentSource
from danswer.utils.logger import setup_logger
from tests.integration.common_utils.managers.api_key import APIKeyManager
Expand All @@ -21,10 +25,50 @@
logger = setup_logger()


# FastAPI server for serving files
def create_fastapi_app(directory: str) -> FastAPI:
app = FastAPI()

# Mount the directory to serve static files
app.mount("/", StaticFiles(directory=directory, html=True), name="static")

return app


# as far as we know, this doesn't hang when crawled. This is good.
@contextmanager
def fastapi_server_context(
directory: str, port: int = 8000
) -> Generator[None, None, None]:
app = create_fastapi_app(directory)

config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level="info")
server = uvicorn.Server(config)

# Create a thread to run the FastAPI server
server_thread = threading.Thread(target=server.run)
server_thread.daemon = (
True # Ensures the thread will exit when the main program exits
)

try:
# Start the server in the background
server_thread.start()
sleep(5) # Give it a few seconds to start
yield # Yield control back to the calling function (context manager in use)
finally:
# Shutdown the server
server.should_exit = True
server_thread.join()


# Leaving this here for posterity and experimentation, but the reason we're
# not using this is python's web servers hang frequently when crawled
# this is obviously not good for a unit test
@contextmanager
def http_server_context(
directory: str, port: int = 8000
) -> Generator[http.server.HTTPServer, None, None]:
) -> Generator[http.server.ThreadingHTTPServer, None, None]:
# Create a handler that serves files from the specified directory
def handler_class(
*args: Any, **kwargs: Any
Expand All @@ -34,7 +78,7 @@ def handler_class(
)

# Create an HTTPServer instance
httpd = http.server.HTTPServer(("0.0.0.0", port), handler_class)
httpd = http.server.ThreadingHTTPServer(("0.0.0.0", port), handler_class)

# Define a thread that runs the server in the background
server_thread = threading.Thread(target=httpd.serve_forever)
Expand All @@ -45,6 +89,7 @@ def handler_class(
try:
# Start the server in the background
server_thread.start()
sleep(5) # give it a few seconds to start
yield httpd
finally:
# Shutdown the server and wait for the thread to finish
Expand All @@ -70,7 +115,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
website_src = os.path.join(test_directory, "website")
website_tgt = os.path.join(temp_dir, "website")
shutil.copytree(website_src, website_tgt)
with http_server_context(os.path.join(temp_dir, "website"), port):
with fastapi_server_context(os.path.join(temp_dir, "website"), port):
sleep(1) # sleep a tiny bit before starting everything

hostname = os.getenv("TEST_WEB_HOSTNAME", "localhost")
Expand Down Expand Up @@ -105,9 +150,10 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
logger.info("Removing courses.html.")
os.remove(os.path.join(website_tgt, "courses.html"))

now = datetime.now(timezone.utc)
CCPairManager.prune(cc_pair_1, user_performing_action=admin_user)
CCPairManager.wait_for_prune(
cc_pair_1, timeout=60, user_performing_action=admin_user
cc_pair_1, now, timeout=60, user_performing_action=admin_user
)

selected_cc_pair = CCPairManager.get_one(
Expand Down