Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Billing #2667

Merged
merged 22 commits into from
Oct 12, 2024
Merged

Billing #2667

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
COPY ./danswer /app/danswer
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf

Expand Down
33 changes: 5 additions & 28 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import jwt
from email_validator import EmailNotValidError
from email_validator import EmailUndeliverableError
from email_validator import validate_email
from fastapi import APIRouter
from fastapi import Depends
Expand Down Expand Up @@ -41,10 +42,8 @@
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DATA_PLANE_SECRET
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import EXPECTED_API_KEY
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SECRET_JWT_KEY
Expand Down Expand Up @@ -129,7 +128,10 @@ def verify_email_is_invited(email: str) -> None:
if not email:
raise PermissionError("Email must be specified")

email_info = validate_email(email) # can raise EmailNotValidError
try:
email_info = validate_email(email)
except EmailUndeliverableError:
raise PermissionError("Email is not valid")

for email_whitelist in whitelist:
try:
Expand Down Expand Up @@ -652,28 +654,3 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
def get_default_admin_user_emails_() -> list[str]:
# No default seeding available for Danswer MIT
return []


async def control_plane_dep(request: Request) -> None:
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key")
raise HTTPException(status_code=401, detail="Invalid API key")

auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
logger.warning("Invalid authorization header")
raise HTTPException(status_code=401, detail="Invalid authorization header")

token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
if payload.get("scope") != "tenant:create":
logger.warning("Insufficient permissions")
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
logger.warning("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")
24 changes: 20 additions & 4 deletions backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,27 @@
AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")


# Cloud configuration

# Multi-tenancy configuration
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"

# Security and authentication
SECRET_JWT_KEY = os.environ.get(
"SECRET_JWT_KEY", ""
) # Used for encryption of the JWT token for user's tenant context
DATA_PLANE_SECRET = os.environ.get(
"DATA_PLANE_SECRET", ""
) # Used for secure communication between the control and data plane
EXPECTED_API_KEY = os.environ.get(
"EXPECTED_API_KEY", ""
) # Additional security check for the control plane API

DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")
# API configuration
CONTROL_PLANE_API_BASE_URL = os.environ.get(
"CONTROL_PLANE_API_BASE_URL", "http://localhost:8082"
)

ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
# JWT configuration
JWT_ALGORITHM = "HS256"
16 changes: 14 additions & 2 deletions backend/danswer/db/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import Session

from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserRole
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
Expand All @@ -33,10 +35,20 @@ def get_default_admin_user_emails() -> list[str]:
return get_default_admin_user_emails_fn()


def get_total_users(db_session: Session) -> int:
"""
Returns the total number of users in the system.
This is the sum of users and invited users.
"""
user_count = db_session.query(User).count()
invited_users = len(get_invited_users())
return user_count + invited_users


