Skip to content

Commit

Permalink
Fix bugs in agent api and update api document (infiniflow#3996)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Fix bugs in agent api and update api document

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
  • Loading branch information
Feiue and liuhua authored Dec 13, 2024
1 parent 68d46b2 commit 1ecb687
Show file tree
Hide file tree
Showing 9 changed files with 351 additions and 57 deletions.
1 change: 1 addition & 0 deletions agent/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def run(self, **kwargs):
self.path.append(["begin"])

self.path.append([])

ran = -1
waiting = []
without_dependent_checking = []
Expand Down
45 changes: 28 additions & 17 deletions api/apps/sdk/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def create_agent_session(tenant_id, agent_id):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)

canvas = Canvas(cvs.dsl, tenant_id)
if canvas.get_preset_param():
return get_error_data_result("The agent can't create a session directly")
conv = {
"id": get_uuid(),
"dialog_id": cvs.id,
Expand Down Expand Up @@ -112,6 +114,8 @@ def update(tenant_id, chat_id, session_id):
@token_required
def chat_completion(tenant_id, chat_id):
req = request.json
if not req or not req.get("session_id"):
req = {"question":""}
if not DialogService.query(tenant_id=tenant_id,id=chat_id,status=StatusEnum.VALID.value):
return get_error_data_result(f"You don't own the chat {chat_id}")
if req.get("session_id"):
Expand All @@ -125,7 +129,6 @@ def chat_completion(tenant_id, chat_id):
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")

return resp

else:
answer = None
for ans in rag_completion(tenant_id, chat_id, **req):
Expand All @@ -137,22 +140,28 @@ def chat_completion(tenant_id, chat_id):
@manager.route('/agents/<agent_id>/completions', methods=['POST']) # noqa: F821
@token_required
def agent_completions(tenant_id, agent_id):
req = request.json
if not UserCanvasService.query(user_id=tenant_id,id=agent_id):
return get_error_data_result(f"You don't own the agent {agent_id}")
if req.get("session_id"):
if not API4ConversationService.query(id=req["session_id"],dialog_id=agent_id):
return get_error_data_result(f"You don't own the session {req['session_id']}")
if req.get("stream", True):
resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp

for answer in agent_completion(tenant_id, agent_id, **req):
return get_result(data=answer)
req = request.json
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}")
if req.get("session_id"):
conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id)
if not conv:
return get_error_data_result(f"You don't own the session {req['session_id']}")
else:
req["question"]=""
if req.get("stream", True):
resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
try:
for answer in agent_completion(tenant_id, agent_id, **req):
return get_result(data=answer)
except Exception as e:
return get_error_data_result(str(e))


@manager.route('/chats/<chat_id>/sessions', methods=['GET']) # noqa: F821
Expand Down Expand Up @@ -420,3 +429,5 @@ def agent_bot_completions(agent_id):

for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
return get_result(data=answer)


38 changes: 20 additions & 18 deletions api/db/services/canvas_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,36 +55,39 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
e, cvs = UserCanvasService.get_by_id(agent_id)
assert e, "Agent not found."
assert cvs.user_id == tenant_id, "You do not own the agent."

if not isinstance(cvs.dsl, str):
if not isinstance(cvs.dsl,str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
canvas = Canvas(cvs.dsl, tenant_id)
canvas.reset()
message_id = str(uuid4())

if not session_id:
query = canvas.get_preset_param()
if query:
for ele in query:
if not ele["optional"]:
if not kwargs.get(ele["key"]):
assert False, f"`{ele['key']}` is required"
ele["value"] = kwargs[ele["key"]]
if ele["optional"]:
if kwargs.get(ele["key"]):
ele["value"] = kwargs[ele['key']]
else:
if "value" in ele:
ele.pop("value")
cvs.dsl = json.loads(str(canvas))
temp_dsl = cvs.dsl
UserCanvasService.update_by_id(agent_id, cvs.to_dict())
else:
temp_dsl = json.loads(cvs.dsl)
session_id = get_uuid()
conv = {
"id": session_id,
"dialog_id": cvs.id,
"user_id": kwargs.get("user_id", ""),
"source": "agent",
"dsl": json.loads(cvs.dsl)
"dsl": temp_dsl
}
API4ConversationService.save(**conv)
if canvas.get_preset_param():
yield "data:" + json.dumps({"code": 0,
"message": "",
"data": {
"session_id": session_id,
"answer": "",
"reference": [],
"param": canvas.get_preset_param()
}
},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
return
conv = API4Conversation(**conv)
else:
e, conv = API4ConversationService.get_by_id(session_id)
Expand All @@ -104,7 +107,6 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
conv.reference.append({"chunks": [], "doc_aggs": []})

final_ans = {"reference": [], "content": ""}

if stream:
try:
for ans in canvas.run(stream=stream):
Expand Down
3 changes: 2 additions & 1 deletion api/db/services/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
assert dia, "You do not own the chat."

if not session_id:
session_id = get_uuid()
conv = {
"id": get_uuid(),
"id":session_id ,
"dialog_id": chat_id,
"name": name,
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}]
Expand Down
Loading

0 comments on commit 1ecb687

Please sign in to comment.