Skip to content

Commit

Permalink
Merge branch 'chatgpt-v2' into test
Browse files Browse the repository at this point in the history
  • Loading branch information
LanaBot committed Nov 7, 2023
2 parents f4c5aa8 + 4792491 commit 971567e
Show file tree
Hide file tree
Showing 9 changed files with 377 additions and 169 deletions.
15 changes: 13 additions & 2 deletions kairos/api/event_logs_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from flask import Blueprint
from flask_cors import CORS

from kairos.services import event_logs_service, cases_service
from kairos.services import event_logs_service, cases_service, openai_service

event_logs_api = Blueprint('event_logs_api','event_logs_api',url_prefix='/event_logs')

Expand All @@ -12,6 +12,7 @@
event_logs_api.route('/<event_log_id>/cases/<case_completion>',methods=['GET'])(cases_service.get_cases_by_log_and_completion)

event_logs_api.route('', methods=['GET'])(event_logs_service.get_logs)

event_logs_api.route('', methods=['POST'])(event_logs_service.save_log)

event_logs_api.route('/<event_log_id>',methods=['GET'])(event_logs_service.get_log)
Expand All @@ -21,6 +22,7 @@
event_logs_api.route('/<event_log_id>/column_types', methods=['PUT'])(event_logs_service.define_log_column_types)

event_logs_api.route('/<event_log_id>/parameters',methods=['GET'])(event_logs_service.get_log_parameters)

event_logs_api.route('/<event_log_id>/parameters',methods=['POST'])(event_logs_service.define_log_parameters)

event_logs_api.route('/<event_log_id>/prescriptions',methods=['GET'])(event_logs_service.get_log_prescriptions)
Expand All @@ -30,8 +32,17 @@
event_logs_api.route('/<event_log_id>/status',methods=['GET'])(event_logs_service.get_project_status)

event_logs_api.route('/<event_log_id>/simulate/start', methods=['PUT'])(event_logs_service.start_simulation)

event_logs_api.route('/<event_log_id>/simulate/stop', methods=['PUT'])(event_logs_service.stop_simulation)

event_logs_api.route('/<event_log_id>/simulate/clear', methods=['PUT'])(event_logs_service.clear_stream)

event_logs_api.route('/<event_log_id>/results', methods=['GET'])(event_logs_service.get_static_results)
event_logs_api.route('/<event_log_id>/results', methods=['GET'])(event_logs_service.get_static_results)

# openai

event_logs_api.route('/<event_log_id>/openai/history', methods=['GET'])(openai_service.get_messages_for_log)

event_logs_api.route('/<event_log_id>/cases/<case_id>/openai/history', methods=['GET'])(openai_service.get_messages_for_case)

event_logs_api.route('/<event_log_id>/cases/<case_id>/openai', methods=['POST'])(openai_service.get_answer)
41 changes: 41 additions & 0 deletions kairos/models/messages_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from werkzeug.local import LocalProxy
from kairos.models.db import get_db

db = LocalProxy(get_db)

def get_messages():
return list(db.messages.find({'context': False},{'_id': False}))

def get_messages_by_log_id(event_log_id):
return list(db.messages.find({'context': False, 'event_log_id': int(event_log_id)},{'_id': False}))

def get_messages_by_case_id(case_id):
return list(db.messages.find({'context': False, 'case_id': case_id},{'_id': False}))

def get_context():
return list(db.messages.find({},{'_id': False,'context': False, 'event_log_id': False, 'case_id': False}))

def get_context_by_log_id(event_log_id):
return list(db.messages.find({'event_log_id': int(event_log_id)},{'_id': False,'context': False, 'event_log_id': False, 'case_id': False}))

def get_context_by_case_id(case_id):
return list(db.messages.find({'case_id': case_id},{'_id': False,'context': False, 'event_log_id': False, 'case_id': False}))

def save_message(role,content,context=False,event_log_id=None,case_id=None):
new_message = {
'role': role,
'content': content,
'context': context,
'event_log_id': event_log_id,
'case_id': case_id
}
return db.messages.insert_one(new_message)

def count_messages():
return db.messages.count_documents({'context': False})

def delete_messages():
return db.messages.deleteMany({})

def get_system_messages():
return list(db.messages.find({'role': 'system'}))
2 changes: 2 additions & 0 deletions kairos/services/event_logs_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import kairos.models.event_logs_model as event_logs_db
import kairos.models.cases_model as cases_db

from kairos.enums.project_status import Status as PROJECT_STATUS

import kairos.services.prcore_service as prcore_service
import kairos.utils as k_utils