async def get_user_count() -> int:
async with get_async_session_with_tenant() as asession:
async with get_async_session_with_tenant() as session:
stmt = select(func.count(User.id))
result = await asession.execute(stmt)
result = await session.execute(stmt)
user_count = result.scalar()
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")
Expand Down
3 changes: 3 additions & 0 deletions backend/danswer/document_index/vespa/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ def ensure_indices_exist(
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
schema = schema.replace(TENANT_ID_PAT, "")

schema = add_ngrams_to_schema(schema) if needs_reindexing else schema

zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")

if self.secondary_index_name:
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/server/auth_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from fastapi.dependencies.models import Dependant
from starlette.routing import BaseRoute

from danswer.auth.users import control_plane_dep
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import current_user_with_expired_token
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.server.danswer_api.ingestion import api_key_dep
from ee.danswer.server.tenants.access import control_plane_dep
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not import ee modules from the MIT stuff. It confuses the separation between the versions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other places this is done but we should address those.



PUBLIC_ENDPOINT_SPECS = [
Expand Down
67 changes: 54 additions & 13 deletions backend/danswer/server/manage/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from datetime import timezone

import jwt
from email_validator import EmailNotValidError
from email_validator import EmailUndeliverableError
from email_validator import validate_email
from fastapi import APIRouter
from fastapi import Body
Expand Down Expand Up @@ -35,6 +37,7 @@
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType
from danswer.db.auth import get_total_users
from danswer.db.engine import current_tenant_id
from danswer.db.engine import get_session
from danswer.db.models import AccessToken
Expand All @@ -60,6 +63,7 @@
from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
from ee.danswer.db.user_group import remove_curator_status__no_commit
from ee.danswer.server.tenants.billing import register_tenant_users
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here and around this line, it should not import from ee

from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant

Expand Down Expand Up @@ -174,19 +178,29 @@ def list_all_users(
def bulk_invite_users(
emails: list[str] = Body(..., embed=True),
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> int:
"""emails are string validated. If any email fails validation, no emails are
invited and an exception is raised."""

if current_user is None:
raise HTTPException(
status_code=400, detail="Auth is disabled, cannot invite users"
)

tenant_id = current_tenant_id.get()

normalized_emails = []
for email in emails:
email_info = validate_email(email) # can raise EmailNotValidError
normalized_emails.append(email_info.normalized) # type: ignore
try:
for email in emails:
email_info = validate_email(email)
normalized_emails.append(email_info.normalized) # type: ignore

except (EmailUndeliverableError, EmailNotValidError):
raise HTTPException(
status_code=400,
detail="One or more emails in the list are invalid",
)

if MULTI_TENANT:
try:
Expand All @@ -199,30 +213,58 @@ def bulk_invite_users(
)
raise

all_emails = list(set(normalized_emails) | set(get_invited_users()))
initial_invited_users = get_invited_users()

if MULTI_TENANT and ENABLE_EMAIL_INVITES:
try:
for email in all_emails:
send_user_email_invite(email, current_user)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
all_emails = list(set(normalized_emails) | set(initial_invited_users))
number_of_invited_users = write_invited_users(all_emails)

return write_invited_users(all_emails)
if not MULTI_TENANT:
return number_of_invited_users
try:
logger.info("Registering tenant users")
register_tenant_users(current_tenant_id.get(), get_total_users(db_session))
if ENABLE_EMAIL_INVITES:
try:
for email in all_emails:
send_user_email_invite(email, current_user)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")

return number_of_invited_users
except Exception as e:
logger.error(f"Failed to register tenant users: {str(e)}")
logger.info(
"Reverting changes: removing users from tenant and resetting invited users"
)
write_invited_users(initial_invited_users) # Reset to original state
remove_users_from_tenant(normalized_emails, tenant_id)
raise e


@router.patch("/manage/admin/remove-invited-user")
def remove_invited_user(
user_email: UserByEmail,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> int:
user_emails = get_invited_users()
remaining_users = [user for user in user_emails if user != user_email.user_email]

tenant_id = current_tenant_id.get()
remove_users_from_tenant([user_email.user_email], tenant_id)
number_of_invited_users = write_invited_users(remaining_users)

try:
if MULTI_TENANT:
register_tenant_users(current_tenant_id.get(), get_total_users(db_session))
except Exception:
logger.error(
"Request to update number of seats taken in control plane failed. "
"This may cause synchronization issues/out of date enforcement of seat limits."
)
raise

return write_invited_users(remaining_users)
return number_of_invited_users


@router.patch("/manage/admin/deactivate-user")
Expand Down Expand Up @@ -421,7 +463,6 @@ def get_current_token_creation(

@router.get("/me")
def verify_user_logged_in(
request: Request,
user: User | None = Depends(optional_user),
db_session: Session = Depends(get_session),
) -> UserInfo:
Expand Down
7 changes: 7 additions & 0 deletions backend/danswer/server/settings/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ class PageType(str, Enum):
SEARCH = "search"


class GatingType(str, Enum):
FULL = "full" # Complete restriction of access to the product or service
PARTIAL = "partial" # Full access but warning (no credit card on file)
NONE = "none" # No restrictions, full access to all features


class Notification(BaseModel):
id: int
notif_type: NotificationType
Expand All @@ -38,6 +44,7 @@ class Settings(BaseModel):
default_page: PageType = PageType.SEARCH
maximum_chat_retention_days: int | None = None
gpu_enabled: bool | None = None
product_gating: GatingType = GatingType.NONE

def check_validity(self) -> None:
chat_page_enabled = self.chat_page_enabled
Expand Down
22 changes: 13 additions & 9 deletions backend/danswer/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from textwrap import dedent
from typing import Any

from danswer.configs.app_configs import SMTP_PASS
Expand Down Expand Up @@ -58,22 +59,25 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
def send_user_email_invite(user_email: str, current_user: User) -> None:
msg = MIMEMultipart()
msg["Subject"] = "Invitation to Join Danswer Workspace"
msg["To"] = user_email
msg["From"] = current_user.email
msg["To"] = user_email

email_body = f"""
Hello,
email_body = dedent(
f"""\
Hello,

You have been invited to join a workspace on Danswer.
You have been invited to join a workspace on Danswer.

To join the workspace, please do so at the following link:
{WEB_DOMAIN}/auth/login
To join the workspace, please visit the following link:

Best regards,
The Danswer Team"""
{WEB_DOMAIN}/auth/login

msg.attach(MIMEText(email_body, "plain"))
Best regards,
The Danswer Team
"""
)

msg.attach(MIMEText(email_body, "plain"))
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp_server:
smtp_server.starttls()
smtp_server.login(SMTP_USER, SMTP_PASS)
Expand Down
4 changes: 4 additions & 0 deletions backend/ee/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@
# Auto Permission Sync
#####
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)


STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
6 changes: 4 additions & 2 deletions backend/ee/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def get_application() -> FastAPI:

# RBAC / group access control
include_router_with_global_prefix_prepended(application, user_group_router)
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)
# Analytics endpoints
include_router_with_global_prefix_prepended(application, analytics_router)
include_router_with_global_prefix_prepended(application, query_history_router)
Expand All @@ -107,6 +105,10 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, enterprise_settings_router)
include_router_with_global_prefix_prepended(application, usage_export_router)

if MULTI_TENANT:
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)

# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)

Expand Down
Loading
Loading