Skip to content

Commit

Permalink
feat: add confirmation for ticket booking (#407)
Browse files Browse the repository at this point in the history
Existing booking ticket require the right information for all the
fields. Agent retrieve information from chat history, hence, sometimes
it will hallucinate information if it is not recorded in the chat
history.

**Existing booking workflow:**
user request to book flight -> present confirmation page (sometimes with
hallucinated information) -> book ticket (will fail if information
inaccurate) -> ticket booked confirmation or failure

**New booking workflow:**
user request to book flight -> validate ticket (using airline, flight
number, departure airport, departure time) -> retrieve other information
from flight (such as arrival airport, arrival time etc.) -> present
confirmation page to user (with accurate information) -> book ticket ->
ticket booked confirmation

---------

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
  • Loading branch information
Yuan325 and kurtisvg authored Jun 14, 2024
1 parent 18c7b6e commit 6987280
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
from pytz import timezone

from ..orchestrator import BaseOrchestrator, classproperty
from .tools import get_confirmation_needing_tools, initialize_tools, insert_ticket
from .tools import (
get_confirmation_needing_tools,
initialize_tools,
insert_ticket,
validate_ticket,
)

set_verbose(bool(os.getenv("DEBUG", default=False)))
BASE_HISTORY = {
Expand Down Expand Up @@ -122,14 +127,19 @@ async def user_session_insert_ticket(self, uuid: str, params: str) -> Any:
response = await user_session.insert_ticket(params)
return response

def check_and_add_confirmations(cls, response: Dict[str, Any]):
async def check_and_add_confirmations(self, response: Dict[str, Any]):
for step in response.get("intermediate_steps") or []:
if len(step) > 0:
# Find the called tool in the step
called_tool = step[0]
# Check to see if the agent has made a decision to call Prepare Insert Ticket
# This tool is a no-op and requires user confirmation before continuing
if called_tool.tool in cls.confirmation_needing_tools:
if called_tool.tool in self.confirmation_needing_tools:
if called_tool.tool == "Insert Ticket":
flight_info = await validate_ticket(
self.client, called_tool.tool_input
)
return {"tool": called_tool.tool, "params": flight_info}
return {"tool": called_tool.tool, "params": called_tool.tool_input}
return None

Expand All @@ -155,7 +165,7 @@ async def user_session_invoke(self, uuid: str, prompt: str) -> dict[str, Any]:
# Send prompt to LLM
agent_response = await user_session.invoke(prompt)
# Check for calls that may require confirmation to proceed
confirmation = self.check_and_add_confirmations(agent_response)
confirmation = await self.check_and_add_confirmations(agent_response)
# Build final response
response = {}
response["output"] = agent_response.get("output")
Expand Down
50 changes: 39 additions & 11 deletions llm_demo/orchestrator/langchain_tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import json
import os
from datetime import datetime
from typing import Optional
from datetime import date, datetime
from typing import Any, Dict, Optional

import aiohttp
import google.oauth2.id_token # type: ignore
Expand All @@ -28,7 +28,7 @@
CREDENTIALS = None


def filter_none_values(params: dict) -> dict:
def filter_none_values(params: Dict) -> Dict:
return {key: value for key, value in params.items() if value is not None}


Expand Down Expand Up @@ -179,19 +179,19 @@ class TicketInput(BaseModel):
departure_airport: str = Field(
description="Departure airport 3-letter code",
)
arrival_airport: str = Field(description="Arrival airport 3-letter code")
departure_time: datetime = Field(description="Flight departure datetime")
arrival_time: datetime = Field(description="Flight arrival datetime")
arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code")
arrival_time: Optional[datetime] = Field(description="Flight arrival datetime")


def generate_insert_ticket(client: aiohttp.ClientSession):
async def insert_ticket(
airline: str,
flight_number: str,
departure_airport: str,
arrival_airport: str,
departure_time: datetime,
arrival_time: datetime,
airline: str | None = None,
flight_number: str | None = None,
departure_airport: str | None = None,
arrival_airport: str | None = None,
departure_time: datetime | date | None = None,
arrival_time: datetime | date | None = None,
):
return f"Booking ticket on {airline} {flight_number}"

Expand All @@ -216,6 +216,34 @@ async def insert_ticket(client: aiohttp.ClientSession, params: str):
return response


async def validate_ticket(client: aiohttp.ClientSession, ticket_info: Dict[Any, Any]):
response = await client.get(
url=f"{BASE_URL}/tickets/validate",
params=filter_none_values(
{
"airline": ticket_info.get("airline"),
"flight_number": ticket_info.get("flight_number"),
"departure_airport": ticket_info.get("departure_airport"),
"departure_time": ticket_info.get("departure_time", "").replace(
"T", " "
),
}
),
headers=get_headers(client),
)
response_json = await response.json()

flight_info = {
"airline": response_json.get("airline"),
"flight_number": response_json.get("flight_number"),
"departure_airport": response_json.get("departure_airport"),
"arrival_airport": response_json.get("arrival_airport"),
"departure_time": response_json.get("departure_time"),
"arrival_time": response_json.get("arrival_time"),
}
return flight_info


def generate_list_tickets(client: aiohttp.ClientSession):
async def list_tickets():
response = await client.get(
Expand Down
60 changes: 60 additions & 0 deletions retrieval_service/app/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,66 @@ def test_search_flights_with_bad_params(m_datastore, app, params):
assert response.status_code == 422


validate_ticket_params = [
pytest.param(
"validate_ticket",
{
"airline": "CY",
"flight_number": "888",
"departure_airport": "LAX",
"departure_time": "2024-01-01 08:08:08",
},
[
models.Flight(
id=1,
airline="validate_ticket",
flight_number="FOOBAR",
departure_airport="FOO",
arrival_airport="BAR",
departure_time=datetime.strptime(
"2023-01-01 05:57:00", "%Y-%m-%d %H:%M:%S"
),
arrival_time=datetime.strptime(
"2023-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"
),
departure_gate="BAZ",
arrival_gate="QUX",
)
],
[
{
"id": 1,
"airline": "validate_ticket",
"flight_number": "FOOBAR",
"departure_airport": "FOO",
"arrival_airport": "BAR",
"departure_time": "2023-01-01T05:57:00",
"arrival_time": "2023-01-01T12:13:00",
"departure_gate": "BAZ",
"arrival_gate": "QUX",
}
],
id="validate_ticket",
),
]


@pytest.mark.parametrize(
"method_name, params, mock_return, expected", validate_ticket_params
)
@patch.object(datastore, "create")
def test_validate_ticket(m_datastore, app, method_name, params, mock_return, expected):
with TestClient(app) as client:
with patch.object(
m_datastore.return_value, method_name, AsyncMock(return_value=mock_return)
) as mock_method:
response = client.get("/tickets/validate", params=params)
assert response.status_code == 200
output = response.json()
assert output == expected
assert models.Flight.model_validate(output[0])


policies_search_params = [
pytest.param(
"policies_search",
Expand Down
18 changes: 18 additions & 0 deletions retrieval_service/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,24 @@ async def insert_ticket(
return result


@routes.get("/tickets/validate")
async def validate_ticket(
request: Request,
airline: str,
flight_number: str,
departure_airport: str,
departure_time: str,
):
ds: datastore.Client = request.app.state.datastore
result = await ds.validate_ticket(
airline,
flight_number,
departure_airport,
departure_time,
)
return result


@routes.get("/tickets/list")
async def list_tickets(
request: Request,
Expand Down
2 changes: 1 addition & 1 deletion retrieval_service/coverage/.app-coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ omit =
[report]
show_missing = true
precision = 2
fail_under = 74
fail_under = 75.73
2 changes: 1 addition & 1 deletion retrieval_service/coverage/.cloudsql-coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ omit =
[report]
show_missing = true
precision = 2
fail_under = 94
fail_under = 90
10 changes: 10 additions & 0 deletions retrieval_service/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,16 @@ async def search_flights_by_airports(
) -> list[models.Flight]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def validate_ticket(
self,
airline: str,
flight_number: str,
departure_airport: str,
departure_time: str,
) -> Optional[models.Flight]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def insert_ticket(
self,
Expand Down
31 changes: 9 additions & 22 deletions retrieval_service/datastore/providers/alloydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,35 +488,31 @@ async def validate_ticket(
airline: str,
flight_number: str,
departure_airport: str,
arrival_airport: str,
departure_time: datetime,
arrival_time: datetime,
) -> bool:
departure_time: str,
) -> Optional[models.Flight]:
departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S")
async with self.__pool.connect() as conn:
s = text(
"""
SELECT * FROM flights
WHERE airline ILIKE :airline
AND flight_number ILIKE :flight_number
AND departure_airport ILIKE :departure_airport
AND arrival_airport ILIKE :arrival_airport
AND departure_time = :departure_time
AND arrival_time = :arrival_time
"""
)
params = {
"airline": airline,
"flight_number": flight_number,
"departure_airport": departure_airport,
"arrival_airport": arrival_airport,
"departure_time": departure_time,
"arrival_time": arrival_time,
"departure_time": departure_time_datetime,
}
results = (await conn.execute(s, params)).mappings().fetchall()
result = (await conn.execute(s, params)).mappings().fetchone()

if len(results) == 1:
return True
return False
if result is None:
return None
res = models.Flight.model_validate(result)
return res

async def insert_ticket(
self,
Expand All @@ -532,15 +528,6 @@ async def insert_ticket(
):
departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S")
arrival_time_datetime = datetime.strptime(arrival_time, "%Y-%m-%d %H:%M:%S")
if not await self.validate_ticket(
airline,
flight_number,
departure_airport,
arrival_airport,
departure_time_datetime,
arrival_time_datetime,
):
raise Exception("Flight information not in database")

async with self.__pool.connect() as conn:
s = text(
Expand Down
9 changes: 9 additions & 0 deletions retrieval_service/datastore/providers/cloudsql_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,15 @@ async def search_flights_by_airports(
res = [models.Flight.model_validate(r) for r in results]
return res

async def validate_ticket(
self,
airline: str,
flight_number: str,
departure_airport: str,
departure_time: str,
) -> Optional[models.Flight]:
raise NotImplementedError("Not Implemented")

async def insert_ticket(
self,
user_id: str,
Expand Down
9 changes: 9 additions & 0 deletions retrieval_service/datastore/providers/firestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,15 @@ async def search_flights_by_airports(
flights.append(models.Flight.model_validate(flight_dict))
return flights

async def validate_ticket(
self,
airline: str,
flight_number: str,
departure_airport: str,
departure_time: str,
) -> Optional[models.Flight]:
raise NotImplementedError("Not Implemented")

async def insert_ticket(
self,
user_id: str,
Expand Down
Loading

0 comments on commit 6987280

Please sign in to comment.