Expand Down
41 changes: 41 additions & 0 deletions kairos/services/openai_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from flask import request, jsonify, current_app

import kairos.models.messages_model as messages_db
import kairos.utils.openai as openai_utils

def get_messages_for_log(event_log_id):
try:
messages = messages_db.get_messages_by_log_id(event_log_id=event_log_id)
except Exception as e:
current_app.logger.error(f'{request.method} {request.path} 500 - {e}')
return jsonify(error=str(e)),500
current_app.logger.info(f'{request.method} {request.path} 200')
return jsonify(memory = messages),200

def get_messages_for_case(event_log_id,case_id):
try:
messages = messages_db.get_messages_by_case_id(case_id=case_id)
except Exception as e:
current_app.logger.error(f'{request.method} {request.path} 500 - {e}')
return jsonify(error=str(e)),500
current_app.logger.info(f'{request.method} {request.path} 200')
return jsonify(memory = messages),200

def get_answer(event_log_id,case_id):
if not event_log_id or not case_id:
current_app.logger.error('Event log id and case id cannot be null.')
return jsonify(error='Please specify event_log_id and case_id.'),403

question = request.get_json().get('question')
if not question:
current_app.logger.error('Question cannot be null.')
return jsonify(error='Please specify a question.'),403

try:
answer = openai_utils.ask_ai(content=question, event_log_id=event_log_id,case_id=case_id)
except Exception as e:
current_app.logger.error(f'{request.method} {request.path} 500 - {e}')
return jsonify(error=str(e)),500

current_app.logger.info(f'{request.method} {request.path} 200')
return jsonify(answer = answer),200
10 changes: 3 additions & 7 deletions kairos/services/prcore_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sseclient
from flask import current_app

import kairos.utils as k_utils
import kairos.utils.events as events_utils

def response(res,status=False,json=True):
if res.status_code != 200:
Expand Down Expand Up @@ -85,16 +85,12 @@ def start_stream(project_id):

event_data = json.loads(event.data)
first_event = event_data[0]
# print(f"ID: {event.id}")

case_id = k_utils.record_event(first_event,event.id,project_id)

# print("-" * 24)

case_id = events_utils.record_event(first_event,event.id,project_id)

def get_static_results(project_id,result_key):
res = requests.get(current_app.config.get('PRCORE_BASE_URL') + f'/project/{project_id}/result/{result_key}', headers=current_app.config.get('PRCORE_HEADERS'))
res_json = response(res)
message = res_json.get('message')
k_utils.record_results(project_id,res_json)
events_utils.record_results(project_id,res_json)
return message
160 changes: 0 additions & 160 deletions kairos/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dateutil import parser
from datetime import timedelta
from pymongo.errors import DuplicateKeyError
from collections import Counter
from flask import current_app

Expand All @@ -12,7 +11,6 @@
from kairos.enums.column_type import Column_type as COLUMN_TYPE

import kairos.models.cases_model as cases_db
import kairos.models.event_logs_model as event_logs_db

