Skip to content

Commit

Permalink
Refactor msgCatalog and create service
Browse files Browse the repository at this point in the history
  • Loading branch information
Robi9 committed Oct 25, 2023
1 parent df2d2d8 commit b49e253
Show file tree
Hide file tree
Showing 9 changed files with 374 additions and 208 deletions.
7 changes: 7 additions & 0 deletions core/goflow/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var classificationFactory func(*runtime.Config) engine.ClassificationServiceFact
var ticketFactory func(*runtime.Config) engine.TicketServiceFactory
var airtimeFactory func(*runtime.Config) engine.AirtimeServiceFactory
var externalServiceFactory func(*runtime.Config) engine.ExternalServiceServiceFactory
var msgCatalogFactory func(*runtime.Config) engine.MsgCatalogServiceFactory

// RegisterEmailServiceFactory can be used by outside callers to register a email factory
// for use by the engine
Expand Down Expand Up @@ -49,6 +50,10 @@ func RegisterExternalServiceServiceFactory(f func(*runtime.Config) engine.Extern
externalServiceFactory = f
}

func RegisterMsgCatalogServiceFactory(f func(*runtime.Config) engine.MsgCatalogServiceFactory) {
msgCatalogFactory = f
}

// Engine returns the global engine instance for use with real sessions
func Engine(c *runtime.Config) flows.Engine {
engInit.Do(func() {
Expand All @@ -65,6 +70,7 @@ func Engine(c *runtime.Config) flows.Engine {
WithEmailServiceFactory(emailFactory(c)).
WithTicketServiceFactory(ticketFactory(c)).
WithExternalServiceServiceFactory(externalServiceFactory((c))).
WithMsgCatalogServiceFactory(msgCatalogFactory((c))). // msg catalog
WithAirtimeServiceFactory(airtimeFactory(c)).
WithMaxStepsPerSprint(c.MaxStepsPerSprint).
WithMaxResumesPerSession(c.MaxResumesPerSession).
Expand All @@ -88,6 +94,7 @@ func Simulator(c *runtime.Config) flows.Engine {
WithWebhookServiceFactory(webhooks.NewServiceFactory(httpClient, nil, httpAccess, webhookHeaders, c.WebhooksMaxBodyBytes)).
WithClassificationServiceFactory(classificationFactory(c)). // simulated sessions do real classification
WithExternalServiceServiceFactory(externalServiceFactory((c))). // and real external services
WithMsgCatalogServiceFactory(msgCatalogFactory((c))). // msg catalog
WithEmailServiceFactory(simulatorEmailServiceFactory). // but faked emails
WithTicketServiceFactory(simulatorTicketServiceFactory). // and faked tickets
WithAirtimeServiceFactory(simulatorAirtimeServiceFactory). // and faked airtime transfers
Expand Down
124 changes: 0 additions & 124 deletions core/handlers/msg_catalog_created.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,15 @@ package handlers

import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"

"github.com/nyaruka/gocommon/urns"
"github.com/nyaruka/goflow/envs"
"github.com/nyaruka/goflow/flows"
"github.com/nyaruka/goflow/flows/events"
"github.com/nyaruka/mailroom/core/goflow"
"github.com/nyaruka/mailroom/core/hooks"
"github.com/nyaruka/mailroom/core/models"
"github.com/nyaruka/mailroom/runtime"
"github.com/nyaruka/mailroom/services/external/openai/chatgpt"
"github.com/nyaruka/mailroom/services/external/weni/sentenx"
"github.com/nyaruka/mailroom/services/external/weni/wenigpt"

"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
Expand Down Expand Up @@ -117,38 +110,6 @@ func handleMsgCatalogCreated(ctx context.Context, rt *runtime.Runtime, tx *sqlx.
}
}

// if is smart product catalog msg
if len(event.Msg.Products()) == 0 && event.Msg.Smart() {
content := event.Msg.ProductSearch()
productList, err := GetProductListFromWeniGPT(ctx, rt, content)
if err != nil {
return err
}
catalog, err := models.GetActiveCatalogFromChannel(ctx, *rt.DB, channel.ID())
if err != nil {
return err
}
channelThreshold := channel.ConfigValue("threshold", "1.5")
searchThreshold, err := strconv.ParseFloat(channelThreshold, 64)
if err != nil {
return err
}

productRetailerIDS := []string{}

for _, product := range productList {
searchResult, err := GetProductListFromSentenX(product, catalog.FacebookCatalogID(), searchThreshold, rt)
if err != nil {
return errors.Wrapf(err, "on iterate to search products on sentenx")
}
for _, prod := range searchResult {
productRetailerIDS = append(productRetailerIDS, prod["product_retailer_id"])
}
}

event.Msg.Products_ = productRetailerIDS
} // if is not smart catalog, event already have products

msg, err := models.NewOutgoingFlowMsgCatalog(rt, oa.Org(), channel, scene.Session(), event.Msg, event.CreatedOn())
if err != nil {
return errors.Wrapf(err, "error creating outgoing message to %s", event.Msg.URN())
Expand All @@ -164,88 +125,3 @@ func handleMsgCatalogCreated(ctx context.Context, rt *runtime.Runtime, tx *sqlx.

return nil
}

func GetProductListFromWeniGPT(ctx context.Context, rt *runtime.Runtime, content string) ([]string, error) {
httpClient, httpRetries, _ := goflow.HTTP(rt.Config)
weniGPTClient := wenigpt.NewClient(httpClient, httpRetries, rt.Config.WeniGPTBaseURL, rt.Config.WeniGPTAuthToken, rt.Config.WeniGPTCookie)

prompt := fmt.Sprintf(`Give me an unformatted JSON list containing strings with the name of each product taken from the user prompt. Never repeat the same product. Always use this pattern: {\"products\": []}. Request: %s. Response:`, content)

dr := wenigpt.NewWenigptRequest(
prompt,
0,
0.0,
0.0,
true,
wenigpt.DefaultStopSequences,
)

response, _, err := weniGPTClient.WeniGPTRequest(dr)
if err != nil {
return nil, errors.Wrapf(err, "error on wewnigpt call fot list products")
}

productsJson := response.Output.Text

var products map[string][]string
err = json.Unmarshal([]byte(productsJson), &products)
if err != nil {
return nil, errors.Wrapf(err, "error on unmarshalling product list")
}
return products["products"], nil
}

func GetProductListFromChatGPT(ctx context.Context, rt *runtime.Runtime, content string) ([]string, error) {
httpClient, httpRetries, _ := goflow.HTTP(rt.Config)
chatGPTClient := chatgpt.NewClient(httpClient, httpRetries, rt.Config.ChatGPTBaseURL, rt.Config.ChatGPTKey)

prompt1 := chatgpt.ChatCompletionMessage{
Role: chatgpt.ChatMessageRoleSystem,
Content: "Give me an unformatted JSON list containing strings with the name of each product taken from the user prompt.",
}
prompt2 := chatgpt.ChatCompletionMessage{
Role: chatgpt.ChatMessageRoleSystem,
Content: "Never repeat the same product.",
}
prompt3 := chatgpt.ChatCompletionMessage{
Role: chatgpt.ChatMessageRoleSystem,
Content: "Always use this pattern: {\"products\": []}",
}
question := chatgpt.ChatCompletionMessage{
Role: chatgpt.ChatMessageRoleUser,
Content: content,
}
completionRequest := chatgpt.NewChatCompletionRequest([]chatgpt.ChatCompletionMessage{prompt1, prompt2, prompt3, question})
response, _, err := chatGPTClient.CreateChatCompletion(completionRequest)
if err != nil {
return nil, errors.Wrapf(err, "error on chatgpt call for list products")
}

productsJson := response.Choices[0].Message.Content

var products map[string][]string
err = json.Unmarshal([]byte(productsJson), &products)
if err != nil {
return nil, errors.Wrapf(err, "error on unmarshalling product list")
}
return products["products"], nil
}

func GetProductListFromSentenX(productSearch string, catalogID string, threshold float64, rt *runtime.Runtime) ([]map[string]string, error) {
client := sentenx.NewClient(http.DefaultClient, nil, rt.Config.SentenXBaseURL)

searchParams := sentenx.NewSearchRequest(productSearch, catalogID, threshold)

searchResponse, _, err := client.SearchProducts(searchParams)
if err != nil {
return nil, err
}

pmap := []map[string]string{}
for _, p := range searchResponse.Products {
mapElement := map[string]string{"product_retailer_id": p.ProductRetailerID}
pmap = append(pmap, mapElement)
}

return pmap, nil
}
11 changes: 11 additions & 0 deletions core/models/assets.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ type OrgAssets struct {
externalServices []assets.ExternalService
externalServicesByID map[ExternalServiceID]*ExternalService
externalServicesByUUID map[assets.ExternalServiceUUID]*ExternalService

msgCatalogs []assets.MsgCatalog
}

var ErrNotFound = errors.New("not found")
Expand Down Expand Up @@ -381,6 +383,10 @@ func NewOrgAssets(ctx context.Context, rt *runtime.Runtime, orgID OrgID, prev *O
oa.externalServicesByUUID = prev.externalServicesByUUID
}

if prev == nil || refresh&RefreshMsgCatalogs > 0 {
oa.msgCatalogs = []assets.MsgCatalog{}
}

// intialize our session assets
oa.sessionAssets, err = engine.NewSessionAssets(oa.Env(), oa, goflow.MigrationConfig(rt.Config))
if err != nil {
Expand Down Expand Up @@ -414,6 +420,7 @@ const (
RefreshTopics = Refresh(1 << 15)
RefreshUsers = Refresh(1 << 16)
RefreshExternalServices = Refresh(1 << 17)
RefreshMsgCatalogs = Refresh(1 << 18)
)

// GetOrgAssets creates or gets org assets for the passed in org
Expand Down Expand Up @@ -706,3 +713,7 @@ func (a *OrgAssets) ExternalServiceByID(id ExternalServiceID) *ExternalService {
func (a *OrgAssets) ExternalServiceByUUID(uuid assets.ExternalServiceUUID) *ExternalService {
return a.externalServicesByUUID[uuid]
}

func (a *OrgAssets) MsgCatalogs() ([]assets.MsgCatalog, error) {
return a.msgCatalogs, nil
}
88 changes: 67 additions & 21 deletions core/models/catalog_products.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,38 @@
package models

import (
"context"
"database/sql"
"database/sql/driver"
"net/http"
"time"

"github.com/jmoiron/sqlx"
"github.com/nyaruka/gocommon/httpx"
"github.com/nyaruka/gocommon/uuids"
"github.com/nyaruka/goflow/flows"
"github.com/nyaruka/goflow/flows/engine"
"github.com/nyaruka/mailroom/core/goflow"
"github.com/nyaruka/mailroom/runtime"
"github.com/nyaruka/null"
"github.com/pkg/errors"
)

type CatalogID null.Int

func (i CatalogID) MarshalJSON() ([]byte, error) {
return null.Int(i).MarshalJSON()
}

func (i *CatalogID) UnmarshalJSON(b []byte) error {
return null.UnmarshalInt(b, (*null.Int)(i))
}

func (i CatalogID) Value() (driver.Value, error) {
return null.Int(i).Value()
}

func (i *CatalogID) Scan(value interface{}) error {
return null.ScanInt(value, (*null.Int)(i))
}

// CatalogProduct represents a product catalog from Whatsapp channels.
type CatalogProduct struct {
c struct {
Expand All @@ -38,24 +58,50 @@ func (c *CatalogProduct) IsActive() bool { return c.c.IsActive }
func (c *CatalogProduct) ChannelID() ChannelID { return c.c.ChannelID }
func (c *CatalogProduct) OrgID() OrgID { return c.c.OrgID }

const getActiveCatalogSQL = `
SELECT
id, uuid, facebook_catalog_id, name, created_on, modified_on, is_active, channel_id, org_id
FROM public.wpp_products_catalog
WHERE channel_id = $1 AND is_active = true
`

// GetActiveCatalogFromChannel returns the active catalog from the given channel
func GetActiveCatalogFromChannel(ctx context.Context, db sqlx.DB, channelID ChannelID) (*CatalogProduct, error) {
var catalog CatalogProduct

err := db.GetContext(ctx, &catalog.c, getActiveCatalogSQL, channelID)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrapf(err, "error getting active catalog for channelID: %d", channelID)
type MsgCatalog struct {
e struct {
ID CatalogID `json:"id,omitempty"`
ChannelUUID uuids.UUID `json:"uuid,omitempty"`
OrgID OrgID `json:"org_id,omitempty"`
Name string `json:"name,omitempty"`
Config map[string]string `json:"config,omitempty"`
Type string `json:"type,omitempty"`
}
}

func (c *MsgCatalog) ChannelUUID() uuids.UUID { return c.e.ChannelUUID }
func (c *MsgCatalog) Name() string { return c.e.Name }
func (c *MsgCatalog) Type() string { return c.e.Type }

func init() {
goflow.RegisterMsgCatalogServiceFactory(msgCatalogServiceFactory)
}

func msgCatalogServiceFactory(c *runtime.Config) engine.MsgCatalogServiceFactory {
return func(session flows.Session, msgCatalog *flows.MsgCatalog) (flows.MsgCatalogService, error) {
return msgCatalog.Asset().(*MsgCatalog).AsService(c, msgCatalog)
}
}

func (e *MsgCatalog) AsService(cfg *runtime.Config, msgCatalog *flows.MsgCatalog) (MsgCatalogService, error) {
httpClient, httpRetries, _ := goflow.HTTP(cfg)

initFunc := msgCatalogServices["msg_catalog"]
if initFunc != nil {
return initFunc(cfg, httpClient, httpRetries, msgCatalog, e.e.Config)
}

return &catalog, nil
return nil, errors.Errorf("unrecognized product catalog '%s'", e.e.Name)
}

type MsgCatalogServiceFunc func(*runtime.Config, *http.Client, *httpx.RetryConfig, *flows.MsgCatalog, map[string]string) (MsgCatalogService, error)

var msgCatalogServices = map[string]MsgCatalogServiceFunc{}

type MsgCatalogService interface {
flows.MsgCatalogService
}

func RegisterMsgCatalogService(name string, initFunc MsgCatalogServiceFunc) {
msgCatalogServices[name] = initFunc
}
Loading

0 comments on commit b49e253

Please sign in to comment.