Skip to content

Commit

Permalink
Merge branch 'refs/heads/feat/web-app-sso' into deploy/dev
Browse files Browse the repository at this point in the history
* refs/heads/feat/web-app-sso:
  feat: web sso app
  Fixed a bug where permission was clearly displaye… (#6934)
  fix: The permissions issue of the editor role accessing some backend … (#6945)
  Fix: tag & settings modal in dataset card in Firefox (#6953)
  fix: ensure db migration in docker entry script running with `upgrade-db` command for proper locking (#6946)
  chore: fix markdown format and one typo (#6939)
  fix: restore xinference secret field (#6941)
  Fix increase_usage of total_price in agent_runner (#6688)
  fix: import workflow errors (#6937)
  Workflow TTS playback node filtering issue. (#6877)
  compatible xinference reranker server (#6927)
  fix: workflow trace user_id error (#6932)
  fix: sending app trace data to other app trace provider (#6931)
  • Loading branch information
ZhouhaoJiang committed Aug 6, 2024
2 parents f5320d1 + 425f871 commit 17e307d
Show file tree
Hide file tree
Showing 28 changed files with 199 additions and 105 deletions.
36 changes: 12 additions & 24 deletions api/controllers/console/app/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ class AnnotationReplyActionApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def post(self, app_id, action):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

app_id = str(app_id)
Expand All @@ -47,8 +46,7 @@ class AppAnnotationSettingDetailApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

app_id = str(app_id)
Expand All @@ -61,8 +59,7 @@ class AppAnnotationSettingUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, app_id, annotation_setting_id):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

app_id = str(app_id)
Expand All @@ -82,8 +79,7 @@ class AnnotationReplyActionStatusApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id, action):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

job_id = str(job_id)
Expand All @@ -110,8 +106,7 @@ class AnnotationListApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

page = request.args.get('page', default=1, type=int)
Expand All @@ -135,8 +130,7 @@ class AnnotationExportApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

app_id = str(app_id)
Expand All @@ -154,8 +148,7 @@ class AnnotationCreateApi(Resource):
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

app_id = str(app_id)
Expand All @@ -174,8 +167,7 @@ class AnnotationUpdateDeleteApi(Resource):
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

app_id = str(app_id)
Expand All @@ -191,8 +183,7 @@ def post(self, app_id, annotation_id):
@login_required
@account_initialization_required
def delete(self, app_id, annotation_id):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

app_id = str(app_id)
Expand All @@ -207,8 +198,7 @@ class AnnotationBatchImportApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

app_id = str(app_id)
Expand All @@ -232,8 +222,7 @@ class AnnotationBatchImportStatusApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

job_id = str(job_id)
Expand All @@ -259,8 +248,7 @@ class AnnotationHitHistoryListApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id, annotation_id):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

page = request.args.get('page', default=1, type=int)
Expand Down
28 changes: 28 additions & 0 deletions api/controllers/console/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from libs.login import login_required
from services.app_dsl_service import AppDslService
from services.app_service import AppService
from services.feature_service import FeatureService

ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']

Expand Down Expand Up @@ -362,6 +363,32 @@ def post(self, app_id):
return {"result": "success"}


class AppSSOApi(Resource):

@setup_required
@login_required
@account_initialization_required
def get(self):
return FeatureService.get_system_features().model_dump()

@setup_required
@login_required
@account_initialization_required
def patch(self):
parser = reqparse.RequestParser()
parser.add_argument('exclude_app_id_list', type=list, location='json')

if not current_user.is_editor:
raise Forbidden()

args = parser.parse_args()

current_user_id = current_user.id
FeatureService.update_web_sso_exclude_apps(args['exclude_app_id_list'], current_user_id)

return {"result": "success"}


api.add_resource(AppListApi, '/apps')
api.add_resource(AppImportApi, '/apps/import')
api.add_resource(AppImportFromUrlApi, '/apps/import/url')
Expand All @@ -373,3 +400,4 @@ def post(self, app_id):
api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
api.add_resource(AppTraceApi, '/apps/<uuid:app_id>/trace')
api.add_resource(AppSSOApi, '/apps/web-sso')
4 changes: 2 additions & 2 deletions api/controllers/console/app/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class ChatConversationApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_with_summary_pagination_fields)
def get(self, app_model):
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('keyword', type=str, location='args')
Expand Down Expand Up @@ -245,7 +245,7 @@ class ChatConversationDetailApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_detail_fields)
def get(self, app_model, conversation_id):
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)

Expand Down
3 changes: 1 addition & 2 deletions api/controllers/console/app/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ class MessageAnnotationApi(Resource):
@get_app_model
@marshal_with(annotation_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()

parser = reqparse.RequestParser()
Expand Down
3 changes: 2 additions & 1 deletion api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def patch(self, dataset_id):
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get('partial_member_list')
)
else:
# clear partial member list when permission is only_me or all_team_members
elif data.get('permission') == 'only_me' or data.get('permission') == 'all_team_members':
DatasetPermissionService.clear_partial_member_list(dataset_id_str)

partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
Expand Down
5 changes: 2 additions & 3 deletions api/controllers/console/datasets/datasets_segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,7 @@ def post(self, dataset_id, document_id):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
Expand Down Expand Up @@ -347,7 +346,7 @@ def delete(self, dataset_id, document_id, segment_id):
if not segment:
raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
Expand Down
6 changes: 4 additions & 2 deletions api/controllers/web/passport.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

class PassportResource(Resource):
"""Base resource for passport."""
def get(self):

def get(self, app_id):
system_features = FeatureService.get_system_features()
if system_features.sso_enforced_for_web:
web_sso_exclude_apps = system_features.sso_exclude_apps

if system_features.sso_enforced_for_web and app_id not in web_sso_exclude_apps:
raise WebSSOAuthRequiredError()

app_code = request.headers.get('X-App-Code')
Expand Down
1 change: 1 addition & 0 deletions api/core/agent/cot_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price

model_instance = self.model_instance

Expand Down
1 change: 1 addition & 0 deletions api/core/agent/fc_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price

model_instance = self.model_instance

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import re
import threading

from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueTextChunkEvent
from core.app.entities.queue_entities import (
QueueAgentMessageEvent,
QueueLLMChunkEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
)
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType

Expand Down Expand Up @@ -88,6 +93,8 @@ def _runtime(self):
self.msg_text += message.event.chunk.delta.message.content
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent):
self.msg_text += message.event.outputs.get('output', '')
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
Expand Down
7 changes: 6 additions & 1 deletion api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,12 @@ def _process_stream_response(
:return:
"""
for message in self._queue_manager.listen():
if publisher:
if hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher:
publisher.publish(message=message)
elif (hasattr(message.event, 'execution_metadata')
and message.event.execution_metadata
and message.event.execution_metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
event = message.event

Expand Down
3 changes: 2 additions & 1 deletion api/core/app/apps/agent_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def generate(self, app_model: App,
)

# get tracing instance
trace_manager = TraceQueueManager(app_model.id)
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id, user_id)

# init application generate entity
application_generate_entity = AgentChatAppGenerateEntity(
Expand Down
3 changes: 2 additions & 1 deletion api/core/app/apps/workflow/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def generate(
)

# get tracing instance
trace_manager = TraceQueueManager(app_model.id)
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id, user_id)

# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
Expand Down
2 changes: 2 additions & 0 deletions api/core/app/task_pipeline/workflow_cycle_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _workflow_run_success(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)

Expand Down Expand Up @@ -173,6 +174,7 @@ def _workflow_run_failed(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)

Expand Down
53 changes: 45 additions & 8 deletions api/core/model_runtime/model_providers/xinference/rerank/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,22 @@ def _invoke(self, model: str, credentials: dict,
server_url = server_url[:-1]
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}

params = {
'documents': docs,
'query': query,
'top_n': top_n,
'return_documents': True
}
try:
handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
response = handle.rerank(
documents=docs,
query=query,
top_n=top_n,
return_documents=True
)
response = handle.rerank(**params)
except RuntimeError as e:
raise InvokeServerUnavailableError(str(e))
if "rerank hasn't support extra parameter" not in str(e):
raise InvokeServerUnavailableError(str(e))

# compatible xinference server between v0.10.1 - v0.12.1, not support 'return_len'
handle = RESTfulRerankModelHandleWithoutExtraParameter(model_uid, server_url, auth_headers)
response = handle.rerank(**params)

rerank_documents = []
for idx, result in enumerate(response['results']):
Expand Down Expand Up @@ -167,8 +172,40 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK,
model_properties={ },
model_properties={},
parameter_rules=[]
)

return entity


class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle):

def rerank(
self,
documents: list[str],
query: str,
top_n: Optional[int] = None,
max_chunks_per_doc: Optional[int] = None,
return_documents: Optional[bool] = None,
**kwargs
):
url = f"{self._base_url}/v1/rerank"
request_body = {
"model": self._model_uid,
"documents": documents,
"query": query,
"top_n": top_n,
"max_chunks_per_doc": max_chunks_per_doc,
"return_documents": return_documents,
}

import requests

response = requests.post(url, json=request_body, headers=self.auth_headers)
if response.status_code != 200:
raise InvokeServerUnavailableError(
f"Failed to rerank documents, detail: {response.json()['detail']}"
)
response_data = response.json()
return response_data
Loading

0 comments on commit 17e307d

Please sign in to comment.