EVALUATION_METHODS = {
'EQUAL':lambda x,y: x == y,'NOT_EQUAL':lambda x,y: x!=y,'CONTAINS': lambda x,y: y in x,'NOT_CONTAINS':lambda x,y: y not in x,
Expand Down Expand Up @@ -78,92 +76,6 @@ def format_positive_outcome(positive_outcome):
item.pop('unit')
return prcore_outcome

def format_additional_info(additional_info):
if not additional_info: return None
prcore_additional_info = copy.deepcopy(additional_info)

treatment_duration = prcore_additional_info['plugin_causallift_resource_allocation']['treatment_duration']
prcore_additional_info['plugin_causallift_resource_allocation']['treatment_duration'] = f'{treatment_duration.get("value")}{treatment_duration.get("unit")}'
prcore_additional_info['plugin-causallift-resource-allocation'] = prcore_additional_info.pop('plugin_causallift_resource_allocation')
return prcore_additional_info

def record_event(event_data,event_id,project_id):
try:
log = event_logs_db.get_event_log_by_project_id(project_id)
except Exception as e:
current_app.logger.error(f'Error while getting event log by project_id {project_id}: {str(e)}')
return
event_log_id = log.get('_id')
columns_definition = log.get("columns_definition")
columns_definition_reverse = log.get('columns_definition_reverse')
case_attributes_definition = log.get('case_attributes')

case_id = 0
activity = {'event_id': event_id}
case_attributes = {}


for column,value in event_data.get('data').items():
column_type = columns_definition.get(column)
value = parse_value(column_type,value)

if column_type == 'CASE_ID':
case_id = value
elif column in case_attributes_definition:
case_attributes[column] = value
else:
activity[column] = value

prescriptions = event_data.get("prescriptions")
prescriptions_with_output = [prescriptions[p] for p in prescriptions if prescriptions[p]["output"]]

for prescription in prescriptions_with_output:
if prescription.get('type') == 'TREATMENT_EFFECT':
try:
category = categorize_cate(event_log_id,prescription)
prescription.get('output',{})['cate_category'] = category
except Exception as e:
current_app.logger.error(f"Error occured while categorizing cate in case {case_id}: {str(e)}")

case_completed = event_data.get('case_completed')
if case_completed:
prescriptions_with_output = []
activity['prescriptions'] = prescriptions_with_output

try:
old_case = cases_db.get_case(case_id)
except Exception:
old_case = None

if not old_case:
_id = cases_db.save_case(case_id,event_log_id,case_completed,[activity],case_attributes).inserted_id
else:
try:
update_case_prescriptions(old_case,activity,columns_definition_reverse.get(COLUMN_TYPE.ACTIVITY))
except Exception as e:
current_app.logger.error(f'Failed to update case {case_id} prescriptions: {e}')

cases_db.update_case(case_id,case_completed,activity)

case_performance = {}
try:
case_performance = calculate_case_performance(case_id,log.get('positive_outcome'),columns_definition, columns_definition_reverse)
except Exception as e:
current_app.logger.error(f'Failed to calculate case {case_id} performance: {e}')

try:
cases_db.update_case_performance(case_id,case_performance)
except Exception as e:
current_app.logger.error(f'Failed to update case {case_id} performance: {e}')

current_app.logger.info(f'''STREAMING RESULT:
event_log_id: {event_log_id},
project_id: {project_id},
case_id: {case_id},
prescriptions: {prescriptions}''')
return case_id


def update_case_prescriptions(my_case,new_activity,activity_column):
new_activities = my_case.get('activities')
last_activity = new_activities[-1]
Expand Down Expand Up @@ -298,78 +210,6 @@ def parse_value(column_type,value):

return value

def record_results(project_id,result):
if result.get('cases') == None:
return

try:
log = event_logs_db.get_event_log_by_project_id(project_id)
except Exception as e:
current_app.logger.error(f'Error while getting event log by project id {project_id}: {str(e)}')
return

event_log_id = log.get('_id')
columns = result.get('columns')
columns_definition = log.get("columns_definition")
columns_definition_reverse = log.get('columns_definition_reverse')
case_attributes_definition = log.get('case_attributes')
suffix = generate_suffix()

for case_id, case_body in result.get('cases',{}).items():
events = case_body.get('events',[])
prescriptions = case_body.get('prescriptions',[])
prescriptions_with_output = [p for p in prescriptions if p["output"]]
activities = []
case_attributes = {}

for prescription in prescriptions_with_output:
if prescription.get('type') == 'TREATMENT_EFFECT':
try:
category = categorize_cate(event_log_id,prescription)
prescription.get('output',{})['cate_category'] = category
except Exception as e:
current_app.logger.error(f"Error occured while categorizing cate in case {case_id}: {str(e)}")

for i in range(len(events)):
event_data = events[i]
event_data = dict(zip(columns, event_data))
activity = {'event_id': i}

for column,value in event_data.items():
column_type = columns_definition.get(column)
value = parse_value(column_type,value)

if column_type == 'CASE_ID':
continue
elif column in case_attributes_definition:
case_attributes[column] = value
else:
activity[column] = value
if i == (len(events) - 1):
activity['prescriptions'] = prescriptions_with_output

activities.append(activity)
case_completed = False

while True:
try:
_id = cases_db.save_case(suffix + str(case_id),event_log_id,case_completed,activities,case_attributes).inserted_id
break
except DuplicateKeyError:
suffix = generate_suffix()

case_performance = {}
try:
case_performance = calculate_case_performance(_id,log.get('positive_outcome'),columns_definition, columns_definition_reverse)
except Exception as e:
current_app.logger(f'Failed to calculate case {case_id} performance: {e}')

try:
cases_db.update_case_performance(_id,case_performance)
except Exception as e:
current_app.logger(f'Failed to update case {case_id} performance: {e}')

event_logs_db.update_event_log(event_log_id,{'got_results': True})

def generate_suffix():
rand = ''.join(random.choices(string.ascii_letters + string.digits, k=8))
Expand Down
Loading

0 comments on commit 971567e

Please sign in to comment.