diff --git a/pkg/store/mock/store.go b/pkg/store/mock/store.go index f24581c4..b3236578 100644 --- a/pkg/store/mock/store.go +++ b/pkg/store/mock/store.go @@ -147,6 +147,26 @@ func (s *Store) DeleteCounterparty(ctx context.Context, counterpartyID ulid.ULID return nil } +func (s *Store) ListContacts(ctx context.Context, counterpartyID ulid.ULID, page *models.PageInfo) (*models.ContactsPage, error) { + return nil, nil +} + +func (s *Store) CreateContact(context.Context, *models.Contact) error { + return nil +} + +func (s *Store) RetrieveContact(ctx context.Context, contactID, counterpartyID ulid.ULID) (*models.Contact, error) { + return nil, nil +} + +func (s *Store) UpdateContact(context.Context, *models.Contact) error { + return nil +} + +func (s *Store) DeleteContact(ctx context.Context, contactID, counterpartyID ulid.ULID) error { + return nil +} + func (s *Store) UseTravelAddressFactory(models.TravelAddressFactory) { } diff --git a/pkg/store/models/pagination.go b/pkg/store/models/pagination.go index 18cd1b54..0df38bbc 100644 --- a/pkg/store/models/pagination.go +++ b/pkg/store/models/pagination.go @@ -35,6 +35,11 @@ type CounterpartyPage struct { Page *PageInfo `json:"page"` } +type ContactsPage struct { + Contacts []*Contact `json:"contacts"` + Page *PageInfo `json:"page"` +} + type UserPage struct { Users []*User `json:"users"` Page *PageInfo `json:"page"` diff --git a/pkg/store/sqlite/accounts.go b/pkg/store/sqlite/accounts.go index fba0cf58..45921d69 100644 --- a/pkg/store/sqlite/accounts.go +++ b/pkg/store/sqlite/accounts.go @@ -101,10 +101,7 @@ func (s *Store) CreateAccount(ctx context.Context, account *models.Account) (err } } - if err = tx.Commit(); err != nil { - return err - } - return nil + return tx.Commit() } const retreiveAccountSQL = "SELECT * FROM accounts WHERE id=:id" @@ -168,10 +165,7 @@ func (s *Store) UpdateAccount(ctx context.Context, a *models.Account) (err error return dberr.ErrNotFound } - if err = tx.Commit(); err != nil { - return err - } - return nil + return tx.Commit() } const deleteAccountSQL = "DELETE FROM accounts WHERE id=:id" @@ -334,6 +328,8 @@ func (s *Store) RetrieveCryptoAddress(ctx context.Context, accountID, cryptoAddr return nil, err } + // TODO: retrieve account and associate it with the crypto address. + tx.Commit() return addr, nil } @@ -373,10 +369,7 @@ func (s *Store) UpdateCryptoAddress(ctx context.Context, addr *models.CryptoAddr return dberr.ErrNotFound } - if err = tx.Commit(); err != nil { - return err - } - return nil + return tx.Commit() } const deleteCryptoAddressSQL = "DELETE FROM crypto_addresses WHERE id=:cryptoAddressID and account_id=:accountID" @@ -398,8 +391,5 @@ func (s *Store) DeleteCryptoAddress(ctx context.Context, accountID, cryptoAddres return dberr.ErrNotFound } - if err = tx.Commit(); err != nil { - return err - } - return nil + return tx.Commit() } diff --git a/pkg/store/sqlite/counterparty.go b/pkg/store/sqlite/counterparty.go index d267db4d..fe30a7fe 100644 --- a/pkg/store/sqlite/counterparty.go +++ b/pkg/store/sqlite/counterparty.go @@ -42,6 +42,10 @@ func (s *Store) ListCounterparties(ctx context.Context, page *models.PageInfo) ( return nil, err } + // Ensure that contacts is non-nil and zero-valued + counterparty.SetContacts(make([]*models.Contact, 0)) + + // Append counterparty to the page out.Counterparties = append(out.Counterparties, counterparty) } @@ -84,6 +88,7 @@ func (s *Store) ListCounterpartySourceInfo(ctx context.Context, source string) ( const createCounterpartySQL = "INSERT INTO counterparties (id, source, directory_id, registered_directory, protocol, common_name, endpoint, name, website, country, business_category, vasp_categories, verified_on, ivms101, created, modified) VALUES (:id, :source, :directoryID, :registeredDirectory, :protocol, :commonName, :endpoint, :name, :website, :country, :businessCategory, :vaspCategories, :verifiedOn, :ivms101, :created, :modified)" func (s *Store) CreateCounterparty(ctx context.Context, counterparty *models.Counterparty) (err error) { + // Basic validation if !ulids.IsZero(counterparty.ID) { return dberr.ErrNoIDOnCreate } @@ -94,15 +99,26 @@ func (s *Store) CreateCounterparty(ctx context.Context, counterparty *models.Cou } defer tx.Rollback() + // Update the model metadata in place and create a new ID counterparty.ID = ulids.New() counterparty.Created = time.Now() counterparty.Modified = counterparty.Created + // Insert the counterparty if _, err = tx.Exec(createCounterpartySQL, counterparty.Params()...); err != nil { // TODO: handle constraint violations return err } + // Insert any associated contacts with the counterparty + contacts, _ := counterparty.Contacts() + for _, contact := range contacts { + contact.CounterpartyID = counterparty.ID + if err = s.createContact(tx, contact); err != nil { + return err + } + } + return tx.Commit() } @@ -115,6 +131,19 @@ func (s *Store) RetrieveCounterparty(ctx context.Context, counterpartyID ulid.UL } defer tx.Rollback() + if counterparty, err = retrieveCounterparty(tx, counterpartyID); err != nil { + return nil, err + } + + if err = s.listContacts(tx, counterparty); err != nil { + return nil, err + } + + tx.Commit() + return counterparty, nil +} + +func retrieveCounterparty(tx *sql.Tx, counterpartyID ulid.ULID) (counterparty *models.Counterparty, err error) { counterparty = &models.Counterparty{} if err = counterparty.Scan(tx.QueryRow(retreiveCounterpartySQL, sql.Named("id", counterpartyID))); err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -122,8 +151,6 @@ func (s *Store) RetrieveCounterparty(ctx context.Context, counterpartyID ulid.UL } return nil, err } - - tx.Commit() return counterparty, nil } @@ -151,6 +178,7 @@ func (s *Store) LookupCounterparty(ctx context.Context, commonName string) (coun const updateCounterpartySQL = "UPDATE counterparties SET source=:source, directory_id=:directoryID, registered_directory=:registeredDirectory, protocol=:protocol, common_name=:commonName, endpoint=:endpoint, name=:name, website=:website, country=:country, business_category=:businessCategory, vasp_categories=:vaspCategories, verified_on=:verifiedOn, ivms101=:ivms101, modified=:modified WHERE id=:id" func (s *Store) UpdateCounterparty(ctx context.Context, counterparty *models.Counterparty) (err error) { + // Basic validation if ulids.IsZero(counterparty.ID) { return dberr.ErrMissingID } @@ -161,8 +189,10 @@ func (s *Store) UpdateCounterparty(ctx context.Context, counterparty *models.Cou } defer tx.Rollback() + // Update modified timestamp (in place). counterparty.Modified = time.Now() + // Execute the update into the database var result sql.Result if result, err = tx.Exec(updateCounterpartySQL, counterparty.Params()...); err != nil { // TODO: handle constraint violations @@ -192,3 +222,195 @@ func (s *Store) DeleteCounterparty(ctx context.Context, counterpartyID ulid.ULID return tx.Commit() } + +const listContactsSQL = "SELECT * FROM contacts WHERE counterparty_id=:counterpartyID" + +// List contacts associated with the specified counterparty. +func (s *Store) ListContacts(ctx context.Context, counterpartyID ulid.ULID, page *models.PageInfo) (out *models.ContactsPage, err error) { + var tx *sql.Tx + if tx, err = s.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}); err != nil { + return nil, err + } + defer tx.Rollback() + + // Check to ensure the associated counterparty exists + var counterparty *models.Counterparty + if counterparty, err = retrieveCounterparty(tx, counterpartyID); err != nil { + return nil, err + } + + // TODO: handle pagination + out = &models.ContactsPage{ + Contacts: make([]*models.Contact, 0), + Page: models.PageInfoFrom(page), + } + + var rows *sql.Rows + if rows, err = tx.Query(listContactsSQL, sql.Named("counterpartyID", counterpartyID)); err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + contact := &models.Contact{} + if err = contact.Scan(rows); err != nil { + return nil, err + } + + contact.SetCounterparty(counterparty) + out.Contacts = append(out.Contacts, contact) + } + + if errors.Is(rows.Err(), sql.ErrNoRows) { + return nil, dberr.ErrNotFound + } + + tx.Commit() + return out, nil +} + +func (s *Store) listContacts(tx *sql.Tx, counterparty *models.Counterparty) (err error) { + var rows *sql.Rows + if rows, err = tx.Query(listContactsSQL, sql.Named("counterpartyID", counterparty.ID)); err != nil { + return err + } + defer rows.Close() + + contacts := make([]*models.Contact, 0) + for rows.Next() { + contact := &models.Contact{} + if err = contact.Scan(rows); err != nil { + return err + } + + contact.SetCounterparty(counterparty) + contacts = append(contacts, contact) + } + + counterparty.SetContacts(contacts) + return nil +} + +const createContactSQL = "INSERT INTO contacts (id, name, email, role, counterparty_id, created, modified) VALUES (:id, :name, :email, :role, :counterpartyID, :created, :modified)" + +func (s *Store) CreateContact(ctx context.Context, contact *models.Contact) (err error) { + var tx *sql.Tx + if tx, err = s.BeginTx(ctx, nil); err != nil { + return err + } + defer tx.Rollback() + + if err = s.createContact(tx, contact); err != nil { + return err + } + + return tx.Commit() +} + +func (s *Store) createContact(tx *sql.Tx, contact *models.Contact) (err error) { + if !ulids.IsZero(contact.ID) { + return dberr.ErrNoIDOnCreate + } + + if ulids.IsZero(contact.CounterpartyID) { + return dberr.ErrMissingReference + } + + // Update the model metadata in place and create a new ID + contact.ID = ulids.New() + contact.Created = time.Now() + contact.Modified = contact.Created + + if _, err = tx.Exec(createContactSQL, contact.Params()...); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return dberr.ErrNotFound + } + + // TODO: handle constraint violations + return err + } + return nil +} + +const retrieveContactSQL = "SELECT * FROM contacts WHERE id=:id and counterparty_id=:counterpartyID" + +func (s *Store) RetrieveContact(ctx context.Context, contactID, counterpartyID ulid.ULID) (contact *models.Contact, err error) { + var tx *sql.Tx + if tx, err = s.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}); err != nil { + return nil, err + } + defer tx.Rollback() + + contact = &models.Contact{} + if err = contact.Scan(tx.QueryRow(retrieveContactSQL, sql.Named("id", contactID), sql.Named("counterpartyID", counterpartyID))); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, dberr.ErrNotFound + } + return nil, err + } + + // TODO: retrieve counterparty and associate it with the contact. + + tx.Commit() + return contact, nil +} + +// TODO: this must be an upsert/delete since the data is being modified on the relation +const updateContactSQL = "UPDATE contacts SET name=:name, email=:email, role=:role, modified=:modified WHERE id=:id AND counterparty_id=:counterpartyID" + +func (s *Store) UpdateContact(ctx context.Context, contact *models.Contact) (err error) { + // Basic validation + if ulids.IsZero(contact.ID) { + return dberr.ErrMissingID + } + + if ulids.IsZero(contact.CounterpartyID) { + return dberr.ErrMissingReference + } + + var tx *sql.Tx + if tx, err = s.BeginTx(ctx, nil); err != nil { + return err + } + defer tx.Rollback() + + // Update modified timestamp (in place). + contact.Modified = time.Now() + + // Execute the update into the database + var result sql.Result + if result, err = tx.Exec(updateContactSQL, contact.Params()...); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return dberr.ErrNotFound + } + + // TODO: handle constraint violations + return err + } else if nRows, _ := result.RowsAffected(); nRows == 0 { + return dberr.ErrNotFound + } + + return tx.Commit() +} + +const deleteContact = "DELETE FROM contacts WHERE id=:id AND counterparty_id=:counterpartyID" + +func (s *Store) DeleteContact(ctx context.Context, contactID, counterpartyID ulid.ULID) (err error) { + var tx *sql.Tx + if tx, err = s.BeginTx(ctx, nil); err != nil { + return err + } + defer tx.Rollback() + + var result sql.Result + if result, err = tx.Exec(deleteContact, sql.Named("id", contactID), sql.Named("counterpartyID", counterpartyID)); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return dberr.ErrNotFound + } + return err + } else if nRows, _ := result.RowsAffected(); nRows == 0 { + return dberr.ErrNotFound + } + + return tx.Commit() +} diff --git a/pkg/store/store.go b/pkg/store/store.go index ada2c102..c784d5cc 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -43,6 +43,7 @@ type Store interface { TransactionStore AccountStore CounterpartyStore + ContactStore UserStore APIKeyStore } @@ -124,6 +125,14 @@ type CounterpartyStore interface { DeleteCounterparty(ctx context.Context, counterpartyID ulid.ULID) error } +type ContactStore interface { + ListContacts(ctx context.Context, counterpartyID ulid.ULID, page *models.PageInfo) (*models.ContactsPage, error) + CreateContact(context.Context, *models.Contact) error + RetrieveContact(ctx context.Context, contactID, counterpartyID ulid.ULID) (*models.Contact, error) + UpdateContact(context.Context, *models.Contact) error + DeleteContact(ctx context.Context, contactID, counterpartyID ulid.ULID) error +} + type TravelAddressStore interface { UseTravelAddressFactory(models.TravelAddressFactory) }