diff --git a/models/contacts.go b/models/contacts.go index 7062794d9..1127355c8 100644 --- a/models/contacts.go +++ b/models/contacts.go @@ -205,18 +205,26 @@ func ContactIDsFromReferences(ctx context.Context, tx Queryer, orgID OrgID, refs // ContactIDsFromUUIDs queries the contacts for the passed in org, returning the contact ids for the UUIDs func ContactIDsFromUUIDs(ctx context.Context, tx Queryer, orgID OrgID, uuids []flows.ContactUUID) ([]ContactID, error) { - ids := make([]ContactID, 0, len(uuids)) - rows, err := tx.QueryxContext(ctx, `SELECT id FROM contacts_contact WHERE org_id = $1 AND uuid = ANY($2) AND is_active = TRUE`, orgID, pq.Array(uuids)) + ids, err := queryContactIDs(ctx, tx, `SELECT id FROM contacts_contact WHERE org_id = $1 AND uuid = ANY($2) AND is_active = TRUE`, orgID, pq.Array(uuids)) if err != nil { return nil, errors.Wrapf(err, "error selecting contact ids by UUID") } + return ids, nil +} + +func queryContactIDs(ctx context.Context, tx Queryer, query string, args ...interface{}) ([]ContactID, error) { + ids := make([]ContactID, 0, 10) + rows, err := tx.QueryxContext(ctx, query, args...) + if err != nil { + return nil, err + } defer rows.Close() var id ContactID for rows.Next() { err := rows.Scan(&id) if err != nil { - return nil, errors.Wrapf(err, "error scanning contact id") + return nil, err } ids = append(ids, id) } @@ -671,17 +679,29 @@ func CreateContact(ctx context.Context, db *sqlx.DB, org *OrgAssets, userID User } // GetOrCreateContact creates a new contact for the passed in org with the passed in URNs -func GetOrCreateContact(ctx context.Context, db *sqlx.DB, org *OrgAssets, urn urns.URN) (*Contact, *flows.Contact, error) { +func GetOrCreateContact(ctx context.Context, db *sqlx.DB, org *OrgAssets, urnz []urns.URN) (*Contact, *flows.Contact, error) { created := true - contactID, err := insertContactAndURNs(ctx, db, org, UserID(1), "", envs.NilLanguage, []urns.URN{urn}) + contactID, err := insertContactAndURNs(ctx, db, org, UserID(1), "", envs.NilLanguage, urnz) if err != nil { if dbutil.IsUniqueViolation(err) { - // if this was a duplicate URN, we should be able to fetch this contact instead - err := db.GetContext(ctx, &contactID, `SELECT contact_id FROM contacts_contacturn WHERE org_id = $1 AND identity = $2`, org.OrgID(), urn.Identity()) + // if we blew up because URNs are already taken by other contacts, find who owns them + identities := make([]string, len(urnz)) + for i := range urnz { + identities[i] = string(urnz[i].Identity()) + } + + contactIDs, err := queryContactIDs(ctx, db, `SELECT DISTINCT contact_id FROM contacts_contacturn WHERE org_id = $1 AND identity = ANY($2)`, org.OrgID(), pq.Array(identities)) if err != nil { - return nil, nil, errors.Wrapf(err, "unable to load contact") + return nil, nil, errors.Wrapf(err, "error querying contacts with URNs") } + + if len(contactIDs) == 1 { + contactID = contactIDs[0] + } else { + // TODO !! + } + created = false } else { return nil, nil, err @@ -1257,23 +1277,9 @@ func UpdateContactURNs(ctx context.Context, tx Queryer, org *OrgAssets, changes if len(inserts) > 0 { // find the unique ids of the contacts that may be affected by our URN inserts - rows, err := tx.QueryxContext(ctx, - `SELECT contact_id FROM contacts_contacturn WHERE identity = ANY($1) AND org_id = $2 AND contact_id IS NOT NULL`, - pq.Array(identities), org.OrgID(), - ) + orphanedIDs, err := queryContactIDs(ctx, tx, `SELECT contact_id FROM contacts_contacturn WHERE identity = ANY($1) AND org_id = $2 AND contact_id IS NOT NULL`, pq.Array(identities), org.OrgID()) if err != nil { - return errors.Wrapf(err, "error finding contacts for urns") - } - defer rows.Close() - - orphanedIDs := make([]ContactID, 0, len(inserts)) - for rows.Next() { - var contactID ContactID - err := rows.Scan(&contactID) - if err != nil { - return errors.Wrapf(err, "error reading orphaned contacts") - } - orphanedIDs = append(orphanedIDs, contactID) + return errors.Wrapf(err, "error finding contacts for URNs") } // then insert new urns, we do these one by one since we have to deal with conflicts diff --git a/models/contacts_test.go b/models/contacts_test.go index 04d2f9048..075622921 100644 --- a/models/contacts_test.go +++ b/models/contacts_test.go @@ -559,20 +559,20 @@ func TestGetOrCreateContact(t *testing.T) { tcs := []struct { OrgID OrgID - URN urns.URN + URNs []urns.URN ContactID ContactID }{ - {Org1, CathyURN, CathyID}, - {Org1, urns.URN(CathyURN.String() + "?foo=bar"), CathyID}, - {Org1, urns.URN("telegram:12345678"), ContactID(maxContactID + 3)}, - {Org1, urns.URN("telegram:12345678"), ContactID(maxContactID + 3)}, + {Org1, []urns.URN{CathyURN}, CathyID}, + {Org1, []urns.URN{urns.URN(CathyURN.String() + "?foo=bar")}, CathyID}, + {Org1, []urns.URN{urns.URN("telegram:12345678")}, ContactID(maxContactID + 3)}, + {Org1, []urns.URN{urns.URN("telegram:12345678")}, ContactID(maxContactID + 3)}, } org, err := GetOrgAssets(ctx, db, Org1) assert.NoError(t, err) for i, tc := range tcs { - contact, _, err := GetOrCreateContact(ctx, db, org, tc.URN) + contact, _, err := GetOrCreateContact(ctx, db, org, tc.URNs) assert.NoError(t, err, "%d: error creating contact", i) assert.Equal(t, tc.ContactID, contact.ID(), "%d: mismatch in contact id", i) } diff --git a/models/imports.go b/models/imports.go index 05fad17ab..8e78de981 100644 --- a/models/imports.go +++ b/models/imports.go @@ -162,7 +162,7 @@ func (b *ContactImportBatch) getOrCreateContacts(ctx context.Context, db *sqlx.D } else { // TODO get or create by multiple URNs - imp.contact, imp.flowContact, err = GetOrCreateContact(ctx, db, oa, spec.URNs[0]) + imp.contact, imp.flowContact, err = GetOrCreateContact(ctx, db, oa, spec.URNs) if err != nil { addError("Unable to get or create contact with URN '%s'", spec.URNs[0]) continue diff --git a/web/ivr/ivr.go b/web/ivr/ivr.go index 900a248c5..b5f9d5819 100644 --- a/web/ivr/ivr.go +++ b/web/ivr/ivr.go @@ -102,7 +102,7 @@ func handleIncomingCall(ctx context.Context, s *web.Server, r *http.Request, w h } // get the contact for this URN - contact, _, err := models.GetOrCreateContact(ctx, s.DB, oa, urn) + contact, _, err := models.GetOrCreateContact(ctx, s.DB, oa, []urns.URN{urn}) if err != nil { return channel, nil, client.WriteErrorResponse(w, errors.Wrapf(err, "unable to get contact by urn")) } diff --git a/web/surveyor/surveyor.go b/web/surveyor/surveyor.go index 0ed1a323e..c9f8f3d5b 100644 --- a/web/surveyor/surveyor.go +++ b/web/surveyor/surveyor.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" + "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/envs" "github.com/nyaruka/goflow/flows" @@ -95,7 +96,7 @@ func handleSubmit(ctx context.Context, s *web.Server, r *http.Request) (interfac // create / fetch our contact based on the highest priority URN urn := fs.Contact().URNs()[0].URN() - _, flowContact, err = models.GetOrCreateContact(ctx, s.DB, oa, urn) + _, flowContact, err = models.GetOrCreateContact(ctx, s.DB, oa, []urns.URN{urn}) if err != nil { return nil, http.StatusInternalServerError, errors.Wrapf(err, "unable to look up contact") }