Skip to content

Commit

Permalink
Adding reranking function #39
Browse files Browse the repository at this point in the history
  • Loading branch information
MrCsabaToth committed Oct 18, 2024
1 parent de89be9 commit a841d6f
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 0 deletions.
109 changes: 109 additions & 0 deletions functions/fn_impl/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from firebase_functions import https_fn
from firebase_admin import initialize_app, storage
import firebase_admin
import json

from google.cloud import discoveryengine_v1beta as discoveryengine

@https_fn.on_request()
def rerank(req: https_fn.Request) -> https_fn.Response:
"""Synthesizes speech from the input string of text or ssml.
Returns:
Encoded audio file in the body.
Note: ssml must be well-formed according to:
https://www.w3.org/TR/speech-synthesis/
"""
# Set CORS headers for the preflight request
if req.method == 'OPTIONS':
# Allows GET requests from any origin with the Content-Type
# header and caches preflight response for an 3600s
headers = {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'GET, POST',
'Access-Control-Allow-Headers': 'Content-Type',
'Access-Control-Max-Age': '3600'
}

return ('', 204, headers)

if not firebase_admin._apps:
initialize_app()

request_json = req.get_json(silent=True)
request_args = req.args
request_form = req.form

if request_json and 'data' in request_json:
request_json = request_json['data']

if request_json and 'top_n' in request_json:
top_n = request_json['top_n']
elif request_args and 'top_n' in request_args:
top_n = request_args['top_n']
elif request_form and 'top_n' in request_form:
top_n = request_form['top_n']
else:
top_n = 10

if request_json and 'query' in request_json:
query = request_json['query']
elif request_args and 'query' in request_args:
query = request_args['query']
elif request_form and 'query' in request_form:
query = request_form['query']
else:
query = ''

if not query:
return [], 400

if request_json and 'records' in request_json:
records = request_json['records']
elif request_args and 'records' in request_args:
records = request_args['records']
elif request_form and 'records' in request_form:
records = request_form['records']
else:
records = []

if not records:
return [], 400

ranking_records = []
for record in records:
ranking_records.append(
discoveryengine.RankingRecord(
id=record['id'],
title=record['title'],
content=record['content'],
)
)

project_id = 'open-mmpa'
region = 'us-central1'
client = discoveryengine.RankServiceClient()

# The full resource name of the ranking config.
# Format: projects/{project_id}/locations/{location}/rankingConfigs/default_ranking_config
ranking_config = client.ranking_config_path(
project=project_id,
location=region,
ranking_config="default_ranking_config",
)
# https://cloud.google.com/generative-ai-app-builder/docs/ranking#models
# semantic-ranker-512-003, Text (25 languages)
request = discoveryengine.RankRequest(
ranking_config=ranking_config,
model="semantic-ranker-512@latest",
top_n=top_n,
query=query,
records=ranking_records,
)

response = client.rank(request=request)

return https_fn.Response(
json.dumps(dict(data=response)),
status=200,
content_type='application/json',
)
Empty file removed functions/fn_impl/reranking.py
Empty file.
1 change: 1 addition & 0 deletions functions/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
firebase_functions~=0.4.2
google-api-core~=2.21.0
google-cloud-discoveryengine~=0.12.3
google-cloud-logging~=3.11.3
google-cloud-speech~=2.27.0
google-cloud-texttospeech~=2.18.0
Expand Down

0 comments on commit a841d6f

Please sign in to comment.