From d38d3a3855fdba46ed51c0c9f2254b36f5ba316a Mon Sep 17 00:00:00 2001 From: Andrea Maria Piana Date: Fri, 31 May 2019 10:16:31 +0200 Subject: [PATCH] Move multidevice logic to its own namespace Multidevice has been extracted from the encryption logic in order to make things clearer. Some old endpoint have also been removed as not used anymore. --- Makefile | 2 +- api/backend.go | 87 ------ lib/library.go | 64 ----- mobile/status.go | 60 ---- services/shhext/api.go | 5 +- services/shhext/chat/db/migrations/bindata.go | 24 +- services/shhext/chat/encryption.go | 120 ++------ .../chat/encryption_multi_device_test.go | 105 ++++--- services/shhext/chat/encryption_test.go | 266 +++++++++++------- .../shhext/chat/multidevice/persistence.go | 12 + services/shhext/chat/multidevice/service.go | 94 +++++++ .../chat/multidevice/sql_lite_persistence.go | 168 +++++++++++ .../multidevice/sql_lite_persistence_test.go | 243 ++++++++++++++++ services/shhext/chat/persistence.go | 28 +- .../chat/{ => protobuf}/encryption.pb.go | 96 +++---- .../chat/{ => protobuf}/encryption.proto | 2 +- services/shhext/chat/protocol.go | 85 ++++-- services/shhext/chat/protocol_test.go | 15 +- services/shhext/chat/sql_lite_persistence.go | 193 ++----------- .../shhext/chat/sql_lite_persistence_test.go | 232 +-------------- services/shhext/chat/x3dh.go | 58 +--- services/shhext/chat/x3dh_test.go | 52 +--- services/shhext/filter/service.go | 11 +- services/shhext/filter/service_test.go | 5 +- services/shhext/service.go | 23 +- 25 files changed, 1005 insertions(+), 1045 deletions(-) create mode 100644 services/shhext/chat/multidevice/persistence.go create mode 100644 services/shhext/chat/multidevice/service.go create mode 100644 services/shhext/chat/multidevice/sql_lite_persistence.go create mode 100644 services/shhext/chat/multidevice/sql_lite_persistence_test.go rename services/shhext/chat/{ => protobuf}/encryption.pb.go (78%) rename services/shhext/chat/{ => protobuf}/encryption.proto (98%) diff --git a/Makefile b/Makefile index 2bfdfcbb0d5..5b8d58308c0 100644 --- a/Makefile +++ b/Makefile @@ -192,7 +192,7 @@ setup: setup-build setup-dev ##@other Prepare project for development and buildi generate: ##@other Regenerate assets and other auto-generated stuff go generate ./static ./static/chat_db_migrations ./static/mailserver_db_migrations - $(shell cd ./services/shhext/chat && exec protoc --go_out=. ./*.proto) + $(shell cd ./services/shhext/chat/protobuf && exec protoc --go_out=. ./*.proto) prepare-release: clean-release mkdir -p $(RELEASE_DIR) diff --git a/api/backend.go b/api/backend.go index 2ab6fdb10e9..ce4c61ddade 100644 --- a/api/backend.go +++ b/api/backend.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/hex" "errors" "fmt" "math/big" @@ -25,7 +24,6 @@ import ( "github.com/status-im/status-go/rpc" "github.com/status-im/status-go/services/personal" "github.com/status-im/status-go/services/rpcfilters" - "github.com/status-im/status-go/services/shhext/chat" "github.com/status-im/status-go/services/shhext/chat/crypto" "github.com/status-im/status-go/services/shhext/filter" "github.com/status-im/status-go/services/subscriptions" @@ -616,91 +614,6 @@ func appendIf(condition bool, services []gethnode.ServiceConstructor, service ge return append(services, service) } -// CreateContactCode create or return the latest contact code -func (b *StatusBackend) CreateContactCode() (string, error) { - selectedChatAccount, err := b.AccountManager().SelectedChatAccount() - if err != nil { - return "", err - } - - st, err := b.statusNode.ShhExtService() - if err != nil { - return "", err - } - - bundle, err := st.GetBundle(selectedChatAccount.AccountKey.PrivateKey) - if err != nil { - return "", err - } - - return bundle.ToBase64() -} - -// GetContactCode return the latest contact code -func (b *StatusBackend) GetContactCode(identity string) (string, error) { - st, err := b.statusNode.ShhExtService() - if err != nil { - return "", err - } - - publicKeyBytes, err := hex.DecodeString(identity) - if err != nil { - return "", err - } - - publicKey, err := ethcrypto.UnmarshalPubkey(publicKeyBytes) - if err != nil { - return "", err - } - - bundle, err := st.GetPublicBundle(publicKey) - if err != nil { - return "", err - } - - if bundle == nil { - return "", nil - } - - return bundle.ToBase64() -} - -// ProcessContactCode process and adds the someone else's bundle -func (b *StatusBackend) ProcessContactCode(contactCode string) error { - selectedChatAccount, err := b.AccountManager().SelectedChatAccount() - if err != nil { - return err - } - - st, err := b.statusNode.ShhExtService() - if err != nil { - return err - } - - bundle, err := chat.FromBase64(contactCode) - if err != nil { - b.log.Error("error decoding base64", "err", err) - return err - } - - if _, err := st.ProcessPublicBundle(selectedChatAccount.AccountKey.PrivateKey, bundle); err != nil { - b.log.Error("error adding bundle", "err", err) - return err - } - - return nil -} - -// ExtractIdentityFromContactCode extract the identity of the user generating the contact code -func (b *StatusBackend) ExtractIdentityFromContactCode(contactCode string) (string, error) { - bundle, err := chat.FromBase64(contactCode) - if err != nil { - return "", err - } - - return chat.ExtractIdentity(bundle) -} - // ExtractGroupMembershipSignatures extract signatures from tuples of content/signature func (b *StatusBackend) ExtractGroupMembershipSignatures(signaturePairs [][2]string) ([]string, error) { return crypto.ExtractSignatures(signaturePairs) diff --git a/lib/library.go b/lib/library.go index e5051d2e2a8..91d76af0e30 100644 --- a/lib/library.go +++ b/lib/library.go @@ -52,70 +52,6 @@ func StopNode() *C.char { return makeJSONResponse(nil) } -// Create an X3DH bundle -//export CreateContactCode -func CreateContactCode() *C.char { - bundle, err := statusBackend.CreateContactCode() - if err != nil { - return makeJSONResponse(err) - } - - cstr := C.CString(bundle) - - return cstr -} - -//export ProcessContactCode -func ProcessContactCode(bundleString *C.char) *C.char { - err := statusBackend.ProcessContactCode(C.GoString(bundleString)) - if err != nil { - return makeJSONResponse(err) - } - - return nil -} - -// Get an X3DH bundle -//export GetContactCode -func GetContactCode(identityString *C.char) *C.char { - bundle, err := statusBackend.GetContactCode(C.GoString(identityString)) - if err != nil { - return makeJSONResponse(err) - } - - data, err := json.Marshal(struct { - ContactCode string `json:"code"` - }{ContactCode: bundle}) - if err != nil { - return makeJSONResponse(err) - } - - return C.CString(string(data)) -} - -//export ExtractIdentityFromContactCode -func ExtractIdentityFromContactCode(bundleString *C.char) *C.char { - bundle := C.GoString(bundleString) - - identity, err := statusBackend.ExtractIdentityFromContactCode(bundle) - if err != nil { - return makeJSONResponse(err) - } - - if err := statusBackend.ProcessContactCode(bundle); err != nil { - return makeJSONResponse(err) - } - - data, err := json.Marshal(struct { - Identity string `json:"identity"` - }{Identity: identity}) - if err != nil { - return makeJSONResponse(err) - } - - return C.CString(string(data)) -} - // LoadFilters load all whisper filters //export LoadFilters func LoadFilters(chatsStr *C.char) *C.char { diff --git a/mobile/status.go b/mobile/status.go index e33f98845ef..24ee9b6b922 100644 --- a/mobile/status.go +++ b/mobile/status.go @@ -65,48 +65,6 @@ func StopNode() string { return makeJSONResponse(nil) } -// CreateContactCode creates an X3DH bundle. -func CreateContactCode() string { - bundle, err := statusBackend.CreateContactCode() - if err != nil { - return makeJSONResponse(err) - } - - return bundle -} - -// ProcessContactCode processes an X3DH bundle. -// TODO(adam): it looks like the return should be error. -func ProcessContactCode(bundle string) string { - err := statusBackend.ProcessContactCode(bundle) - if err != nil { - return makeJSONResponse(err) - } - - return "" -} - -// ExtractIdentityFromContactCode extracts an identity from an X3DH bundle. -func ExtractIdentityFromContactCode(bundle string) string { - identity, err := statusBackend.ExtractIdentityFromContactCode(bundle) - if err != nil { - return makeJSONResponse(err) - } - - if err := statusBackend.ProcessContactCode(bundle); err != nil { - return makeJSONResponse(err) - } - - data, err := json.Marshal(struct { - Identity string `json:"identity"` - }{Identity: identity}) - if err != nil { - return makeJSONResponse(err) - } - - return string(data) -} - // ExtractGroupMembershipSignatures extract public keys from tuples of content/signature. func ExtractGroupMembershipSignatures(signaturePairsStr string) string { var signaturePairs [][2]string @@ -618,24 +576,6 @@ func SetSignalEventCallback(cb unsafe.Pointer) { signal.SetSignalEventCallback(cb) } -// Get an X3DH bundle -//export GetContactCode -func GetContactCode(identity string) string { - bundle, err := statusBackend.GetContactCode(identity) - if err != nil { - return makeJSONResponse(err) - } - - data, err := json.Marshal(struct { - ContactCode string `json:"code"` - }{ContactCode: bundle}) - if err != nil { - return makeJSONResponse(err) - } - - return string(data) -} - // ExportNodeLogs reads current node log and returns content to a caller. //export ExportNodeLogs func ExportNodeLogs() string { diff --git a/services/shhext/api.go b/services/shhext/api.go index 961b241d895..492e989e86e 100644 --- a/services/shhext/api.go +++ b/services/shhext/api.go @@ -21,6 +21,7 @@ import ( "github.com/status-im/status-go/db" "github.com/status-im/status-go/mailserver" "github.com/status-im/status-go/services/shhext/chat" + "github.com/status-im/status-go/services/shhext/chat/protobuf" "github.com/status-im/status-go/services/shhext/dedup" "github.com/status-im/status-go/services/shhext/filter" "github.com/status-im/status-go/services/shhext/mailservers" @@ -502,7 +503,7 @@ func (api *PublicAPI) SendDirectMessage(ctx context.Context, msg chat.SendDirect } // This is transport layer-agnostic - var protocolMessage *chat.ProtocolMessage + var protocolMessage *protobuf.ProtocolMessage // The negotiated secret var msgSpec *chat.ProtocolMessageSpec var partitionedTopicSupported bool @@ -688,7 +689,7 @@ func (api *PublicAPI) processPFSMessage(dedupMessage dedup.DeduplicateMessage) e } // Unmarshal message - protocolMessage := &chat.ProtocolMessage{} + protocolMessage := &protobuf.ProtocolMessage{} if err := proto.Unmarshal(msg.Payload, protocolMessage); err != nil { api.log.Debug("Not a protocol message", "err", err) diff --git a/services/shhext/chat/db/migrations/bindata.go b/services/shhext/chat/db/migrations/bindata.go index 8cb709743e3..5d4ca9a3b12 100644 --- a/services/shhext/chat/db/migrations/bindata.go +++ b/services/shhext/chat/db/migrations/bindata.go @@ -94,7 +94,7 @@ func _1536754952_initial_schemaDownSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1536754952_initial_schema.down.sql", size: 83, mode: os.FileMode(420), modTime: time.Unix(1557996559, 0)} + info := bindataFileInfo{name: "1536754952_initial_schema.down.sql", size: 83, mode: os.FileMode(420), modTime: time.Unix(1558598279, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -114,7 +114,7 @@ func _1536754952_initial_schemaUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1536754952_initial_schema.up.sql", size: 962, mode: os.FileMode(420), modTime: time.Unix(1557996559, 0)} + info := bindataFileInfo{name: "1536754952_initial_schema.up.sql", size: 962, mode: os.FileMode(420), modTime: time.Unix(1558598279, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -134,7 +134,7 @@ func _1539249977_update_ratchet_infoDownSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1539249977_update_ratchet_info.down.sql", size: 311, mode: os.FileMode(420), modTime: time.Unix(1557996559, 0)} + info := bindataFileInfo{name: "1539249977_update_ratchet_info.down.sql", size: 311, mode: os.FileMode(420), modTime: time.Unix(1558598279, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -154,7 +154,7 @@ func _1539249977_update_ratchet_infoUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1539249977_update_ratchet_info.up.sql", size: 368, mode: os.FileMode(420), modTime: time.Unix(1557996559, 0)} + info := bindataFileInfo{name: "1539249977_update_ratchet_info.up.sql", size: 368, mode: os.FileMode(420), modTime: time.Unix(1558598279, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -174,7 +174,7 @@ func _1540715431_add_versionDownSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1540715431_add_version.down.sql", size: 127, mode: os.FileMode(420), modTime: time.Unix(1557996559, 0)} + info := bindataFileInfo{name: "1540715431_add_version.down.sql", size: 127, mode: os.FileMode(420), modTime: time.Unix(1558598279, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -194,7 +194,7 @@ func _1540715431_add_versionUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1540715431_add_version.up.sql", size: 265, mode: os.FileMode(420), modTime: time.Unix(1557996559, 0)} + info := bindataFileInfo{name: "1540715431_add_version.up.sql", size: 265, mode: os.FileMode(420), modTime: time.Unix(1558598279, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -214,7 +214,7 @@ func _1541164797_add_installationsDownSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1541164797_add_installations.down.sql", size: 26, mode: os.FileMode(420), modTime: time.Unix(1557996559, 0)} + info := bindataFileInfo{name: "1541164797_add_installations.down.sql", size: 26, mode: os.FileMode(420), modTime: time.Unix(1558598279, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -234,7 +234,7 @@ func _1541164797_add_installationsUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1541164797_add_installations.up.sql", size: 216, mode: os.FileMode(420), modTime: time.Unix(1557996559, 0)} + info := bindataFileInfo{name: "1541164797_add_installations.up.sql", size: 216, mode: os.FileMode(420), modTime: time.Unix(1558598279, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -254,7 +254,7 @@ func _1558084410_add_topicDownSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1558084410_add_topic.down.sql", size: 54, mode: os.FileMode(420), modTime: time.Unix(1558084748, 0)} + info := bindataFileInfo{name: "1558084410_add_topic.down.sql", size: 54, mode: os.FileMode(420), modTime: time.Unix(1558598279, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -274,7 +274,7 @@ func _1558084410_add_topicUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1558084410_add_topic.up.sql", size: 298, mode: os.FileMode(420), modTime: time.Unix(1558091116, 0)} + info := bindataFileInfo{name: "1558084410_add_topic.up.sql", size: 298, mode: os.FileMode(420), modTime: time.Unix(1558598279, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -294,7 +294,7 @@ func _1558588866_add_versionUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1558588866_add_version.up.sql", size: 57, mode: os.FileMode(420), modTime: time.Unix(1558588995, 0)} + info := bindataFileInfo{name: "1558588866_add_version.up.sql", size: 57, mode: os.FileMode(420), modTime: time.Unix(1558598292, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -314,7 +314,7 @@ func staticGo() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "static.go", size: 191, mode: os.FileMode(420), modTime: time.Unix(1558084389, 0)} + info := bindataFileInfo{name: "static.go", size: 191, mode: os.FileMode(420), modTime: time.Unix(1558598279, 0)} a := &asset{bytes: bytes, info: info} return a, nil } diff --git a/services/shhext/chat/encryption.go b/services/shhext/chat/encryption.go index 7160641080b..b01e8fe1edd 100644 --- a/services/shhext/chat/encryption.go +++ b/services/shhext/chat/encryption.go @@ -5,7 +5,6 @@ import ( "crypto/ecdsa" "encoding/hex" "errors" - "fmt" "sync" "time" @@ -15,6 +14,8 @@ import ( dr "github.com/status-im/doubleratchet" "github.com/status-im/status-go/services/shhext/chat/crypto" + "github.com/status-im/status-go/services/shhext/chat/multidevice" + "github.com/status-im/status-go/services/shhext/chat/protobuf" ) var ErrSessionNotFound = errors.New("session not found") @@ -51,8 +52,6 @@ type EncryptionServiceConfig struct { BundleRefreshInterval int64 } -type IdentityAndIDPair [2]string - // DefaultEncryptionServiceConfig returns the default values used by the encryption service func DefaultEncryptionServiceConfig(installationID string) EncryptionServiceConfig { return EncryptionServiceConfig{ @@ -132,19 +131,9 @@ func (s *EncryptionService) ConfirmMessagesProcessed(messageIDs [][]byte) error } // CreateBundle retrieves or creates an X3DH bundle given a private key -func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey) (*Bundle, error) { +func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey, installations []*multidevice.Installation) (*protobuf.Bundle, error) { ourIdentityKeyC := ecrypto.CompressPubkey(&privateKey.PublicKey) - installations, err := s.persistence.GetActiveInstallations(s.config.MaxInstallations-1, ourIdentityKeyC) - if err != nil { - return nil, err - } - - installations = append(installations, &Installation{ - ID: s.config.InstallationID, - Version: protocolCurrentVersion, - }) - bundleContainer, err := s.persistence.GetAnyPrivateBundle(ourIdentityKeyC, installations) if err != nil { return nil, err @@ -176,7 +165,7 @@ func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey) (*Bundle, return nil, err } - return s.CreateBundle(privateKey) + return s.CreateBundle(privateKey, installations) } // DecryptWithDH decrypts message sent with a DH key exchange, and throws away the key after decryption @@ -224,55 +213,13 @@ func (s *EncryptionService) keyFromPassiveX3DH(myIdentityKey *ecdsa.PrivateKey, return key, nil } -func (s *EncryptionService) EnableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error { - myIdentityKeyC := ecrypto.CompressPubkey(myIdentityKey) - return s.persistence.EnableInstallation(myIdentityKeyC, installationID) -} - -func (s *EncryptionService) DisableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error { - myIdentityKeyC := ecrypto.CompressPubkey(myIdentityKey) - return s.persistence.DisableInstallation(myIdentityKeyC, installationID) -} - -// ProcessPublicBundle persists a bundle and returns a list of tuples identity/installationID -func (s *EncryptionService) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, b *Bundle) ([]IdentityAndIDPair, error) { - // Make sure the bundle belongs to who signed it - identity, err := ExtractIdentity(b) - if err != nil { - return nil, err - } - signedPreKeys := b.GetSignedPreKeys() - var response []IdentityAndIDPair - var installations []*Installation - myIdentityStr := fmt.Sprintf("0x%x", ecrypto.FromECDSAPub(&myIdentityKey.PublicKey)) - - // Any device from other peers will be considered enabled, ours needs to - // be explicitly enabled - fromOurIdentity := identity != myIdentityStr - - for installationID, signedPreKey := range signedPreKeys { - if installationID != s.config.InstallationID { - installations = append(installations, &Installation{ - ID: installationID, - Version: signedPreKey.GetProtocolVersion(), - }) - response = append(response, IdentityAndIDPair{identity, installationID}) - } - } - - if err = s.persistence.AddInstallations(b.GetIdentity(), b.GetTimestamp(), installations, fromOurIdentity); err != nil { - return nil, err - } - - if err = s.persistence.AddPublicBundle(b); err != nil { - return nil, err - } - - return response, nil +// ProcessPublicBundle persists a bundle +func (s *EncryptionService) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, b *protobuf.Bundle) error { + return s.persistence.AddPublicBundle(b) } // DecryptPayload decrypts the payload of a DirectMessageProtocol, given an identity private key and the sender's public key -func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, theirIdentityKey *ecdsa.PublicKey, theirInstallationID string, msgs map[string]*DirectMessageProtocol, messageID []byte) ([]byte, error) { +func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, theirIdentityKey *ecdsa.PublicKey, theirInstallationID string, msgs map[string]*protobuf.DirectMessageProtocol, messageID []byte) ([]byte, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -329,9 +276,9 @@ func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, thei } // Add installations with a timestamp of 0, as we don't have bundle informations - if err = s.persistence.AddInstallations(theirIdentityKeyC, 0, []*Installation{{ID: theirInstallationID, Version: 0}}, true); err != nil { - return nil, err - } + //if err = s.persistence.AddInstallations(theirIdentityKeyC, 0, []*Installation{{ID: theirInstallationID, Version: 0}}, true); err != nil { + // return nil, err + // } // We mark the exchange as successful so we stop sending x3dh header if err = s.persistence.RatchetInfoConfirmed(drHeader.GetId(), theirIdentityKeyC, theirInstallationID); err != nil { @@ -396,7 +343,7 @@ func (s *EncryptionService) createNewSession(drInfo *RatchetInfo, sk [32]byte, k return session, err } -func (s *EncryptionService) encryptUsingDR(theirIdentityKey *ecdsa.PublicKey, drInfo *RatchetInfo, payload []byte) ([]byte, *DRHeader, error) { +func (s *EncryptionService) encryptUsingDR(theirIdentityKey *ecdsa.PublicKey, drInfo *RatchetInfo, payload []byte) ([]byte, *protobuf.DRHeader, error) { var err error var session dr.Session @@ -430,7 +377,7 @@ func (s *EncryptionService) encryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr return nil, nil, err } - header := &DRHeader{ + header := &protobuf.DRHeader{ Id: drInfo.BundleID, Key: response.Header.DH[:], N: response.Header.N, @@ -474,7 +421,7 @@ func (s *EncryptionService) decryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr return plaintext, nil } -func (s *EncryptionService) encryptWithDH(theirIdentityKey *ecdsa.PublicKey, payload []byte) (*DirectMessageProtocol, error) { +func (s *EncryptionService) encryptWithDH(theirIdentityKey *ecdsa.PublicKey, payload []byte) (*protobuf.DirectMessageProtocol, error) { symmetricKey, ourEphemeralKey, err := PerformActiveDH(theirIdentityKey) if err != nil { return nil, err @@ -485,16 +432,16 @@ func (s *EncryptionService) encryptWithDH(theirIdentityKey *ecdsa.PublicKey, pay return nil, err } - return &DirectMessageProtocol{ - DHHeader: &DHHeader{ + return &protobuf.DirectMessageProtocol{ + DHHeader: &protobuf.DHHeader{ Key: ecrypto.CompressPubkey(ourEphemeralKey), }, Payload: encryptedPayload, }, nil } -func (s *EncryptionService) EncryptPayloadWithDH(theirIdentityKey *ecdsa.PublicKey, payload []byte) (map[string]*DirectMessageProtocol, error) { - response := make(map[string]*DirectMessageProtocol) +func (s *EncryptionService) EncryptPayloadWithDH(theirIdentityKey *ecdsa.PublicKey, payload []byte) (map[string]*protobuf.DirectMessageProtocol, error) { + response := make(map[string]*protobuf.DirectMessageProtocol) dmp, err := s.encryptWithDH(theirIdentityKey, payload) if err != nil { return nil, err @@ -505,23 +452,16 @@ func (s *EncryptionService) EncryptPayloadWithDH(theirIdentityKey *ecdsa.PublicK } // GetPublicBundle returns the active installations bundles for a given user -func (s *EncryptionService) GetPublicBundle(theirIdentityKey *ecdsa.PublicKey) (*Bundle, error) { - theirIdentityKeyC := ecrypto.CompressPubkey(theirIdentityKey) - - installations, err := s.persistence.GetActiveInstallations(s.config.MaxInstallations, theirIdentityKeyC) - if err != nil { - return nil, err - } - +func (s *EncryptionService) GetPublicBundle(theirIdentityKey *ecdsa.PublicKey, installations []*multidevice.Installation) (*protobuf.Bundle, error) { return s.persistence.GetPublicBundle(theirIdentityKey, installations) } // EncryptPayload returns a new DirectMessageProtocol with a given payload encrypted, given a recipient's public key and the sender private identity key // TODO: refactor this // nolint: gocyclo -func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, myIdentityKey *ecdsa.PrivateKey, payload []byte) (map[string]*DirectMessageProtocol, []*Installation, error) { +func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, myIdentityKey *ecdsa.PrivateKey, installations []*multidevice.Installation, payload []byte) (map[string]*protobuf.DirectMessageProtocol, []*multidevice.Installation, error) { // Which installations we are sending the message to - var targetedInstallations []*Installation + var targetedInstallations []*multidevice.Installation s.mutex.Lock() defer s.mutex.Unlock() @@ -530,19 +470,13 @@ func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, my theirIdentityKeyC := ecrypto.CompressPubkey(theirIdentityKey) - // Get their installationIds - installations, err := s.persistence.GetActiveInstallations(s.config.MaxInstallations, theirIdentityKeyC) - if err != nil { - return nil, nil, err - } - // We don't have any, send a message with DH if installations == nil && !bytes.Equal(theirIdentityKeyC, ecrypto.CompressPubkey(&myIdentityKey.PublicKey)) { encryptedPayload, err := s.EncryptPayloadWithDH(theirIdentityKey, payload) return encryptedPayload, targetedInstallations, err } - response := make(map[string]*DirectMessageProtocol) + response := make(map[string]*protobuf.DirectMessageProtocol) for _, installation := range installations { installationID := installation.ID @@ -550,7 +484,7 @@ func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, my if s.config.InstallationID == installationID { continue } - bundle, err := s.persistence.GetPublicBundle(theirIdentityKey, []*Installation{installation}) + bundle, err := s.persistence.GetPublicBundle(theirIdentityKey, []*multidevice.Installation{installation}) if err != nil { return nil, nil, err } @@ -570,13 +504,13 @@ func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, my return nil, nil, err } - dmp := DirectMessageProtocol{ + dmp := protobuf.DirectMessageProtocol{ Payload: encryptedPayload, DRHeader: drHeader, } if drInfo.EphemeralKey != nil { - dmp.X3DHHeader = &X3DHHeader{ + dmp.X3DHHeader = &protobuf.X3DHHeader{ Key: drInfo.EphemeralKey, Id: drInfo.BundleID, } @@ -610,7 +544,7 @@ func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, my return nil, nil, err } - x3dhHeader := &X3DHHeader{ + x3dhHeader := &protobuf.X3DHHeader{ Key: ourEphemeralKeyC, Id: theirSignedPreKey, } @@ -626,7 +560,7 @@ func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, my return nil, nil, err } - dmp := &DirectMessageProtocol{ + dmp := &protobuf.DirectMessageProtocol{ Payload: encryptedPayload, X3DHHeader: x3dhHeader, DRHeader: drHeader, diff --git a/services/shhext/chat/encryption_multi_device_test.go b/services/shhext/chat/encryption_multi_device_test.go index a3e12b4c7e4..163b91337d8 100644 --- a/services/shhext/chat/encryption_multi_device_test.go +++ b/services/shhext/chat/encryption_multi_device_test.go @@ -8,6 +8,9 @@ import ( "os" "sort" "testing" + + "github.com/status-im/status-go/services/shhext/chat/multidevice" + "github.com/status-im/status-go/services/shhext/chat/topic" ) const ( @@ -20,8 +23,8 @@ func TestEncryptionServiceMultiDeviceSuite(t *testing.T) { } type serviceAndKey struct { - encryptionServices []*EncryptionService - key *ecdsa.PrivateKey + services []*ProtocolService + key *ecdsa.PrivateKey } type EncryptionServiceMultiDeviceSuite struct { @@ -36,8 +39,8 @@ func setupUser(user string, s *EncryptionServiceMultiDeviceSuite, n int) error { } s.services[user] = &serviceAndKey{ - key: key, - encryptionServices: make([]*EncryptionService, n), + key: key, + services: make([]*ProtocolService, n), } for i := 0; i < n; i++ { @@ -50,11 +53,27 @@ func setupUser(user string, s *EncryptionServiceMultiDeviceSuite, n int) error { if err != nil { return err } + // Initialize topics + multideviceConfig := &multidevice.Config{ + MaxInstallations: n - 1, + InstallationID: installationID, + ProtocolVersion: 1, + } + + topicService := topic.NewService(persistence.GetTopicStorage()) + multideviceService := multidevice.New(multideviceConfig, persistence.GetMultideviceStorage()) - config := DefaultEncryptionServiceConfig(installationID) - config.MaxInstallations = n - 1 + protocol := NewProtocolService( + NewEncryptionService( + persistence, + DefaultEncryptionServiceConfig(installationID)), + topicService, + multideviceService, + func(s []multidevice.IdentityAndIDPair) {}, + func(s []*topic.Secret) {}, + ) - s.services[user].encryptionServices[i] = NewEncryptionService(persistence, config) + s.services[user].services[i] = protocol } @@ -73,43 +92,47 @@ func (s *EncryptionServiceMultiDeviceSuite) SetupTest() { func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundle() { aliceKey := s.services[aliceUser].key - alice2Bundle, err := s.services[aliceUser].encryptionServices[1].CreateBundle(aliceKey) + alice2Bundle, err := s.services[aliceUser].services[1].GetBundle(aliceKey) s.Require().NoError(err) - alice2Identity, err := ExtractIdentity(alice2Bundle) + alice2IdentityPK, err := ExtractIdentity(alice2Bundle) s.Require().NoError(err) - alice3Bundle, err := s.services[aliceUser].encryptionServices[2].CreateBundle(aliceKey) + alice2Identity := fmt.Sprintf("0x%x", crypto.FromECDSAPub(alice2IdentityPK)) + + alice3Bundle, err := s.services[aliceUser].services[2].GetBundle(aliceKey) s.Require().NoError(err) - alice3Identity, err := ExtractIdentity(alice2Bundle) + alice3IdentityPK, err := ExtractIdentity(alice2Bundle) s.Require().NoError(err) + alice3Identity := fmt.Sprintf("0x%x", crypto.FromECDSAPub(alice3IdentityPK)) + // Add alice2 bundle - response, err := s.services[aliceUser].encryptionServices[0].ProcessPublicBundle(aliceKey, alice2Bundle) + response, err := s.services[aliceUser].services[0].ProcessPublicBundle(aliceKey, alice2Bundle) s.Require().NoError(err) - s.Require().Equal(IdentityAndIDPair{alice2Identity, "alice2"}, response[0]) + s.Require().Equal(multidevice.IdentityAndIDPair{alice2Identity, "alice2"}, response[0]) // Add alice3 bundle - response, err = s.services[aliceUser].encryptionServices[0].ProcessPublicBundle(aliceKey, alice3Bundle) + response, err = s.services[aliceUser].services[0].ProcessPublicBundle(aliceKey, alice3Bundle) s.Require().NoError(err) - s.Require().Equal(IdentityAndIDPair{alice3Identity, "alice3"}, response[0]) + s.Require().Equal(multidevice.IdentityAndIDPair{alice3Identity, "alice3"}, response[0]) // No installation is enabled - alice1MergedBundle1, err := s.services[aliceUser].encryptionServices[0].CreateBundle(aliceKey) + alice1MergedBundle1, err := s.services[aliceUser].services[0].GetBundle(aliceKey) s.Require().NoError(err) s.Require().Equal(1, len(alice1MergedBundle1.GetSignedPreKeys())) s.Require().NotNil(alice1MergedBundle1.GetSignedPreKeys()["alice1"]) // We enable the installations - err = s.services[aliceUser].encryptionServices[0].EnableInstallation(&aliceKey.PublicKey, "alice2") + err = s.services[aliceUser].services[0].EnableInstallation(&aliceKey.PublicKey, "alice2") s.Require().NoError(err) - err = s.services[aliceUser].encryptionServices[0].EnableInstallation(&aliceKey.PublicKey, "alice3") + err = s.services[aliceUser].services[0].EnableInstallation(&aliceKey.PublicKey, "alice3") s.Require().NoError(err) - alice1MergedBundle2, err := s.services[aliceUser].encryptionServices[0].CreateBundle(aliceKey) + alice1MergedBundle2, err := s.services[aliceUser].services[0].GetBundle(aliceKey) s.Require().NoError(err) // We get back a bundle with all the installations @@ -118,21 +141,21 @@ func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundle() { s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice2"]) s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice3"]) - response, err = s.services[aliceUser].encryptionServices[0].ProcessPublicBundle(aliceKey, alice1MergedBundle2) + response, err = s.services[aliceUser].services[0].ProcessPublicBundle(aliceKey, alice1MergedBundle2) s.Require().NoError(err) sort.Slice(response, func(i, j int) bool { return response[i][1] < response[j][1] }) // We only get back installationIDs not equal to us s.Require().Equal(2, len(response)) - s.Require().Equal(IdentityAndIDPair{alice2Identity, "alice2"}, response[0]) - s.Require().Equal(IdentityAndIDPair{alice2Identity, "alice3"}, response[1]) + s.Require().Equal(multidevice.IdentityAndIDPair{alice2Identity, "alice2"}, response[0]) + s.Require().Equal(multidevice.IdentityAndIDPair{alice2Identity, "alice3"}, response[1]) // We disable the installations - err = s.services[aliceUser].encryptionServices[0].DisableInstallation(&aliceKey.PublicKey, "alice2") + err = s.services[aliceUser].services[0].DisableInstallation(&aliceKey.PublicKey, "alice2") s.Require().NoError(err) - alice1MergedBundle3, err := s.services[aliceUser].encryptionServices[0].CreateBundle(aliceKey) + alice1MergedBundle3, err := s.services[aliceUser].services[0].GetBundle(aliceKey) s.Require().NoError(err) // We get back a bundle with all the installations @@ -146,23 +169,23 @@ func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundleOutOfOrder() s.Require().NoError(err) // Alice1 creates a bundle - alice1Bundle, err := s.services[aliceUser].encryptionServices[0].CreateBundle(aliceKey) + alice1Bundle, err := s.services[aliceUser].services[0].GetBundle(aliceKey) s.Require().NoError(err) // Alice2 Receives the bundle - _, err = s.services[aliceUser].encryptionServices[1].ProcessPublicBundle(aliceKey, alice1Bundle) + _, err = s.services[aliceUser].services[1].ProcessPublicBundle(aliceKey, alice1Bundle) s.Require().NoError(err) // Alice2 Creates a Bundle - _, err = s.services[aliceUser].encryptionServices[1].CreateBundle(aliceKey) + _, err = s.services[aliceUser].services[1].GetBundle(aliceKey) s.Require().NoError(err) // We enable the installation - err = s.services[aliceUser].encryptionServices[1].EnableInstallation(&aliceKey.PublicKey, "alice1") + err = s.services[aliceUser].services[1].EnableInstallation(&aliceKey.PublicKey, "alice1") s.Require().NoError(err) // It should contain both bundles - alice2MergedBundle1, err := s.services[aliceUser].encryptionServices[1].CreateBundle(aliceKey) + alice2MergedBundle1, err := s.services[aliceUser].services[1].GetBundle(aliceKey) s.Require().NoError(err) s.Require().NotNil(alice2MergedBundle1.GetSignedPreKeys()["alice1"]) @@ -170,9 +193,9 @@ func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundleOutOfOrder() } func pairDevices(s *serviceAndKey, target int) error { - device := s.encryptionServices[target] - for i := 0; i < len(s.encryptionServices); i++ { - b, err := s.encryptionServices[i].CreateBundle(s.key) + device := s.services[target] + for i := 0; i < len(s.services); i++ { + b, err := s.services[i].GetBundle(s.key) if err != nil { return err @@ -183,7 +206,7 @@ func pairDevices(s *serviceAndKey, target int) error { return err } - err = device.EnableInstallation(&s.key.PublicKey, s.encryptionServices[i].config.InstallationID) + err = device.EnableInstallation(&s.key.PublicKey, s.services[i].encryption.config.InstallationID) if err != nil { return nil } @@ -194,14 +217,14 @@ func pairDevices(s *serviceAndKey, target int) error { func (s *EncryptionServiceMultiDeviceSuite) TestMaxDevices() { err := pairDevices(s.services[aliceUser], 0) s.Require().NoError(err) - alice1 := s.services[aliceUser].encryptionServices[0] - bob1 := s.services[bobUser].encryptionServices[0] + alice1 := s.services[aliceUser].services[0] + bob1 := s.services[bobUser].services[0] aliceKey := s.services[aliceUser].key bobKey := s.services[bobUser].key // Check bundle is ok // No installation is enabled - aliceBundle, err := alice1.CreateBundle(aliceKey) + aliceBundle, err := alice1.GetBundle(aliceKey) s.Require().NoError(err) // Check all installations are correctly working, and that the oldest device is not there @@ -218,19 +241,20 @@ func (s *EncryptionServiceMultiDeviceSuite) TestMaxDevices() { s.Require().NoError(err) // Bob sends a message to alice - payload, _, err := bob1.EncryptPayload(&aliceKey.PublicKey, bobKey, []byte("test")) + msg, err := bob1.BuildDirectMessage(bobKey, &aliceKey.PublicKey, []byte("test")) s.Require().NoError(err) + payload := msg.Message.GetDirectMessage() s.Require().Equal(3, len(payload)) s.Require().NotNil(payload["alice1"]) s.Require().NotNil(payload["alice3"]) s.Require().NotNil(payload["alice4"]) // We disable the last installation - err = s.services[aliceUser].encryptionServices[0].DisableInstallation(&aliceKey.PublicKey, "alice4") + err = s.services[aliceUser].services[0].DisableInstallation(&aliceKey.PublicKey, "alice4") s.Require().NoError(err) // We check the bundle is updated - aliceBundle, err = alice1.CreateBundle(aliceKey) + aliceBundle, err = alice1.GetBundle(aliceKey) s.Require().NoError(err) // Check all installations are there @@ -247,8 +271,9 @@ func (s *EncryptionServiceMultiDeviceSuite) TestMaxDevices() { s.Require().NoError(err) // Bob sends a message to alice - payload, _, err = bob1.EncryptPayload(&aliceKey.PublicKey, bobKey, []byte("test")) + msg, err = bob1.BuildDirectMessage(bobKey, &aliceKey.PublicKey, []byte("test")) s.Require().NoError(err) + payload = msg.Message.GetDirectMessage() s.Require().Equal(3, len(payload)) s.Require().NotNil(payload["alice1"]) s.Require().NotNil(payload["alice2"]) diff --git a/services/shhext/chat/encryption_test.go b/services/shhext/chat/encryption_test.go index 050a91fbf80..3917d2e3bdd 100644 --- a/services/shhext/chat/encryption_test.go +++ b/services/shhext/chat/encryption_test.go @@ -13,6 +13,9 @@ import ( "time" "github.com/ethereum/go-ethereum/crypto" + "github.com/status-im/status-go/services/shhext/chat/multidevice" + "github.com/status-im/status-go/services/shhext/chat/protobuf" + "github.com/status-im/status-go/services/shhext/chat/topic" "github.com/stretchr/testify/suite" ) @@ -27,8 +30,8 @@ func TestEncryptionServiceTestSuite(t *testing.T) { type EncryptionServiceTestSuite struct { suite.Suite - alice *EncryptionService - bob *EncryptionService + alice *ProtocolService + bob *ProtocolService aliceDBPath string bobDBPath string } @@ -56,21 +59,57 @@ func (s *EncryptionServiceTestSuite) initDatabases(baseConfig *EncryptionService bobDBKey = "bob" ) + aliceMultideviceConfig := &multidevice.Config{ + MaxInstallations: 3, + InstallationID: aliceInstallationID, + ProtocolVersion: 1, + } + alicePersistence, err := NewSQLLitePersistence(aliceDBPath, aliceDBKey) if err != nil { panic(err) } + baseConfig.InstallationID = aliceInstallationID + aliceEncryptionService := NewEncryptionService(alicePersistence, *baseConfig) + + aliceTopicService := topic.NewService(alicePersistence.GetTopicStorage()) + aliceMultideviceService := multidevice.New(aliceMultideviceConfig, alicePersistence.GetMultideviceStorage()) + + s.alice = NewProtocolService( + aliceEncryptionService, + aliceTopicService, + aliceMultideviceService, + func(s []multidevice.IdentityAndIDPair) {}, + func(s []*topic.Secret) {}, + ) + bobPersistence, err := NewSQLLitePersistence(bobDBPath, bobDBKey) if err != nil { panic(err) } - baseConfig.InstallationID = aliceInstallationID - s.alice = NewEncryptionService(alicePersistence, *baseConfig) + bobMultideviceConfig := &multidevice.Config{ + MaxInstallations: 3, + InstallationID: bobInstallationID, + ProtocolVersion: 1, + } + + bobMultideviceService := multidevice.New(bobMultideviceConfig, bobPersistence.GetMultideviceStorage()) + + bobTopicService := topic.NewService(bobPersistence.GetTopicStorage()) baseConfig.InstallationID = bobInstallationID - s.bob = NewEncryptionService(bobPersistence, *baseConfig) + bobEncryptionService := NewEncryptionService(bobPersistence, *baseConfig) + + s.bob = NewProtocolService( + bobEncryptionService, + bobTopicService, + bobMultideviceService, + func(s []multidevice.IdentityAndIDPair) {}, + func(s []*topic.Secret) {}, + ) + } func (s *EncryptionServiceTestSuite) SetupTest() { @@ -82,14 +121,14 @@ func (s *EncryptionServiceTestSuite) TearDownTest() { os.Remove(s.bobDBPath) } -func (s *EncryptionServiceTestSuite) TestCreateBundle() { +func (s *EncryptionServiceTestSuite) TestGetBundle() { aliceKey, err := crypto.GenerateKey() s.Require().NoError(err) - aliceBundle1, err := s.alice.CreateBundle(aliceKey) + aliceBundle1, err := s.alice.GetBundle(aliceKey) s.Require().NoError(err) s.NotNil(aliceBundle1, "It creates a bundle") - aliceBundle2, err := s.alice.CreateBundle(aliceKey) + aliceBundle2, err := s.alice.GetBundle(aliceKey) s.Require().NoError(err) s.Equal(aliceBundle1, aliceBundle2, "It returns the same bundle") } @@ -105,9 +144,11 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadNoBundle() { aliceKey, err := crypto.GenerateKey() s.Require().NoError(err) - encryptionResponse1, _, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext) + response1, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext) s.Require().NoError(err) + encryptionResponse1 := response1.Message.GetDirectMessage() + installationResponse1 := encryptionResponse1["none"] // That's for any device s.Require().NotNil(installationResponse1) @@ -119,14 +160,16 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadNoBundle() { s.NotEqual(cyphertext1, cleartext, "It encrypts the payload correctly") // On the receiver side, we should be able to decrypt using our private key and the ephemeral just sent - decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1, defaultMessageID) + decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response1.Message, defaultMessageID) s.Require().NoError(err) s.Equal(cleartext, decryptedPayload1, "It correctly decrypts the payload using DH") // The next message will not be re-using the same key - encryptionResponse2, _, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext) + response2, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext) s.Require().NoError(err) + encryptionResponse2 := response2.Message.GetDirectMessage() + installationResponse2 := encryptionResponse2[aliceInstallationID] cyphertext2 := installationResponse2.GetPayload() @@ -134,7 +177,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadNoBundle() { s.NotEqual(cyphertext1, cyphertext2, "It does not re-use the symmetric key") s.NotEqual(ephemeralKey1, ephemeralKey2, "It does not re-use the ephemeral key") - decryptedPayload2, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse2, defaultMessageID) + decryptedPayload2, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response2.Message, defaultMessageID) s.Require().NoError(err) s.Equal(cleartext, decryptedPayload2, "It correctly decrypts the payload using DH") } @@ -150,7 +193,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadBundle() { s.Require().NoError(err) // Create a bundle - bobBundle, err := s.bob.CreateBundle(bobKey) + bobBundle, err := s.bob.GetBundle(bobKey) s.Require().NoError(err) // We add bob bundle @@ -158,9 +201,11 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadBundle() { s.Require().NoError(err) // We send a message using the bundle - encryptionResponse1, _, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext) + response1, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext) s.Require().NoError(err) + encryptionResponse1 := response1.Message.GetDirectMessage() + installationResponse1 := encryptionResponse1[bobInstallationID] s.Require().NotNil(installationResponse1) @@ -186,7 +231,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadBundle() { s.Equal(uint32(0), drHeader.GetPn(), "It adds the correct length of the message chain") // Bob is able to decrypt it using the bundle - decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1, defaultMessageID) + decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response1.Message, defaultMessageID) s.Require().NoError(err) s.Equal(cleartext, decryptedPayload1, "It correctly decrypts the payload using X3DH") } @@ -209,7 +254,7 @@ func (s *EncryptionServiceTestSuite) TestConsequentMessagesBundle() { s.Require().NoError(err) // Create a bundle - bobBundle, err := s.bob.CreateBundle(bobKey) + bobBundle, err := s.bob.GetBundle(bobKey) s.Require().NoError(err) // We add bob bundle @@ -217,12 +262,13 @@ func (s *EncryptionServiceTestSuite) TestConsequentMessagesBundle() { s.Require().NoError(err) // We send a message using the bundle - _, _, err = s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext1) + _, err = s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext1) s.Require().NoError(err) // We send another message using the bundle - encryptionResponse, _, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext2) + response, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext2) s.Require().NoError(err) + encryptionResponse := response.Message.GetDirectMessage() installationResponse := encryptionResponse[bobInstallationID] s.Require().NotNil(installationResponse) @@ -250,7 +296,7 @@ func (s *EncryptionServiceTestSuite) TestConsequentMessagesBundle() { s.Equal(uint32(0), drHeader.GetPn(), "It adds the correct length of the message chain") // Bob is able to decrypt it using the bundle - decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse, defaultMessageID) + decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response.Message, defaultMessageID) s.Require().NoError(err) s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH") @@ -274,11 +320,11 @@ func (s *EncryptionServiceTestSuite) TestConversation() { s.Require().NoError(err) // Create a bundle - bobBundle, err := s.bob.CreateBundle(bobKey) + bobBundle, err := s.bob.GetBundle(bobKey) s.Require().NoError(err) // Create a bundle - aliceBundle, err := s.alice.CreateBundle(aliceKey) + aliceBundle, err := s.alice.GetBundle(aliceKey) s.Require().NoError(err) // We add bob bundle @@ -290,24 +336,25 @@ func (s *EncryptionServiceTestSuite) TestConversation() { s.Require().NoError(err) // Alice sends a message - encryptionResponse, _, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext1) + response, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext1) s.Require().NoError(err) // Bob receives the message - _, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse, defaultMessageID) + _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response.Message, defaultMessageID) s.Require().NoError(err) // Bob replies to the message - encryptionResponse, _, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, cleartext1) + response, err = s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, cleartext1) s.Require().NoError(err) // Alice receives the message - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, encryptionResponse, defaultMessageID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, response.Message, defaultMessageID) s.Require().NoError(err) // We send another message using the bundle - encryptionResponse, _, err = s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext2) + response, err = s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext2) s.Require().NoError(err) + encryptionResponse := response.Message.GetDirectMessage() installationResponse := encryptionResponse[bobInstallationID] s.Require().NotNil(installationResponse) @@ -333,7 +380,7 @@ func (s *EncryptionServiceTestSuite) TestConversation() { s.Equal(uint32(1), drHeader.GetPn(), "It adds the correct length of the message chain") // Bob is able to decrypt it using the bundle - decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse, defaultMessageID) + decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response.Message, defaultMessageID) s.Require().NoError(err) s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH") @@ -354,7 +401,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() { s.Require().NoError(err) // Create a bundle - bobBundle, err := s.bob.CreateBundle(bobKey) + bobBundle, err := s.bob.GetBundle(bobKey) s.Require().NoError(err) // We add bob bundle @@ -362,7 +409,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() { s.Require().NoError(err) // Create a bundle - aliceBundle, err := s.alice.CreateBundle(aliceKey) + aliceBundle, err := s.alice.GetBundle(aliceKey) s.Require().NoError(err) // We add alice bundle @@ -371,30 +418,30 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() { // Bob sends a message - for i := 0; i < s.alice.config.MaxSkip; i++ { - _, _, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + for i := 0; i < s.alice.encryption.config.MaxSkip; i++ { + _, err = s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText) s.Require().NoError(err) } // Bob sends a message - bobMessage1, _, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + bobMessage1, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText) s.Require().NoError(err) // Alice receives the message - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, defaultMessageID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage1.Message, defaultMessageID) s.Require().NoError(err) // Bob sends a message - _, _, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + _, err = s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText) s.Require().NoError(err) // Bob sends a message - bobMessage2, _, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + bobMessage2, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText) s.Require().NoError(err) // Alice receives the message, we should have maxSkip + 1 keys in the db, but // we should not throw an error - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, defaultMessageID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage2.Message, defaultMessageID) s.Require().NoError(err) } @@ -409,7 +456,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() { s.Require().NoError(err) // Create a bundle - bobBundle, err := s.bob.CreateBundle(bobKey) + bobBundle, err := s.bob.GetBundle(bobKey) s.Require().NoError(err) // We add bob bundle @@ -417,7 +464,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() { s.Require().NoError(err) // Create a bundle - aliceBundle, err := s.alice.CreateBundle(aliceKey) + aliceBundle, err := s.alice.GetBundle(aliceKey) s.Require().NoError(err) // We add alice bundle @@ -426,17 +473,17 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() { // Bob sends a message - for i := 0; i < s.alice.config.MaxSkip+1; i++ { - _, _, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + for i := 0; i < s.alice.encryption.config.MaxSkip+1; i++ { + _, err = s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText) s.Require().NoError(err) } // Bob sends a message - bobMessage1, _, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + bobMessage1, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText) s.Require().NoError(err) // Alice receives the message - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, defaultMessageID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage1.Message, defaultMessageID) s.Require().Equal(errors.New("can't skip current chain message keys: too many messages"), err) } @@ -457,7 +504,7 @@ func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() { s.Require().NoError(err) // Create a bundle - bobBundle, err := s.bob.CreateBundle(bobKey) + bobBundle, err := s.bob.GetBundle(bobKey) s.Require().NoError(err) // We add bob bundle @@ -465,7 +512,7 @@ func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() { s.Require().NoError(err) // Create a bundle - aliceBundle, err := s.alice.CreateBundle(aliceKey) + aliceBundle, err := s.alice.GetBundle(aliceKey) s.Require().NoError(err) // We add alice bundle @@ -474,27 +521,27 @@ func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() { // We create just enough messages so that the first key should be deleted - nMessages := s.alice.config.MaxMessageKeysPerSession - messages := make([]map[string]*DirectMessageProtocol, nMessages) + nMessages := s.alice.encryption.config.MaxMessageKeysPerSession + messages := make([]*protobuf.ProtocolMessage, nMessages) for i := 0; i < nMessages; i++ { - m, _, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + m, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText) s.Require().NoError(err) - messages[i] = m + messages[i] = m.Message } // Another message to trigger the deletion - m, _, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + m, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText) s.Require().NoError(err) - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m, defaultMessageID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, m.Message, defaultMessageID) s.Require().NoError(err) // We decrypt the first message, and it should fail - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0], defaultMessageID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, messages[0], defaultMessageID) s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err) // We decrypt the second message, and it should be decrypted - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1], defaultMessageID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, messages[1], defaultMessageID) s.Require().NoError(err) } @@ -514,7 +561,7 @@ func (s *EncryptionServiceTestSuite) TestMaxKeep() { s.Require().NoError(err) // Create a bundle - bobBundle, err := s.bob.CreateBundle(bobKey) + bobBundle, err := s.bob.GetBundle(bobKey) s.Require().NoError(err) // We add bob bundle @@ -522,7 +569,7 @@ func (s *EncryptionServiceTestSuite) TestMaxKeep() { s.Require().NoError(err) // Create a bundle - aliceBundle, err := s.alice.CreateBundle(aliceKey) + aliceBundle, err := s.alice.GetBundle(aliceKey) s.Require().NoError(err) // We add alice bundle @@ -530,15 +577,15 @@ func (s *EncryptionServiceTestSuite) TestMaxKeep() { s.Require().NoError(err) // We decrypt all messages but 1 & 2 - messages := make([]map[string]*DirectMessageProtocol, s.alice.config.MaxKeep) - for i := 0; i < s.alice.config.MaxKeep; i++ { - m, _, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) - messages[i] = m + messages := make([]*protobuf.ProtocolMessage, s.alice.encryption.config.MaxKeep) + for i := 0; i < s.alice.encryption.config.MaxKeep; i++ { + m, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText) + messages[i] = m.Message s.Require().NoError(err) if i != 0 && i != 1 { messageID := []byte(fmt.Sprintf("%d", i)) - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m, messageID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, m.Message, messageID) s.Require().NoError(err) err = s.alice.ConfirmMessagesProcessed([][]byte{messageID}) s.Require().NoError(err) @@ -547,11 +594,11 @@ func (s *EncryptionServiceTestSuite) TestMaxKeep() { } // We decrypt the first message, and it should fail, as it should have been removed - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0], defaultMessageID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, messages[0], defaultMessageID) s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err) // We decrypt the second message, and it should be decrypted - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1], defaultMessageID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, messages[1], defaultMessageID) s.Require().NoError(err) } @@ -576,7 +623,7 @@ func (s *EncryptionServiceTestSuite) TestConcurrentBundles() { s.Require().NoError(err) // Create a bundle - bobBundle, err := s.bob.CreateBundle(bobKey) + bobBundle, err := s.bob.GetBundle(bobKey) s.Require().NoError(err) // We add bob bundle @@ -584,7 +631,7 @@ func (s *EncryptionServiceTestSuite) TestConcurrentBundles() { s.Require().NoError(err) // Create a bundle - aliceBundle, err := s.alice.CreateBundle(aliceKey) + aliceBundle, err := s.alice.GetBundle(aliceKey) s.Require().NoError(err) // We add alice bundle @@ -592,44 +639,44 @@ func (s *EncryptionServiceTestSuite) TestConcurrentBundles() { s.Require().NoError(err) // Alice sends a message - aliceMessage1, _, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, aliceText1) + aliceMessage1, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, aliceText1) s.Require().NoError(err) // Bob sends a message - bobMessage1, _, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1) + bobMessage1, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText1) s.Require().NoError(err) // Bob receives the message - _, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage1, defaultMessageID) + _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, aliceMessage1.Message, defaultMessageID) s.Require().NoError(err) // Alice receives the message - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, defaultMessageID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage1.Message, defaultMessageID) s.Require().NoError(err) // Bob replies to the message - bobMessage2, _, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText2) + bobMessage2, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText2) s.Require().NoError(err) // Alice sends a message - aliceMessage2, _, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, aliceText2) + aliceMessage2, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, aliceText2) s.Require().NoError(err) // Alice receives the message - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, defaultMessageID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage2.Message, defaultMessageID) s.Require().NoError(err) // Bob receives the message - _, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage2, defaultMessageID) + _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, aliceMessage2.Message, defaultMessageID) s.Require().NoError(err) } func publisher( - e *EncryptionService, + e *ProtocolService, privateKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey, errChan chan error, - output chan map[string]*DirectMessageProtocol, + output chan *protobuf.ProtocolMessage, ) { var wg sync.WaitGroup @@ -642,13 +689,13 @@ func publisher( go func() { defer wg.Done() time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) - response, _, err := e.EncryptPayload(publicKey, privateKey, cleartext) + response, err := e.BuildDirectMessage(privateKey, publicKey, cleartext) if err != nil { errChan <- err return } - output <- response + output <- response.Message }() } } @@ -658,17 +705,16 @@ func publisher( } func receiver( - s *EncryptionService, + s *ProtocolService, privateKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey, - installationID string, errChan chan error, - input chan map[string]*DirectMessageProtocol, + input chan *protobuf.ProtocolMessage, ) { i := 0 for payload := range input { - actualCleartext, err := s.DecryptPayload(privateKey, publicKey, installationID, payload, defaultMessageID) + actualCleartext, err := s.HandleMessage(privateKey, publicKey, payload, defaultMessageID) if err != nil { errChan <- err return @@ -697,7 +743,7 @@ func (s *EncryptionServiceTestSuite) TestRandomised() { s.Require().NoError(err) // Create a bundle - bobBundle, err := s.bob.CreateBundle(bobKey) + bobBundle, err := s.bob.GetBundle(bobKey) s.Require().NoError(err) // We add bob bundle @@ -705,15 +751,15 @@ func (s *EncryptionServiceTestSuite) TestRandomised() { s.Require().NoError(err) // Create a bundle - aliceBundle, err := s.alice.CreateBundle(aliceKey) + aliceBundle, err := s.alice.GetBundle(aliceKey) s.Require().NoError(err) // We add alice bundle _, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle) s.Require().NoError(err) - aliceChan := make(chan map[string]*DirectMessageProtocol, 100) - bobChan := make(chan map[string]*DirectMessageProtocol, 100) + aliceChan := make(chan *protobuf.ProtocolMessage, 100) + bobChan := make(chan *protobuf.ProtocolMessage, 100) alicePublisherErrChan := make(chan error, 1) bobPublisherErrChan := make(chan error, 1) @@ -727,10 +773,10 @@ func (s *EncryptionServiceTestSuite) TestRandomised() { go publisher(s.bob, bobKey, &aliceKey.PublicKey, bobPublisherErrChan, aliceChan) // Set up bob receiver - go receiver(s.bob, bobKey, &aliceKey.PublicKey, aliceInstallationID, bobReceiverErrChan, bobChan) + go receiver(s.bob, bobKey, &aliceKey.PublicKey, bobReceiverErrChan, bobChan) // Set up alice receiver - go receiver(s.alice, aliceKey, &bobKey.PublicKey, bobInstallationID, aliceReceiverErrChan, aliceChan) + go receiver(s.alice, aliceKey, &bobKey.PublicKey, aliceReceiverErrChan, aliceChan) aliceErr := <-alicePublisherErrChan s.Require().NoError(aliceErr) @@ -771,11 +817,11 @@ func (s *EncryptionServiceTestSuite) TestBundleNotExisting() { s.Require().NoError(err) // Alice sends a message - aliceMessage, _, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, aliceText) + aliceMessage, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, aliceText) s.Require().NoError(err) // Bob receives the message, and returns a bundlenotfound error - _, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage, defaultMessageID) + _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, aliceMessage.Message, defaultMessageID) s.Require().Error(err) s.Equal(ErrSessionNotFound, err) } @@ -804,11 +850,11 @@ func (s *EncryptionServiceTestSuite) TestDeviceNotIncluded() { s.Require().NoError(err) // Alice sends a message - aliceMessage, _, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, []byte("does not matter")) + aliceMessage, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, []byte("does not matter")) s.Require().NoError(err) // Bob receives the message, and returns a bundlenotfound error - _, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage, defaultMessageID) + _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, aliceMessage.Message, defaultMessageID) s.Require().Error(err) s.Equal(ErrDeviceNotFound, err) } @@ -829,7 +875,7 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() { s.Require().NoError(err) // Create bundles - bobBundle1, err := s.bob.CreateBundle(bobKey) + bobBundle1, err := s.bob.GetBundle(bobKey) s.Require().NoError(err) s.Require().Equal(uint32(1), bobBundle1.GetSignedPreKeys()[bobInstallationID].GetVersion()) @@ -837,7 +883,7 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() { time.Sleep(time.Duration(config.BundleRefreshInterval) * time.Millisecond) // Create bundles - bobBundle2, err := s.bob.CreateBundle(bobKey) + bobBundle2, err := s.bob.GetBundle(bobKey) s.Require().NoError(err) s.Require().Equal(uint32(2), bobBundle2.GetSignedPreKeys()[bobInstallationID].GetVersion()) @@ -846,8 +892,9 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() { s.Require().NoError(err) // Alice sends a message - encryptionResponse1, _, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, []byte("anything")) + response1, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, []byte("anything")) s.Require().NoError(err) + encryptionResponse1 := response1.Message.GetDirectMessage() installationResponse1 := encryptionResponse1[bobInstallationID] s.Require().NotNil(installationResponse1) @@ -859,7 +906,7 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() { s.Equal(bobBundle1.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader1.GetId()) // Bob decrypts the message - _, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1, defaultMessageID) + _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response1.Message, defaultMessageID) s.Require().NoError(err) // We add the second bob bundle @@ -867,8 +914,9 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() { s.Require().NoError(err) // Alice sends a message - encryptionResponse2, _, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, []byte("anything")) + response2, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, []byte("anything")) s.Require().NoError(err) + encryptionResponse2 := response2.Message.GetDirectMessage() installationResponse2 := encryptionResponse2[bobInstallationID] s.Require().NotNil(installationResponse2) @@ -880,7 +928,7 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() { s.Equal(bobBundle2.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader2.GetId()) // Bob decrypts the message - _, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse2, defaultMessageID) + _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response2.Message, defaultMessageID) s.Require().NoError(err) } @@ -894,7 +942,7 @@ func (s *EncryptionServiceTestSuite) TestMessageConfirmation() { s.Require().NoError(err) // Create a bundle - bobBundle, err := s.bob.CreateBundle(bobKey) + bobBundle, err := s.bob.GetBundle(bobKey) s.Require().NoError(err) // We add bob bundle @@ -902,7 +950,7 @@ func (s *EncryptionServiceTestSuite) TestMessageConfirmation() { s.Require().NoError(err) // Create a bundle - aliceBundle, err := s.alice.CreateBundle(aliceKey) + aliceBundle, err := s.alice.GetBundle(aliceKey) s.Require().NoError(err) // We add alice bundle @@ -910,16 +958,16 @@ func (s *EncryptionServiceTestSuite) TestMessageConfirmation() { s.Require().NoError(err) // Bob sends a message - bobMessage1, _, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1) + bobMessage1, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText1) s.Require().NoError(err) bobMessage1ID := []byte("bob-message-1-id") // Alice receives the message once - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, bobMessage1ID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage1.Message, bobMessage1ID) s.Require().NoError(err) // Alice receives the message twice - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, bobMessage1ID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage1.Message, bobMessage1ID) s.Require().NoError(err) // Alice confirms the message @@ -927,33 +975,33 @@ func (s *EncryptionServiceTestSuite) TestMessageConfirmation() { s.Require().NoError(err) // Alice decrypts it again, it should fail - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, bobMessage1ID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage1.Message, bobMessage1ID) s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err) // Bob sends a message - bobMessage2, _, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1) + bobMessage2, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText1) s.Require().NoError(err) bobMessage2ID := []byte("bob-message-2-id") // Bob sends a message - bobMessage3, _, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1) + bobMessage3, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText1) s.Require().NoError(err) bobMessage3ID := []byte("bob-message-3-id") // Alice receives message 3 once - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage3, bobMessage3ID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage3.Message, bobMessage3ID) s.Require().NoError(err) // Alice receives message 3 twice - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage3, bobMessage3ID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage3.Message, bobMessage3ID) s.Require().NoError(err) // Alice receives message 2 once - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, bobMessage2ID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage2.Message, bobMessage2ID) s.Require().NoError(err) // Alice receives message 2 twice - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, bobMessage2ID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage2.Message, bobMessage2ID) s.Require().NoError(err) // Alice confirms the messages @@ -961,10 +1009,10 @@ func (s *EncryptionServiceTestSuite) TestMessageConfirmation() { s.Require().NoError(err) // Alice decrypts it again, it should fail - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage3, bobMessage3ID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage3.Message, bobMessage3ID) s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err) // Alice decrypts it again, it should fail - _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, bobMessage2ID) + _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage2.Message, bobMessage2ID) s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err) } diff --git a/services/shhext/chat/multidevice/persistence.go b/services/shhext/chat/multidevice/persistence.go new file mode 100644 index 00000000000..27017a7456b --- /dev/null +++ b/services/shhext/chat/multidevice/persistence.go @@ -0,0 +1,12 @@ +package multidevice + +type Persistence interface { + // GetActiveInstallations returns the active installations for a given identity. + GetActiveInstallations(maxInstallations int, identity []byte) ([]*Installation, error) + // EnableInstallation enables the installation. + EnableInstallation(identity []byte, installationID string) error + // DisableInstallation disable the installation. + DisableInstallation(identity []byte, installationID string) error + // AddInstallations adds the installations for a given identity, maintaining the enabled flag + AddInstallations(identity []byte, timestamp int64, installations []*Installation, defaultEnabled bool) error +} diff --git a/services/shhext/chat/multidevice/service.go b/services/shhext/chat/multidevice/service.go new file mode 100644 index 00000000000..0a2aa7e99ec --- /dev/null +++ b/services/shhext/chat/multidevice/service.go @@ -0,0 +1,94 @@ +package multidevice + +import ( + "crypto/ecdsa" + "fmt" + "github.com/ethereum/go-ethereum/crypto" + "github.com/status-im/status-go/services/shhext/chat/protobuf" +) + +type Installation struct { + ID string + Version uint32 +} + +type Config struct { + MaxInstallations int + ProtocolVersion uint32 + InstallationID string +} + +func New(config *Config, persistence Persistence) *Service { + return &Service{ + config: config, + persistence: persistence, + } +} + +type Service struct { + persistence Persistence + config *Config +} + +type IdentityAndIDPair [2]string + +func (s *Service) GetActiveInstallations(identity *ecdsa.PublicKey) ([]*Installation, error) { + identityC := crypto.CompressPubkey(identity) + return s.persistence.GetActiveInstallations(s.config.MaxInstallations, identityC) +} + +func (s *Service) GetOurActiveInstallations(identity *ecdsa.PublicKey) ([]*Installation, error) { + identityC := crypto.CompressPubkey(identity) + installations, err := s.persistence.GetActiveInstallations(s.config.MaxInstallations-1, identityC) + if err != nil { + return nil, err + } + // Move to layer above + installations = append(installations, &Installation{ + ID: s.config.InstallationID, + Version: s.config.ProtocolVersion, + }) + + return installations, nil + +} + +func (s *Service) EnableInstallation(identity *ecdsa.PublicKey, installationID string) error { + identityC := crypto.CompressPubkey(identity) + return s.persistence.EnableInstallation(identityC, installationID) +} + +func (s *Service) DisableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error { + myIdentityKeyC := crypto.CompressPubkey(myIdentityKey) + return s.persistence.DisableInstallation(myIdentityKeyC, installationID) +} + +// ProcessPublicBundle persists a bundle and returns a list of tuples identity/installationID +func (s *Service) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, theirIdentity *ecdsa.PublicKey, b *protobuf.Bundle) ([]IdentityAndIDPair, error) { + signedPreKeys := b.GetSignedPreKeys() + var response []IdentityAndIDPair + var installations []*Installation + + myIdentityStr := fmt.Sprintf("0x%x", crypto.FromECDSAPub(&myIdentityKey.PublicKey)) + theirIdentityStr := fmt.Sprintf("0x%x", crypto.FromECDSAPub(theirIdentity)) + + // Any device from other peers will be considered enabled, ours needs to + // be explicitly enabled + fromOurIdentity := theirIdentityStr != myIdentityStr + + for installationID, signedPreKey := range signedPreKeys { + if installationID != s.config.InstallationID { + installations = append(installations, &Installation{ + ID: installationID, + Version: signedPreKey.GetProtocolVersion(), + }) + response = append(response, IdentityAndIDPair{theirIdentityStr, installationID}) + } + } + + if err := s.persistence.AddInstallations(b.GetIdentity(), b.GetTimestamp(), installations, fromOurIdentity); err != nil { + return nil, err + } + + return response, nil +} diff --git a/services/shhext/chat/multidevice/sql_lite_persistence.go b/services/shhext/chat/multidevice/sql_lite_persistence.go new file mode 100644 index 00000000000..ef270091e35 --- /dev/null +++ b/services/shhext/chat/multidevice/sql_lite_persistence.go @@ -0,0 +1,168 @@ +package multidevice + +import ( + "database/sql" +) + +// SQLLitePersistence represents a persistence service tied to an SQLite database +type SQLLitePersistence struct { + db *sql.DB +} + +// NewSQLLitePersistence creates a new SQLLitePersistence instance, given a path and a key +func NewSQLLitePersistence(db *sql.DB) *SQLLitePersistence { + return &SQLLitePersistence{db: db} +} + +// GetActiveInstallations returns the active installations for a given identity +func (s *SQLLitePersistence) GetActiveInstallations(maxInstallations int, identity []byte) ([]*Installation, error) { + stmt, err := s.db.Prepare(`SELECT installation_id, version + FROM installations + WHERE enabled = 1 AND identity = ? + ORDER BY timestamp DESC + LIMIT ?`) + if err != nil { + return nil, err + } + + var installations []*Installation + rows, err := stmt.Query(identity, maxInstallations) + if err != nil { + return nil, err + } + + for rows.Next() { + var installationID string + var version uint32 + err = rows.Scan( + &installationID, + &version, + ) + if err != nil { + return nil, err + } + installations = append(installations, &Installation{ + ID: installationID, + Version: version, + }) + + } + + return installations, nil + +} + +// AddInstallations adds the installations for a given identity, maintaining the enabled flag +func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64, installations []*Installation, defaultEnabled bool) error { + tx, err := s.db.Begin() + if err != nil { + return nil + } + + for _, installation := range installations { + stmt, err := tx.Prepare(`SELECT enabled, version + FROM installations + WHERE identity = ? AND installation_id = ? + LIMIT 1`) + if err != nil { + return err + } + defer stmt.Close() + + var oldEnabled bool + // We don't override version once we saw one + var oldVersion uint32 + latestVersion := installation.Version + + err = stmt.QueryRow(identity, installation.ID).Scan(&oldEnabled, &oldVersion) + if err != nil && err != sql.ErrNoRows { + return err + } + + // We update timestamp if present without changing enabled, only if this is a new bundle + // and we set the version to the latest we ever saw + if err != sql.ErrNoRows { + if oldVersion > installation.Version { + latestVersion = oldVersion + } + + stmt, err = tx.Prepare(`UPDATE installations + SET timestamp = ?, enabled = ?, version = ? + WHERE identity = ? + AND installation_id = ? + AND timestamp < ?`) + if err != nil { + return err + } + + _, err = stmt.Exec( + timestamp, + oldEnabled, + latestVersion, + identity, + installation.ID, + timestamp, + ) + if err != nil { + return err + } + defer stmt.Close() + + } else { + stmt, err = tx.Prepare(`INSERT INTO installations(identity, installation_id, timestamp, enabled, version) + VALUES (?, ?, ?, ?, ?)`) + if err != nil { + return err + } + + _, err = stmt.Exec( + identity, + installation.ID, + timestamp, + defaultEnabled, + latestVersion, + ) + if err != nil { + return err + } + defer stmt.Close() + } + + } + + if err := tx.Commit(); err != nil { + _ = tx.Rollback() + return err + } + + return nil + +} + +// EnableInstallation enables the installation +func (s *SQLLitePersistence) EnableInstallation(identity []byte, installationID string) error { + stmt, err := s.db.Prepare(`UPDATE installations + SET enabled = 1 + WHERE identity = ? AND installation_id = ?`) + if err != nil { + return err + } + + _, err = stmt.Exec(identity, installationID) + return err + +} + +// DisableInstallation disable the installation +func (s *SQLLitePersistence) DisableInstallation(identity []byte, installationID string) error { + + stmt, err := s.db.Prepare(`UPDATE installations + SET enabled = 0 + WHERE identity = ? AND installation_id = ?`) + if err != nil { + return err + } + + _, err = stmt.Exec(identity, installationID) + return err +} diff --git a/services/shhext/chat/multidevice/sql_lite_persistence_test.go b/services/shhext/chat/multidevice/sql_lite_persistence_test.go new file mode 100644 index 00000000000..5ca62e06bcd --- /dev/null +++ b/services/shhext/chat/multidevice/sql_lite_persistence_test.go @@ -0,0 +1,243 @@ +package multidevice + +import ( + "database/sql" + "os" + "testing" + + appDB "github.com/status-im/status-go/services/shhext/chat/db" + "github.com/stretchr/testify/suite" +) + +const ( + dbPath = "/tmp/status-key-store.db" +) + +func TestSQLLitePersistenceTestSuite(t *testing.T) { + suite.Run(t, new(SQLLitePersistenceTestSuite)) +} + +type SQLLitePersistenceTestSuite struct { + suite.Suite + // nolint: structcheck, megacheck + db *sql.DB + service Persistence +} + +func (s *SQLLitePersistenceTestSuite) SetupTest() { + os.Remove(dbPath) + + db, err := appDB.Open(dbPath, "", 0) + s.Require().NoError(err) + + s.service = NewSQLLitePersistence(db) +} + +func (s *SQLLitePersistenceTestSuite) TestAddInstallations() { + identity := []byte("alice") + installations := []*Installation{ + {ID: "alice-1", Version: 1}, + {ID: "alice-2", Version: 2}, + } + err := s.service.AddInstallations( + identity, + 1, + installations, + true, + ) + + s.Require().NoError(err) + + enabledInstallations, err := s.service.GetActiveInstallations(5, identity) + s.Require().NoError(err) + + s.Require().Equal(installations, enabledInstallations) +} + +func (s *SQLLitePersistenceTestSuite) TestAddInstallationVersions() { + identity := []byte("alice") + installations := []*Installation{ + {ID: "alice-1", Version: 1}, + } + err := s.service.AddInstallations( + identity, + 1, + installations, + true, + ) + + s.Require().NoError(err) + + enabledInstallations, err := s.service.GetActiveInstallations(5, identity) + s.Require().NoError(err) + + s.Require().Equal(installations, enabledInstallations) + + installationsWithDowngradedVersion := []*Installation{ + {ID: "alice-1", Version: 0}, + } + + err = s.service.AddInstallations( + identity, + 3, + installationsWithDowngradedVersion, + true, + ) + s.Require().NoError(err) + + enabledInstallations, err = s.service.GetActiveInstallations(5, identity) + s.Require().NoError(err) + s.Require().Equal(installations, enabledInstallations) +} + +func (s *SQLLitePersistenceTestSuite) TestAddInstallationsLimit() { + identity := []byte("alice") + + installations := []*Installation{ + {ID: "alice-1", Version: 1}, + {ID: "alice-2", Version: 2}, + } + + err := s.service.AddInstallations( + identity, + 1, + installations, + true, + ) + s.Require().NoError(err) + + installations = []*Installation{ + {ID: "alice-1", Version: 1}, + {ID: "alice-3", Version: 3}, + } + + err = s.service.AddInstallations( + identity, + 2, + installations, + true, + ) + s.Require().NoError(err) + + installations = []*Installation{ + {ID: "alice-2", Version: 2}, + {ID: "alice-3", Version: 3}, + {ID: "alice-4", Version: 4}, + } + + err = s.service.AddInstallations( + identity, + 3, + installations, + true, + ) + s.Require().NoError(err) + + enabledInstallations, err := s.service.GetActiveInstallations(3, identity) + s.Require().NoError(err) + + s.Require().Equal(installations, enabledInstallations) +} + +func (s *SQLLitePersistenceTestSuite) TestAddInstallationsDisabled() { + identity := []byte("alice") + + installations := []*Installation{ + {ID: "alice-1", Version: 1}, + {ID: "alice-2", Version: 2}, + } + + err := s.service.AddInstallations( + identity, + 1, + installations, + false, + ) + s.Require().NoError(err) + + actualInstallations, err := s.service.GetActiveInstallations(3, identity) + s.Require().NoError(err) + + s.Require().Nil(actualInstallations) +} + +func (s *SQLLitePersistenceTestSuite) TestDisableInstallation() { + identity := []byte("alice") + + installations := []*Installation{ + {ID: "alice-1", Version: 1}, + {ID: "alice-2", Version: 2}, + } + + err := s.service.AddInstallations( + identity, + 1, + installations, + true, + ) + s.Require().NoError(err) + + err = s.service.DisableInstallation(identity, "alice-1") + s.Require().NoError(err) + + // We add the installations again + installations = []*Installation{ + {ID: "alice-1", Version: 1}, + {ID: "alice-2", Version: 2}, + } + + err = s.service.AddInstallations( + identity, + 1, + installations, + true, + ) + s.Require().NoError(err) + + actualInstallations, err := s.service.GetActiveInstallations(3, identity) + s.Require().NoError(err) + + expected := []*Installation{{ID: "alice-2", Version: 2}} + s.Require().Equal(expected, actualInstallations) +} + +func (s *SQLLitePersistenceTestSuite) TestEnableInstallation() { + identity := []byte("alice") + + installations := []*Installation{ + {ID: "alice-1", Version: 1}, + {ID: "alice-2", Version: 2}, + } + + err := s.service.AddInstallations( + identity, + 1, + installations, + true, + ) + s.Require().NoError(err) + + err = s.service.DisableInstallation(identity, "alice-1") + s.Require().NoError(err) + + actualInstallations, err := s.service.GetActiveInstallations(3, identity) + s.Require().NoError(err) + + expected := []*Installation{{ID: "alice-2", Version: 2}} + s.Require().Equal(expected, actualInstallations) + + err = s.service.EnableInstallation(identity, "alice-1") + s.Require().NoError(err) + + actualInstallations, err = s.service.GetActiveInstallations(3, identity) + s.Require().NoError(err) + + expected = []*Installation{ + {ID: "alice-1", Version: 1}, + {ID: "alice-2", Version: 2}, + } + s.Require().Equal(expected, actualInstallations) + +} + +// TODO: Add test for MarkBundleExpired diff --git a/services/shhext/chat/persistence.go b/services/shhext/chat/persistence.go index 669784638f6..b8a1fe3f7b8 100644 --- a/services/shhext/chat/persistence.go +++ b/services/shhext/chat/persistence.go @@ -4,13 +4,10 @@ import ( "crypto/ecdsa" dr "github.com/status-im/doubleratchet" + "github.com/status-im/status-go/services/shhext/chat/multidevice" + "github.com/status-im/status-go/services/shhext/chat/protobuf" ) -type Installation struct { - ID string - Version uint32 -} - // RatchetInfo holds the current ratchet state type RatchetInfo struct { ID []byte @@ -30,17 +27,17 @@ type PersistenceService interface { // GetSessionStorage returns the associated double ratchet SessionStorage object. GetSessionStorage() dr.SessionStorage - // GetPublicBundle retrieves an existing Bundle for the specified public key & installationIDs. - GetPublicBundle(*ecdsa.PublicKey, []*Installation) (*Bundle, error) + // GetPublicBundle retrieves an existing Bundle for the specified public key & installations + GetPublicBundle(*ecdsa.PublicKey, []*multidevice.Installation) (*protobuf.Bundle, error) // AddPublicBundle persists a specified Bundle - AddPublicBundle(*Bundle) error + AddPublicBundle(*protobuf.Bundle) error - // GetAnyPrivateBundle retrieves any bundle for our identity & installationIDs - GetAnyPrivateBundle([]byte, []*Installation) (*BundleContainer, error) + // GetAnyPrivateBundle retrieves any bundle for our identity & installations + GetAnyPrivateBundle([]byte, []*multidevice.Installation) (*protobuf.BundleContainer, error) // GetPrivateKeyBundle retrieves a BundleContainer with the specified signed prekey. GetPrivateKeyBundle([]byte) ([]byte, error) // AddPrivateBundle persists a BundleContainer. - AddPrivateBundle(*BundleContainer) error + AddPrivateBundle(*protobuf.BundleContainer) error // MarkBundleExpired marks a private bundle as expired, not to be used for encryption anymore. MarkBundleExpired([]byte) error @@ -53,13 +50,4 @@ type PersistenceService interface { // RatchetInfoConfirmed clears the ephemeral key in the RatchetInfo // associated with the specified bundle ID and interlocutor identity public key. RatchetInfoConfirmed([]byte, []byte, string) error - - // GetActiveInstallations returns the active installations for a given identity. - GetActiveInstallations(maxInstallations int, identity []byte) ([]*Installation, error) - // AddInstallations adds the installations for a given identity. - AddInstallations(identity []byte, timestamp int64, installations []*Installation, enabled bool) error - // EnableInstallation enables the installation. - EnableInstallation(identity []byte, installationID string) error - // DisableInstallation disable the installation. - DisableInstallation(identity []byte, installationID string) error } diff --git a/services/shhext/chat/encryption.pb.go b/services/shhext/chat/protobuf/encryption.pb.go similarity index 78% rename from services/shhext/chat/encryption.pb.go rename to services/shhext/chat/protobuf/encryption.pb.go index 88e0a90c640..760de5cdca5 100644 --- a/services/shhext/chat/encryption.pb.go +++ b/services/shhext/chat/protobuf/encryption.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // source: encryption.proto -package chat +package protobuf import ( fmt "fmt" @@ -491,56 +491,56 @@ func (m *ProtocolMessage) GetPublicMessage() []byte { } func init() { - proto.RegisterType((*SignedPreKey)(nil), "chat.SignedPreKey") - proto.RegisterType((*Bundle)(nil), "chat.Bundle") - proto.RegisterMapType((map[string]*SignedPreKey)(nil), "chat.Bundle.SignedPreKeysEntry") - proto.RegisterType((*BundleContainer)(nil), "chat.BundleContainer") - proto.RegisterType((*DRHeader)(nil), "chat.DRHeader") - proto.RegisterType((*DHHeader)(nil), "chat.DHHeader") - proto.RegisterType((*X3DHHeader)(nil), "chat.X3DHHeader") - proto.RegisterType((*DirectMessageProtocol)(nil), "chat.DirectMessageProtocol") - proto.RegisterType((*ProtocolMessage)(nil), "chat.ProtocolMessage") - proto.RegisterMapType((map[string]*DirectMessageProtocol)(nil), "chat.ProtocolMessage.DirectMessageEntry") + proto.RegisterType((*SignedPreKey)(nil), "protobuf.SignedPreKey") + proto.RegisterType((*Bundle)(nil), "protobuf.Bundle") + proto.RegisterMapType((map[string]*SignedPreKey)(nil), "protobuf.Bundle.SignedPreKeysEntry") + proto.RegisterType((*BundleContainer)(nil), "protobuf.BundleContainer") + proto.RegisterType((*DRHeader)(nil), "protobuf.DRHeader") + proto.RegisterType((*DHHeader)(nil), "protobuf.DHHeader") + proto.RegisterType((*X3DHHeader)(nil), "protobuf.X3DHHeader") + proto.RegisterType((*DirectMessageProtocol)(nil), "protobuf.DirectMessageProtocol") + proto.RegisterType((*ProtocolMessage)(nil), "protobuf.ProtocolMessage") + proto.RegisterMapType((map[string]*DirectMessageProtocol)(nil), "protobuf.ProtocolMessage.DirectMessageEntry") } func init() { proto.RegisterFile("encryption.proto", fileDescriptor_8293a649ce9418c6) } var fileDescriptor_8293a649ce9418c6 = []byte{ - // 562 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x54, 0x61, 0x8b, 0xd3, 0x4c, - 0x10, 0x26, 0x49, 0xef, 0xda, 0x4e, 0xd3, 0xb4, 0xec, 0xcb, 0x2b, 0xa1, 0x1e, 0x58, 0xc2, 0xa9, - 0x11, 0xa1, 0x70, 0xad, 0x1f, 0xc4, 0x8f, 0x5a, 0xb1, 0x9e, 0x88, 0xc7, 0x2a, 0xe2, 0x17, 0x09, - 0xdb, 0x66, 0xbd, 0x5b, 0x4c, 0x93, 0xb0, 0xbb, 0x2d, 0xe4, 0xcf, 0xf9, 0x57, 0xfc, 0x29, 0x4a, - 0x76, 0xb3, 0xed, 0xb6, 0x77, 0x07, 0x7e, 0xeb, 0xcc, 0x3c, 0xfb, 0xcc, 0x33, 0xcf, 0x74, 0x02, - 0x43, 0x9a, 0xaf, 0x78, 0x55, 0x4a, 0x56, 0xe4, 0x93, 0x92, 0x17, 0xb2, 0x40, 0xad, 0xd5, 0x0d, - 0x91, 0x51, 0x05, 0xfe, 0x67, 0x76, 0x9d, 0xd3, 0xf4, 0x8a, 0xd3, 0x0f, 0xb4, 0x42, 0xe7, 0x10, - 0x08, 0x15, 0x27, 0x25, 0xa7, 0xc9, 0x4f, 0x5a, 0x85, 0xce, 0xd8, 0x89, 0x7d, 0xec, 0x0b, 0x1b, - 0x15, 0x42, 0x7b, 0x4b, 0xb9, 0x60, 0x45, 0x1e, 0xba, 0x63, 0x27, 0xee, 0x63, 0x13, 0xa2, 0x67, - 0x30, 0x54, 0xf4, 0xab, 0x22, 0x4b, 0x0c, 0xc4, 0x53, 0x90, 0x81, 0xc9, 0x7f, 0xd5, 0xe9, 0xe8, - 0x8f, 0x03, 0xa7, 0xaf, 0x37, 0x79, 0x9a, 0x51, 0x34, 0x82, 0x0e, 0x4b, 0x69, 0x2e, 0x99, 0x34, - 0xfd, 0x76, 0x31, 0x7a, 0x07, 0x83, 0x43, 0x45, 0x22, 0x74, 0xc7, 0x5e, 0xdc, 0x9b, 0x3e, 0x9a, - 0xd4, 0x13, 0x4c, 0x34, 0xc5, 0xc4, 0x9e, 0x42, 0xbc, 0xcd, 0x25, 0xaf, 0x70, 0xdf, 0xd6, 0x2c, - 0xd0, 0x19, 0x74, 0xeb, 0x04, 0x91, 0x1b, 0x4e, 0xc3, 0x96, 0xea, 0xb2, 0x4f, 0xd4, 0x55, 0xc9, - 0xd6, 0x54, 0x48, 0xb2, 0x2e, 0xc3, 0x93, 0xb1, 0x13, 0x7b, 0x78, 0x9f, 0x18, 0x7d, 0x01, 0x74, - 0xbb, 0x01, 0x1a, 0x82, 0x67, 0x1c, 0xea, 0xe2, 0xfa, 0x27, 0x8a, 0xe1, 0x64, 0x4b, 0xb2, 0x0d, - 0x55, 0xb6, 0xf4, 0xa6, 0x48, 0x4b, 0xb4, 0x9f, 0x62, 0x0d, 0x78, 0xe5, 0xbe, 0x74, 0x22, 0x0e, - 0x03, 0xad, 0xfe, 0x4d, 0x91, 0x4b, 0xc2, 0x72, 0xca, 0xd1, 0x39, 0x9c, 0x2e, 0x55, 0x4a, 0xb1, - 0xf6, 0xa6, 0xbe, 0x3d, 0x24, 0x6e, 0x6a, 0x68, 0x06, 0x0f, 0x4a, 0xce, 0xb6, 0x44, 0xd2, 0xe4, - 0x68, 0x5b, 0xae, 0x9a, 0xeb, 0xbf, 0xa6, 0x6a, 0x37, 0xbe, 0x6c, 0x75, 0xbc, 0x61, 0x2b, 0xba, - 0x84, 0xce, 0x1c, 0x2f, 0x28, 0x49, 0x29, 0xb7, 0xf5, 0xfb, 0x5a, 0xbf, 0x0f, 0x8e, 0x59, 0xa9, - 0x93, 0xa3, 0x00, 0xdc, 0xd2, 0xac, 0xcf, 0x2d, 0x55, 0xcc, 0xd2, 0xc6, 0x3a, 0x97, 0xa5, 0xd1, - 0x19, 0x74, 0xe6, 0x8b, 0xfb, 0xb8, 0xa2, 0x17, 0x00, 0xdf, 0x66, 0xf7, 0xd7, 0x8f, 0xd9, 0x1a, - 0x7d, 0xbf, 0x1c, 0xf8, 0x7f, 0xce, 0x38, 0x5d, 0xc9, 0x8f, 0x54, 0x08, 0x72, 0x4d, 0xaf, 0x9a, - 0xbf, 0x0d, 0xba, 0x80, 0x5e, 0xcd, 0x97, 0xdc, 0x28, 0xc2, 0xc6, 0x9f, 0xa1, 0xf6, 0x67, 0xdf, - 0x08, 0xdb, 0x4d, 0x9f, 0x43, 0x77, 0x8e, 0xcd, 0x03, 0xbd, 0x92, 0x40, 0x3f, 0x30, 0x1e, 0xe0, - 0xbd, 0x1b, 0x35, 0x78, 0xc7, 0x4e, 0x0f, 0xc0, 0x8b, 0x1d, 0xd8, 0x30, 0x87, 0xd0, 0x2e, 0x49, - 0x95, 0x15, 0x24, 0x55, 0xfe, 0xf8, 0xd8, 0x84, 0xd1, 0x6f, 0x17, 0x06, 0x46, 0x73, 0x33, 0xc2, - 0x3f, 0x6e, 0xf5, 0x29, 0x0c, 0x58, 0x2e, 0x24, 0xc9, 0x32, 0x52, 0xdf, 0x69, 0xc2, 0x52, 0xa5, - 0xb9, 0x8b, 0x03, 0x3b, 0xfd, 0x3e, 0x45, 0x4f, 0xa0, 0xad, 0x9f, 0x88, 0xd0, 0x53, 0xa7, 0x70, - 0xc8, 0x67, 0x8a, 0xe8, 0x13, 0x04, 0xa9, 0xb2, 0x32, 0x59, 0x6b, 0x21, 0x21, 0x55, 0xf0, 0x58, - 0xc3, 0x8f, 0x54, 0x4e, 0x0e, 0x6c, 0x6f, 0x4e, 0x28, 0xb5, 0x73, 0xe8, 0x31, 0x04, 0xe5, 0x66, - 0x99, 0xb1, 0xd5, 0x8e, 0xf0, 0x87, 0x1a, 0xbe, 0xaf, 0xb3, 0x0d, 0x6c, 0xf4, 0x1d, 0xd0, 0x6d, - 0xae, 0x3b, 0xae, 0xe5, 0xe2, 0xf0, 0x5a, 0x1e, 0x36, 0x6e, 0xdf, 0xb5, 0x7d, 0xeb, 0x6c, 0x96, - 0xa7, 0xea, 0x4b, 0x32, 0xfb, 0x1b, 0x00, 0x00, 0xff, 0xff, 0x9e, 0x75, 0x6d, 0x59, 0xd4, 0x04, - 0x00, 0x00, + // 566 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x53, 0xdd, 0x6a, 0xdb, 0x4c, + 0x10, 0x45, 0x52, 0xe2, 0x9f, 0xb1, 0xfc, 0xc3, 0x7e, 0x5f, 0x83, 0x30, 0x81, 0x1a, 0xb5, 0xa5, + 0x6e, 0x09, 0x2e, 0xd8, 0x0d, 0x94, 0x5e, 0xb6, 0x2e, 0xb8, 0x09, 0x85, 0xb0, 0x81, 0x92, 0x3b, + 0xb1, 0xb6, 0x36, 0xe9, 0x52, 0x79, 0x25, 0x76, 0xd7, 0x06, 0x3d, 0x41, 0xdf, 0xad, 0x2f, 0xd3, + 0x57, 0x28, 0x5a, 0x69, 0xad, 0xb5, 0x9d, 0x5c, 0xf4, 0xca, 0x9e, 0xb3, 0x73, 0xce, 0xcc, 0x9c, + 0xd1, 0xc0, 0x80, 0xf2, 0x95, 0xc8, 0x33, 0xc5, 0x52, 0x3e, 0xc9, 0x44, 0xaa, 0x52, 0xd4, 0xd2, + 0x3f, 0xcb, 0xcd, 0x7d, 0x98, 0x83, 0x7f, 0xcb, 0x1e, 0x38, 0x8d, 0x6f, 0x04, 0xbd, 0xa6, 0x39, + 0x7a, 0x09, 0x3d, 0xa9, 0xe3, 0x28, 0x13, 0x34, 0xfa, 0x49, 0xf3, 0xc0, 0x19, 0x39, 0x63, 0x1f, + 0xfb, 0xd2, 0xce, 0x0a, 0xa0, 0xb9, 0xa5, 0x42, 0xb2, 0x94, 0x07, 0xee, 0xc8, 0x19, 0x77, 0xb1, + 0x09, 0xd1, 0x1b, 0x18, 0x68, 0xed, 0x55, 0x9a, 0x44, 0x26, 0xc5, 0xd3, 0x29, 0x7d, 0x83, 0x7f, + 0x2f, 0xe1, 0xf0, 0x97, 0x0b, 0x8d, 0x4f, 0x1b, 0x1e, 0x27, 0x14, 0x0d, 0xa1, 0xc5, 0x62, 0xca, + 0x15, 0x53, 0xa6, 0xde, 0x2e, 0x46, 0xd7, 0xd0, 0xdf, 0xef, 0x48, 0x06, 0xee, 0xc8, 0x1b, 0x77, + 0xa6, 0x2f, 0x26, 0x66, 0x8a, 0x49, 0x29, 0x33, 0xb1, 0x27, 0x91, 0x5f, 0xb8, 0x12, 0x39, 0xee, + 0xda, 0x7d, 0x4b, 0x74, 0x0e, 0xed, 0x02, 0x20, 0x6a, 0x23, 0x68, 0x70, 0xa2, 0x2b, 0xd5, 0x40, + 0xf1, 0xaa, 0xd8, 0x9a, 0x4a, 0x45, 0xd6, 0x59, 0x70, 0x3a, 0x72, 0xc6, 0x1e, 0xae, 0x81, 0xe1, + 0x1d, 0xa0, 0xe3, 0x02, 0x68, 0x00, 0x9e, 0x71, 0xa9, 0x8d, 0x8b, 0xbf, 0xe8, 0x02, 0x4e, 0xb7, + 0x24, 0xd9, 0x50, 0x6d, 0x4d, 0x67, 0x7a, 0x56, 0xb7, 0x69, 0xd3, 0x71, 0x99, 0xf4, 0xd1, 0xfd, + 0xe0, 0x84, 0x5b, 0xe8, 0x97, 0x13, 0x7c, 0x4e, 0xb9, 0x22, 0x8c, 0x53, 0x81, 0xc6, 0xd0, 0x58, + 0x6a, 0x48, 0x2b, 0x77, 0xa6, 0x83, 0xc3, 0x61, 0x71, 0xf5, 0x8e, 0x66, 0x70, 0x96, 0x09, 0xb6, + 0x25, 0x8a, 0x46, 0x07, 0x9b, 0x73, 0xf5, 0x7c, 0xff, 0x55, 0xaf, 0x76, 0xf1, 0xab, 0x93, 0x96, + 0x37, 0x38, 0x09, 0xaf, 0xa0, 0x35, 0xc7, 0x0b, 0x4a, 0x62, 0x2a, 0xec, 0x39, 0xfc, 0x72, 0x0e, + 0x1f, 0x1c, 0xb3, 0x5e, 0x87, 0xa3, 0x1e, 0xb8, 0x99, 0x59, 0xa5, 0x9b, 0xe9, 0x98, 0xc5, 0x95, + 0x85, 0x2e, 0x8b, 0xc3, 0x73, 0x68, 0xcd, 0x17, 0x4f, 0x69, 0x85, 0xef, 0x01, 0xee, 0x66, 0x4f, + 0xbf, 0x1f, 0xaa, 0x55, 0xfd, 0xfd, 0x76, 0xe0, 0xd9, 0x9c, 0x09, 0xba, 0x52, 0xdf, 0xa8, 0x94, + 0xe4, 0x81, 0xde, 0x54, 0x9f, 0x10, 0xba, 0x84, 0x4e, 0xa1, 0x17, 0xfd, 0xd0, 0x82, 0x95, 0x47, + 0xff, 0xd7, 0x1e, 0xd5, 0xc5, 0xb0, 0x5d, 0xf8, 0x1d, 0xb4, 0xe7, 0xd8, 0x90, 0xca, 0xf5, 0xa0, + 0x9a, 0x64, 0xbc, 0xc0, 0xb5, 0x2b, 0x05, 0x61, 0x57, 0x85, 0x1e, 0x11, 0x16, 0x3b, 0x82, 0xa9, + 0x10, 0x40, 0x33, 0x23, 0x79, 0x92, 0x92, 0x58, 0x7b, 0xe5, 0x63, 0x13, 0x86, 0x7f, 0x5c, 0xe8, + 0x9b, 0xfe, 0xab, 0x71, 0xfe, 0x61, 0xcb, 0xaf, 0xa1, 0xcf, 0xb8, 0x54, 0x24, 0x49, 0x48, 0x71, + 0xc7, 0x11, 0x8b, 0x75, 0xff, 0x6d, 0xdc, 0xb3, 0xe1, 0xaf, 0x31, 0x7a, 0x0b, 0xcd, 0x92, 0x22, + 0x03, 0x4f, 0x9f, 0xc9, 0xb1, 0xa6, 0x49, 0x40, 0xb7, 0xd0, 0x8b, 0xb5, 0xbd, 0xd1, 0xba, 0x6c, + 0x28, 0xa0, 0x9a, 0x72, 0x51, 0x53, 0x0e, 0x3a, 0x9e, 0xec, 0xad, 0xa3, 0x3a, 0xb1, 0xd8, 0xc6, + 0xd0, 0x2b, 0xe8, 0x65, 0x9b, 0x65, 0xc2, 0x56, 0x3b, 0xd1, 0x7b, 0x6d, 0x44, 0xb7, 0x44, 0xab, + 0xb4, 0x21, 0x01, 0x74, 0xac, 0xf5, 0xc8, 0x35, 0x5d, 0xee, 0x5f, 0xd3, 0x73, 0xcb, 0xfd, 0xc7, + 0xbe, 0x0c, 0xeb, 0xac, 0x96, 0x0d, 0x9d, 0x3a, 0xfb, 0x1b, 0x00, 0x00, 0xff, 0xff, 0xcd, 0x2d, + 0x0e, 0xc8, 0x00, 0x05, 0x00, 0x00, } diff --git a/services/shhext/chat/encryption.proto b/services/shhext/chat/protobuf/encryption.proto similarity index 98% rename from services/shhext/chat/encryption.proto rename to services/shhext/chat/protobuf/encryption.proto index 54e374489ab..c8b10277d44 100644 --- a/services/shhext/chat/encryption.proto +++ b/services/shhext/chat/protobuf/encryption.proto @@ -1,6 +1,6 @@ syntax = "proto3"; -package chat; +package protobuf; message SignedPreKey { bytes signed_pre_key = 1; diff --git a/services/shhext/chat/protocol.go b/services/shhext/chat/protocol.go index 5eb661ff531..51065ecb976 100644 --- a/services/shhext/chat/protocol.go +++ b/services/shhext/chat/protocol.go @@ -5,10 +5,12 @@ import ( "errors" "github.com/ethereum/go-ethereum/log" + "github.com/status-im/status-go/services/shhext/chat/multidevice" + "github.com/status-im/status-go/services/shhext/chat/protobuf" "github.com/status-im/status-go/services/shhext/chat/topic" ) -const protocolCurrentVersion = 1 +const ProtocolVersion = 1 const topicNegotiationVersion = 1 const partitionedTopicMinVersion = 1 @@ -16,7 +18,8 @@ type ProtocolService struct { log log.Logger encryption *EncryptionService topic *topic.Service - addedBundlesHandler func([]IdentityAndIDPair) + multidevice *multidevice.Service + addedBundlesHandler func([]multidevice.IdentityAndIDPair) onNewTopicHandler func([]*topic.Secret) Enabled bool } @@ -24,19 +27,26 @@ type ProtocolService struct { var ErrNotProtocolMessage = errors.New("Not a protocol message") // NewProtocolService creates a new ProtocolService instance -func NewProtocolService(encryption *EncryptionService, topic *topic.Service, addedBundlesHandler func([]IdentityAndIDPair), onNewTopicHandler func([]*topic.Secret)) *ProtocolService { +func NewProtocolService(encryption *EncryptionService, topic *topic.Service, multidevice *multidevice.Service, addedBundlesHandler func([]multidevice.IdentityAndIDPair), onNewTopicHandler func([]*topic.Secret)) *ProtocolService { return &ProtocolService{ log: log.New("package", "status-go/services/sshext.chat"), encryption: encryption, topic: topic, + multidevice: multidevice, addedBundlesHandler: addedBundlesHandler, onNewTopicHandler: onNewTopicHandler, } } -func (p *ProtocolService) addBundle(myIdentityKey *ecdsa.PrivateKey, msg *ProtocolMessage, sendSingle bool) (*ProtocolMessage, error) { +func (p *ProtocolService) addBundle(myIdentityKey *ecdsa.PrivateKey, msg *protobuf.ProtocolMessage, sendSingle bool) (*protobuf.ProtocolMessage, error) { + // Get a bundle - bundle, err := p.encryption.CreateBundle(myIdentityKey) + installations, err := p.multidevice.GetOurActiveInstallations(&myIdentityKey.PublicKey) + if err != nil { + return nil, err + } + + bundle, err := p.encryption.CreateBundle(myIdentityKey, installations) if err != nil { p.log.Error("encryption-service", "error creating bundle", err) return nil, err @@ -47,16 +57,16 @@ func (p *ProtocolService) addBundle(myIdentityKey *ecdsa.PrivateKey, msg *Protoc // an issue anymore msg.Bundle = bundle } else { - msg.Bundles = []*Bundle{bundle} + msg.Bundles = []*protobuf.Bundle{bundle} } return msg, nil } // BuildPublicMessage marshals a public chat message given the user identity private key and a payload -func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, payload []byte) (*ProtocolMessage, error) { +func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, payload []byte) (*protobuf.ProtocolMessage, error) { // Build message not encrypted - protocolMessage := &ProtocolMessage{ + protocolMessage := &protobuf.ProtocolMessage{ InstallationId: p.encryption.config.InstallationID, PublicMessage: payload, } @@ -65,9 +75,9 @@ func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, pa } type ProtocolMessageSpec struct { - Message *ProtocolMessage + Message *protobuf.ProtocolMessage // Installations is the targeted devices - Installations []*Installation + Installations []*multidevice.Installation // SharedSecret is a shared secret established among the installations SharedSecret []byte } @@ -93,15 +103,20 @@ func (p *ProtocolMessageSpec) PartitionedTopic() bool { // BuildDirectMessage returns a 1:1 chat message and optionally a negotiated topic given the user identity private key, the recipient's public key, and a payload func (p *ProtocolService) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey, payload []byte) (*ProtocolMessageSpec, error) { + installations, err := p.multidevice.GetActiveInstallations(publicKey) + if err != nil { + return nil, err + } + // Encrypt payload - encryptionResponse, installations, err := p.encryption.EncryptPayload(publicKey, myIdentityKey, payload) + encryptionResponse, installations, err := p.encryption.EncryptPayload(publicKey, myIdentityKey, installations, payload) if err != nil { p.log.Error("encryption-service", "error encrypting payload", err) return nil, err } // Build message - protocolMessage := &ProtocolMessage{ + protocolMessage := &protobuf.ProtocolMessage{ InstallationId: p.encryption.config.InstallationID, DirectMessage: encryptionResponse, } @@ -144,7 +159,7 @@ func (p *ProtocolService) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, pu } // BuildDHMessage builds a message with DH encryption so that it can be decrypted by any other device. -func (p *ProtocolService) BuildDHMessage(myIdentityKey *ecdsa.PrivateKey, destination *ecdsa.PublicKey, payload []byte) (*ProtocolMessage, error) { +func (p *ProtocolService) BuildDHMessage(myIdentityKey *ecdsa.PrivateKey, destination *ecdsa.PublicKey, payload []byte) (*protobuf.ProtocolMessage, error) { // Encrypt payload encryptionResponse, err := p.encryption.EncryptPayloadWithDH(destination, payload) if err != nil { @@ -153,7 +168,7 @@ func (p *ProtocolService) BuildDHMessage(myIdentityKey *ecdsa.PrivateKey, destin } // Build message - protocolMessage := &ProtocolMessage{ + protocolMessage := &protobuf.ProtocolMessage{ InstallationId: p.encryption.config.InstallationID, DirectMessage: encryptionResponse, } @@ -167,28 +182,46 @@ func (p *ProtocolService) BuildDHMessage(myIdentityKey *ecdsa.PrivateKey, destin } // ProcessPublicBundle processes a received X3DH bundle. -func (p *ProtocolService) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, bundle *Bundle) ([]IdentityAndIDPair, error) { - return p.encryption.ProcessPublicBundle(myIdentityKey, bundle) +func (p *ProtocolService) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, bundle *protobuf.Bundle) ([]multidevice.IdentityAndIDPair, error) { + if err := p.encryption.ProcessPublicBundle(myIdentityKey, bundle); err != nil { + return nil, err + } + + theirIdentityKey, err := ExtractIdentity(bundle) + if err != nil { + return nil, err + } + + return p.multidevice.ProcessPublicBundle(myIdentityKey, theirIdentityKey, bundle) } // GetBundle retrieves or creates a X3DH bundle, given a private identity key. -func (p *ProtocolService) GetBundle(myIdentityKey *ecdsa.PrivateKey) (*Bundle, error) { - return p.encryption.CreateBundle(myIdentityKey) +func (p *ProtocolService) GetBundle(myIdentityKey *ecdsa.PrivateKey) (*protobuf.Bundle, error) { + installations, err := p.multidevice.GetOurActiveInstallations(&myIdentityKey.PublicKey) + if err != nil { + return nil, err + } + + return p.encryption.CreateBundle(myIdentityKey, installations) } // EnableInstallation enables an installation for multi-device sync. func (p *ProtocolService) EnableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error { - return p.encryption.EnableInstallation(myIdentityKey, installationID) + return p.multidevice.EnableInstallation(myIdentityKey, installationID) } // DisableInstallation disables an installation for multi-device sync. func (p *ProtocolService) DisableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error { - return p.encryption.DisableInstallation(myIdentityKey, installationID) + return p.multidevice.DisableInstallation(myIdentityKey, installationID) } // GetPublicBundle retrieves a public bundle given an identity -func (p *ProtocolService) GetPublicBundle(theirIdentityKey *ecdsa.PublicKey) (*Bundle, error) { - return p.encryption.GetPublicBundle(theirIdentityKey) +func (p *ProtocolService) GetPublicBundle(theirIdentityKey *ecdsa.PublicKey) (*protobuf.Bundle, error) { + installations, err := p.multidevice.GetActiveInstallations(theirIdentityKey) + if err != nil { + return nil, err + } + return p.encryption.GetPublicBundle(theirIdentityKey, installations) } // ConfirmMessagesProcessed confirms and deletes message keys for the given messages @@ -197,7 +230,7 @@ func (p *ProtocolService) ConfirmMessagesProcessed(messageIDs [][]byte) error { } // HandleMessage unmarshals a message and processes it, decrypting it if it is a 1:1 message. -func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, protocolMessage *ProtocolMessage, messageID []byte) ([]byte, error) { +func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, protocolMessage *protobuf.ProtocolMessage, messageID []byte) ([]byte, error) { if p.encryption == nil { return nil, errors.New("encryption service not initialized") } @@ -205,7 +238,7 @@ func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPu // Process bundle, deprecated, here for backward compatibility if bundle := protocolMessage.GetBundle(); bundle != nil { // Should we stop processing if the bundle cannot be verified? - addedBundles, err := p.encryption.ProcessPublicBundle(myIdentityKey, bundle) + addedBundles, err := p.ProcessPublicBundle(myIdentityKey, bundle) if err != nil { return nil, err } @@ -216,7 +249,7 @@ func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPu // Process bundles for _, bundle := range protocolMessage.GetBundles() { // Should we stop processing if the bundle cannot be verified? - addedBundles, err := p.encryption.ProcessPublicBundle(myIdentityKey, bundle) + addedBundles, err := p.ProcessPublicBundle(myIdentityKey, bundle) if err != nil { return nil, err } @@ -257,7 +290,7 @@ func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPu return nil, errors.New("no payload") } -func getProtocolVersion(bundles []*Bundle, installationID string) uint32 { +func getProtocolVersion(bundles []*protobuf.Bundle, installationID string) uint32 { if installationID == "" { return 0 } diff --git a/services/shhext/chat/protocol_test.go b/services/shhext/chat/protocol_test.go index 59a70d56fed..bad4ce08181 100644 --- a/services/shhext/chat/protocol_test.go +++ b/services/shhext/chat/protocol_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/ethereum/go-ethereum/crypto" + "github.com/status-im/status-go/services/shhext/chat/multidevice" "github.com/status-im/status-go/services/shhext/chat/topic" "github.com/stretchr/testify/suite" ) @@ -38,19 +39,31 @@ func (s *ProtocolServiceTestSuite) SetupTest() { panic(err) } - addedBundlesHandler := func(addedBundles []IdentityAndIDPair) {} + addedBundlesHandler := func(addedBundles []multidevice.IdentityAndIDPair) {} onNewTopicHandler := func(topic []*topic.Secret) {} + aliceMultideviceConfig := &multidevice.Config{ + MaxInstallations: 3, + InstallationID: "1", + } + s.alice = NewProtocolService( NewEncryptionService(alicePersistence, DefaultEncryptionServiceConfig("1")), topic.NewService(alicePersistence.GetTopicStorage()), + multidevice.New(aliceMultideviceConfig, alicePersistence.GetMultideviceStorage()), addedBundlesHandler, onNewTopicHandler, ) + bobMultideviceConfig := &multidevice.Config{ + MaxInstallations: 3, + InstallationID: "2", + } + s.bob = NewProtocolService( NewEncryptionService(bobPersistence, DefaultEncryptionServiceConfig("2")), topic.NewService(bobPersistence.GetTopicStorage()), + multidevice.New(bobMultideviceConfig, bobPersistence.GetMultideviceStorage()), addedBundlesHandler, onNewTopicHandler, ) diff --git a/services/shhext/chat/sql_lite_persistence.go b/services/shhext/chat/sql_lite_persistence.go index fda8a06bedf..c21b6898ad5 100644 --- a/services/shhext/chat/sql_lite_persistence.go +++ b/services/shhext/chat/sql_lite_persistence.go @@ -10,6 +10,8 @@ import ( dr "github.com/status-im/doubleratchet" ecrypto "github.com/status-im/status-go/services/shhext/chat/crypto" appDB "github.com/status-im/status-go/services/shhext/chat/db" + "github.com/status-im/status-go/services/shhext/chat/multidevice" + "github.com/status-im/status-go/services/shhext/chat/protobuf" "github.com/status-im/status-go/services/shhext/chat/topic" ) @@ -18,10 +20,11 @@ const maxNumberOfRows = 100000000 // SQLLitePersistence represents a persistence service tied to an SQLite database type SQLLitePersistence struct { - db *sql.DB - keysStorage dr.KeysStorage - sessionStorage dr.SessionStorage - topicStorage topic.PersistenceService + db *sql.DB + keysStorage dr.KeysStorage + sessionStorage dr.SessionStorage + topicStorage topic.PersistenceService + multideviceStorage multidevice.Persistence } // SQLLiteKeysStorage represents a keys persistence service tied to an SQLite database @@ -48,6 +51,8 @@ func NewSQLLitePersistence(path string, key string) (*SQLLitePersistence, error) s.topicStorage = topic.NewSQLLitePersistence(s.db) + s.multideviceStorage = multidevice.NewSQLLitePersistence(s.db) + return s, nil } @@ -80,6 +85,11 @@ func (s *SQLLitePersistence) GetTopicStorage() topic.PersistenceService { return s.topicStorage } +// GetMultideviceStorage returns the associated multideviceStorage +func (s *SQLLitePersistence) GetMultideviceStorage() multidevice.Persistence { + return s.multideviceStorage +} + // Open opens a file at the specified path func (s *SQLLitePersistence) Open(path string, key string) error { db, err := appDB.Open(path, key, appDB.KdfIterationsNumber) @@ -93,7 +103,7 @@ func (s *SQLLitePersistence) Open(path string, key string) error { } // AddPrivateBundle adds the specified BundleContainer to the database -func (s *SQLLitePersistence) AddPrivateBundle(bc *BundleContainer) error { +func (s *SQLLitePersistence) AddPrivateBundle(bc *protobuf.BundleContainer) error { tx, err := s.db.Begin() if err != nil { return err @@ -147,7 +157,7 @@ func (s *SQLLitePersistence) AddPrivateBundle(bc *BundleContainer) error { } // AddPublicBundle adds the specified Bundle to the database -func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error { +func (s *SQLLitePersistence) AddPublicBundle(b *protobuf.Bundle) error { tx, err := s.db.Begin() if err != nil { @@ -200,7 +210,7 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error { } // GetAnyPrivateBundle retrieves any bundle from the database containing a private key -func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installations []*Installation) (*BundleContainer, error) { +func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installations []*multidevice.Installation) (*protobuf.BundleContainer, error) { versions := make(map[string]uint32) /* #nosec */ @@ -236,11 +246,11 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat defer rows.Close() - bundle := &Bundle{ - SignedPreKeys: make(map[string]*SignedPreKey), + bundle := &protobuf.Bundle{ + SignedPreKeys: make(map[string]*protobuf.SignedPreKey), } - bundleContainer := &BundleContainer{ + bundleContainer := &protobuf.BundleContainer{ Bundle: bundle, } @@ -264,7 +274,7 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat bundle.Timestamp = timestamp } - bundle.SignedPreKeys[installationID] = &SignedPreKey{ + bundle.SignedPreKeys[installationID] = &protobuf.SignedPreKey{ SignedPreKey: signedPreKey, Version: version, ProtocolVersion: versions[installationID], @@ -320,7 +330,7 @@ func (s *SQLLitePersistence) MarkBundleExpired(identity []byte) error { } // GetPublicBundle retrieves an existing Bundle for the specified public key from the database -func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, installations []*Installation) (*Bundle, error) { +func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, installations []*multidevice.Installation) (*protobuf.Bundle, error) { if len(installations) == 0 { return nil, nil @@ -357,9 +367,9 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, install defer rows.Close() - bundle := &Bundle{ + bundle := &protobuf.Bundle{ Identity: identity, - SignedPreKeys: make(map[string]*SignedPreKey), + SignedPreKeys: make(map[string]*protobuf.SignedPreKey), } for rows.Next() { @@ -376,7 +386,7 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, install return nil, err } - bundle.SignedPreKeys[installationID] = &SignedPreKey{ + bundle.SignedPreKeys[installationID] = &protobuf.SignedPreKey{ SignedPreKey: signedPreKey, Version: version, ProtocolVersion: versions[installationID], @@ -751,159 +761,6 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) { } } -// GetActiveInstallations returns the active installations for a given identity -func (s *SQLLitePersistence) GetActiveInstallations(maxInstallations int, identity []byte) ([]*Installation, error) { - stmt, err := s.db.Prepare(`SELECT installation_id, version - FROM installations - WHERE enabled = 1 AND identity = ? - ORDER BY timestamp DESC - LIMIT ?`) - if err != nil { - return nil, err - } - - var installations []*Installation - rows, err := stmt.Query(identity, maxInstallations) - if err != nil { - return nil, err - } - - for rows.Next() { - var installationID string - var version uint32 - err = rows.Scan( - &installationID, - &version, - ) - if err != nil { - return nil, err - } - installations = append(installations, &Installation{ - ID: installationID, - Version: version, - }) - - } - - return installations, nil - -} - -// AddInstallations adds the installations for a given identity, maintaining the enabled flag -func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64, installations []*Installation, defaultEnabled bool) error { - tx, err := s.db.Begin() - if err != nil { - return nil - } - - for _, installation := range installations { - stmt, err := tx.Prepare(`SELECT enabled, version - FROM installations - WHERE identity = ? AND installation_id = ? - LIMIT 1`) - if err != nil { - return err - } - defer stmt.Close() - - var oldEnabled bool - // We don't override version once we saw one - var oldVersion uint32 - latestVersion := installation.Version - - err = stmt.QueryRow(identity, installation.ID).Scan(&oldEnabled, &oldVersion) - if err != nil && err != sql.ErrNoRows { - return err - } - - // We update timestamp if present without changing enabled, only if this is a new bundle - // and we set the version to the latest we ever saw - if err != sql.ErrNoRows { - if oldVersion > installation.Version { - latestVersion = oldVersion - } - - stmt, err = tx.Prepare(`UPDATE installations - SET timestamp = ?, enabled = ?, version = ? - WHERE identity = ? - AND installation_id = ? - AND timestamp < ?`) - if err != nil { - return err - } - - _, err = stmt.Exec( - timestamp, - oldEnabled, - latestVersion, - identity, - installation.ID, - timestamp, - ) - if err != nil { - return err - } - defer stmt.Close() - - } else { - stmt, err = tx.Prepare(`INSERT INTO installations(identity, installation_id, timestamp, enabled, version) - VALUES (?, ?, ?, ?, ?)`) - if err != nil { - return err - } - - _, err = stmt.Exec( - identity, - installation.ID, - timestamp, - defaultEnabled, - latestVersion, - ) - if err != nil { - return err - } - defer stmt.Close() - } - - } - - if err := tx.Commit(); err != nil { - _ = tx.Rollback() - return err - } - - return nil - -} - -// EnableInstallation enables the installation -func (s *SQLLitePersistence) EnableInstallation(identity []byte, installationID string) error { - stmt, err := s.db.Prepare(`UPDATE installations - SET enabled = 1 - WHERE identity = ? AND installation_id = ?`) - if err != nil { - return err - } - - _, err = stmt.Exec(identity, installationID) - return err - -} - -// DisableInstallation disable the installation -func (s *SQLLitePersistence) DisableInstallation(identity []byte, installationID string) error { - - stmt, err := s.db.Prepare(`UPDATE installations - SET enabled = 0 - WHERE identity = ? AND installation_id = ?`) - if err != nil { - return err - } - - _, err = stmt.Exec(identity, installationID) - return err -} - func toKey(a []byte) dr.Key { var k [32]byte copy(k[:], a) diff --git a/services/shhext/chat/sql_lite_persistence_test.go b/services/shhext/chat/sql_lite_persistence_test.go index 097ffc66c3b..e7dbfac3075 100644 --- a/services/shhext/chat/sql_lite_persistence_test.go +++ b/services/shhext/chat/sql_lite_persistence_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/ethereum/go-ethereum/crypto" + "github.com/status-im/status-go/services/shhext/chat/multidevice" "github.com/stretchr/testify/suite" ) @@ -53,7 +54,7 @@ func (s *SQLLitePersistenceTestSuite) TestPrivateBundle() { s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Nil(actualKey) - anyPrivateBundle, err := s.service.GetAnyPrivateBundle([]byte("non-existing-id"), []*Installation{{ID: installationID, Version: 1}}) + anyPrivateBundle, err := s.service.GetAnyPrivateBundle([]byte("non-existing-id"), []*multidevice.Installation{{ID: installationID, Version: 1}}) s.Require().NoError(err) s.Nil(anyPrivateBundle) @@ -70,7 +71,7 @@ func (s *SQLLitePersistenceTestSuite) TestPrivateBundle() { s.Equal(bundle.GetPrivateSignedPreKey(), actualKey, "It returns the same key") identity := crypto.CompressPubkey(&key.PublicKey) - anyPrivateBundle, err = s.service.GetAnyPrivateBundle(identity, []*Installation{{ID: installationID, Version: 1}}) + anyPrivateBundle, err = s.service.GetAnyPrivateBundle(identity, []*multidevice.Installation{{ID: installationID, Version: 1}}) s.Require().NoError(err) s.NotNil(anyPrivateBundle) s.Equal(bundle.GetBundle().GetSignedPreKeys()[installationID].SignedPreKey, anyPrivateBundle.GetBundle().GetSignedPreKeys()[installationID].SignedPreKey, "It returns the same bundle") @@ -80,7 +81,7 @@ func (s *SQLLitePersistenceTestSuite) TestPublicBundle() { key, err := crypto.GenerateKey() s.Require().NoError(err) - actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) + actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}}) s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Nil(actualBundle) @@ -91,7 +92,7 @@ func (s *SQLLitePersistenceTestSuite) TestPublicBundle() { err = s.service.AddPublicBundle(bundle) s.Require().NoError(err) - actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) + actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}}) s.Require().NoError(err) s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity") s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys") @@ -101,7 +102,7 @@ func (s *SQLLitePersistenceTestSuite) TestUpdatedBundle() { key, err := crypto.GenerateKey() s.Require().NoError(err) - actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) + actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}}) s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Nil(actualBundle) @@ -123,7 +124,7 @@ func (s *SQLLitePersistenceTestSuite) TestUpdatedBundle() { err = s.service.AddPublicBundle(bundle) s.Require().NoError(err) - actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) + actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}}) s.Require().NoError(err) s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity") s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys") @@ -133,7 +134,7 @@ func (s *SQLLitePersistenceTestSuite) TestOutOfOrderBundles() { key, err := crypto.GenerateKey() s.Require().NoError(err) - actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) + actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}}) s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Nil(actualBundle) @@ -160,7 +161,7 @@ func (s *SQLLitePersistenceTestSuite) TestOutOfOrderBundles() { err = s.service.AddPublicBundle(bundle1) s.Require().NoError(err) - actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) + actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}}) s.Require().NoError(err) s.Equal(bundle2.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity") s.Equal(bundle2.GetSignedPreKeys()["1"].GetVersion(), uint32(1)) @@ -171,7 +172,7 @@ func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() { key, err := crypto.GenerateKey() s.Require().NoError(err) - actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) + actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}}) s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Nil(actualBundle) @@ -197,7 +198,7 @@ func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() { s.Require().NoError(err) // Returns the most recent bundle - actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) + actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}}) s.Require().NoError(err) s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the identity") @@ -209,7 +210,7 @@ func (s *SQLLitePersistenceTestSuite) TestMultiDevicePublicBundle() { key, err := crypto.GenerateKey() s.Require().NoError(err) - actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) + actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}}) s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Nil(actualBundle) @@ -234,7 +235,7 @@ func (s *SQLLitePersistenceTestSuite) TestMultiDevicePublicBundle() { // Returns the most recent bundle actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, - []*Installation{ + []*multidevice.Installation{ {ID: "1", Version: 1}, {ID: "2", Version: 1}, }) @@ -347,211 +348,4 @@ func (s *SQLLitePersistenceTestSuite) TestRatchetInfoNoBundle() { s.Nil(ratchetInfo, "It returns nil when no bundle is there") } -func (s *SQLLitePersistenceTestSuite) TestAddInstallations() { - identity := []byte("alice") - installations := []*Installation{ - {ID: "alice-1", Version: 1}, - {ID: "alice-2", Version: 2}, - } - err := s.service.AddInstallations( - identity, - 1, - installations, - true, - ) - - s.Require().NoError(err) - - enabledInstallations, err := s.service.GetActiveInstallations(5, identity) - s.Require().NoError(err) - - s.Require().Equal(installations, enabledInstallations) -} - -func (s *SQLLitePersistenceTestSuite) TestAddInstallationVersions() { - identity := []byte("alice") - installations := []*Installation{ - {ID: "alice-1", Version: 1}, - } - err := s.service.AddInstallations( - identity, - 1, - installations, - true, - ) - - s.Require().NoError(err) - - enabledInstallations, err := s.service.GetActiveInstallations(5, identity) - s.Require().NoError(err) - - s.Require().Equal(installations, enabledInstallations) - - installationsWithDowngradedVersion := []*Installation{ - {ID: "alice-1", Version: 0}, - } - - err = s.service.AddInstallations( - identity, - 3, - installationsWithDowngradedVersion, - true, - ) - s.Require().NoError(err) - - enabledInstallations, err = s.service.GetActiveInstallations(5, identity) - s.Require().NoError(err) - s.Require().Equal(installations, enabledInstallations) -} - -func (s *SQLLitePersistenceTestSuite) TestAddInstallationsLimit() { - identity := []byte("alice") - - installations := []*Installation{ - {ID: "alice-1", Version: 1}, - {ID: "alice-2", Version: 2}, - } - - err := s.service.AddInstallations( - identity, - 1, - installations, - true, - ) - s.Require().NoError(err) - - installations = []*Installation{ - {ID: "alice-1", Version: 1}, - {ID: "alice-3", Version: 3}, - } - - err = s.service.AddInstallations( - identity, - 2, - installations, - true, - ) - s.Require().NoError(err) - - installations = []*Installation{ - {ID: "alice-2", Version: 2}, - {ID: "alice-3", Version: 3}, - {ID: "alice-4", Version: 4}, - } - - err = s.service.AddInstallations( - identity, - 3, - installations, - true, - ) - s.Require().NoError(err) - - enabledInstallations, err := s.service.GetActiveInstallations(3, identity) - s.Require().NoError(err) - - s.Require().Equal(installations, enabledInstallations) -} - -func (s *SQLLitePersistenceTestSuite) TestAddInstallationsDisabled() { - identity := []byte("alice") - - installations := []*Installation{ - {ID: "alice-1", Version: 1}, - {ID: "alice-2", Version: 2}, - } - - err := s.service.AddInstallations( - identity, - 1, - installations, - false, - ) - s.Require().NoError(err) - - actualInstallations, err := s.service.GetActiveInstallations(3, identity) - s.Require().NoError(err) - - s.Require().Nil(actualInstallations) -} - -func (s *SQLLitePersistenceTestSuite) TestDisableInstallation() { - identity := []byte("alice") - - installations := []*Installation{ - {ID: "alice-1", Version: 1}, - {ID: "alice-2", Version: 2}, - } - - err := s.service.AddInstallations( - identity, - 1, - installations, - true, - ) - s.Require().NoError(err) - - err = s.service.DisableInstallation(identity, "alice-1") - s.Require().NoError(err) - - // We add the installations again - installations = []*Installation{ - {ID: "alice-1", Version: 1}, - {ID: "alice-2", Version: 2}, - } - - err = s.service.AddInstallations( - identity, - 1, - installations, - true, - ) - s.Require().NoError(err) - - actualInstallations, err := s.service.GetActiveInstallations(3, identity) - s.Require().NoError(err) - - expected := []*Installation{{ID: "alice-2", Version: 2}} - s.Require().Equal(expected, actualInstallations) -} - -func (s *SQLLitePersistenceTestSuite) TestEnableInstallation() { - identity := []byte("alice") - - installations := []*Installation{ - {ID: "alice-1", Version: 1}, - {ID: "alice-2", Version: 2}, - } - - err := s.service.AddInstallations( - identity, - 1, - installations, - true, - ) - s.Require().NoError(err) - - err = s.service.DisableInstallation(identity, "alice-1") - s.Require().NoError(err) - - actualInstallations, err := s.service.GetActiveInstallations(3, identity) - s.Require().NoError(err) - - expected := []*Installation{{ID: "alice-2", Version: 2}} - s.Require().Equal(expected, actualInstallations) - - err = s.service.EnableInstallation(identity, "alice-1") - s.Require().NoError(err) - - actualInstallations, err = s.service.GetActiveInstallations(3, identity) - s.Require().NoError(err) - - expected = []*Installation{ - {ID: "alice-1", Version: 1}, - {ID: "alice-2", Version: 2}, - } - s.Require().Equal(expected, actualInstallations) - -} - // TODO: Add test for MarkBundleExpired diff --git a/services/shhext/chat/x3dh.go b/services/shhext/chat/x3dh.go index 082410676b7..5f77533a883 100644 --- a/services/shhext/chat/x3dh.go +++ b/services/shhext/chat/x3dh.go @@ -2,16 +2,14 @@ package chat import ( "crypto/ecdsa" - "encoding/base64" "errors" - "fmt" "sort" "strconv" "time" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/golang/protobuf/proto" + "github.com/status-im/status-go/services/shhext/chat/protobuf" ) const ( @@ -19,33 +17,7 @@ const ( sskLen = 16 ) -// ToBase64 returns a Base64 encoding representation of the protobuf Bundle message -func (bundle *Bundle) ToBase64() (string, error) { - marshaledMessage, err := proto.Marshal(bundle) - if err != nil { - return "", err - } - - return base64.StdEncoding.EncodeToString(marshaledMessage), nil -} - -// FromBase64 unmarshals a Bundle from a Base64 encoding representation of the protobuf Bundle message -func FromBase64(str string) (*Bundle, error) { - bundle := &Bundle{} - - decodedBundle, err := base64.StdEncoding.DecodeString(str) - if err != nil { - return nil, err - } - - if err := proto.Unmarshal(decodedBundle, bundle); err != nil { - return nil, err - } - - return bundle, nil -} - -func buildSignatureMaterial(bundle *Bundle) []byte { +func buildSignatureMaterial(bundle *protobuf.Bundle) []byte { signedPreKeys := bundle.GetSignedPreKeys() timestamp := bundle.GetTimestamp() var keys []string @@ -73,7 +45,7 @@ func buildSignatureMaterial(bundle *Bundle) []byte { } -func SignBundle(identity *ecdsa.PrivateKey, bundleContainer *BundleContainer) error { +func SignBundle(identity *ecdsa.PrivateKey, bundleContainer *protobuf.BundleContainer) error { signatureMaterial := buildSignatureMaterial(bundleContainer.GetBundle()) signature, err := crypto.Sign(crypto.Keccak256(signatureMaterial), identity) @@ -85,7 +57,7 @@ func SignBundle(identity *ecdsa.PrivateKey, bundleContainer *BundleContainer) er } // NewBundleContainer creates a new BundleContainer from an identity private key -func NewBundleContainer(identity *ecdsa.PrivateKey, installationID string) (*BundleContainer, error) { +func NewBundleContainer(identity *ecdsa.PrivateKey, installationID string) (*protobuf.BundleContainer, error) { preKey, err := crypto.GenerateKey() if err != nil { return nil, err @@ -95,35 +67,35 @@ func NewBundleContainer(identity *ecdsa.PrivateKey, installationID string) (*Bun compressedIdentityKey := crypto.CompressPubkey(&identity.PublicKey) encodedPreKey := crypto.FromECDSA(preKey) - signedPreKeys := make(map[string]*SignedPreKey) - signedPreKeys[installationID] = &SignedPreKey{ - ProtocolVersion: protocolCurrentVersion, + signedPreKeys := make(map[string]*protobuf.SignedPreKey) + signedPreKeys[installationID] = &protobuf.SignedPreKey{ + ProtocolVersion: ProtocolVersion, SignedPreKey: compressedPreKey, } - bundle := Bundle{ + bundle := protobuf.Bundle{ Timestamp: time.Now().UnixNano(), Identity: compressedIdentityKey, SignedPreKeys: signedPreKeys, } - return &BundleContainer{ + return &protobuf.BundleContainer{ Bundle: &bundle, PrivateSignedPreKey: encodedPreKey, }, nil } // VerifyBundle checks that a bundle is valid -func VerifyBundle(bundle *Bundle) error { +func VerifyBundle(bundle *protobuf.Bundle) error { _, err := ExtractIdentity(bundle) return err } // ExtractIdentity extracts the identity key from a given bundle -func ExtractIdentity(bundle *Bundle) (string, error) { +func ExtractIdentity(bundle *protobuf.Bundle) (*ecdsa.PublicKey, error) { bundleIdentityKey, err := crypto.DecompressPubkey(bundle.GetIdentity()) if err != nil { - return "", err + return nil, err } signatureMaterial := buildSignatureMaterial(bundle) @@ -133,14 +105,14 @@ func ExtractIdentity(bundle *Bundle) (string, error) { bundle.GetSignature(), ) if err != nil { - return "", err + return nil, err } if crypto.PubkeyToAddress(*recoveredKey) != crypto.PubkeyToAddress(*bundleIdentityKey) { - return "", errors.New("identity key and signature mismatch") + return nil, errors.New("identity key and signature mismatch") } - return fmt.Sprintf("0x%x", crypto.FromECDSAPub(recoveredKey)), nil + return recoveredKey, nil } // PerformDH generates a shared key given a private and a public key diff --git a/services/shhext/chat/x3dh_test.go b/services/shhext/chat/x3dh_test.go index 5222c3973de..e2a9ac5879a 100644 --- a/services/shhext/chat/x3dh_test.go +++ b/services/shhext/chat/x3dh_test.go @@ -6,6 +6,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/ecies" + "github.com/status-im/status-go/services/shhext/chat/protobuf" "github.com/stretchr/testify/require" ) @@ -14,12 +15,11 @@ const ( aliceEphemeralKey = "11111111111111111111111111111111" bobPrivateKey = "22222222222222222222222222222222" bobSignedPreKey = "33333333333333333333333333333333" - base64Bundle = "CiECkJmdu/QwNL/7HdU+rB60wzpOocT0i6WFz944MIQPBVUSKAoBMhIjCiECPHKt20/fCa+U8MlNf+kqOGp+cM+KHYWRY4a7JTXHsbEiQT9Wse3UkJgo6/1HzxQfHcZNBaMH0j+0eylfBf1ropsLZ7yZM98k/qDQ3ZW5uHXQ4zhY8E1Q7HDytqm62k5JIPYA" ) var sharedKey = []byte{0xa4, 0xe9, 0x23, 0xd0, 0xaf, 0x8f, 0xe7, 0x8a, 0x5, 0x63, 0x63, 0xbe, 0x20, 0xe7, 0x1c, 0xa, 0x58, 0xe5, 0x69, 0xea, 0x8f, 0xc1, 0xf7, 0x92, 0x89, 0xec, 0xa1, 0xd, 0x9f, 0x68, 0x13, 0x3a} -func bobBundle() (*Bundle, error) { +func bobBundle() (*protobuf.Bundle, error) { privateKey, err := crypto.ToECDSA([]byte(bobPrivateKey)) if err != nil { return nil, err @@ -37,10 +37,10 @@ func bobBundle() (*Bundle, error) { return nil, err } - signedPreKeys := make(map[string]*SignedPreKey) - signedPreKeys[bobInstallationID] = &SignedPreKey{SignedPreKey: compressedPreKey} + signedPreKeys := make(map[string]*protobuf.SignedPreKey) + signedPreKeys[bobInstallationID] = &protobuf.SignedPreKey{SignedPreKey: compressedPreKey} - bundle := Bundle{ + bundle := protobuf.Bundle{ Identity: crypto.CompressPubkey(&privateKey.PublicKey), SignedPreKeys: signedPreKeys, Signature: signature, @@ -76,8 +76,8 @@ func TestNewBundleContainer(t *testing.T) { require.Equal( t, - &privateKey.PublicKey, - recoveredPublicKey, + privateKey.PublicKey, + *recoveredPublicKey, "The correct public key should be recovered", ) } @@ -94,7 +94,7 @@ func TestSignBundle(t *testing.T) { // We add a signed pre key signedPreKeys := bundle1.GetSignedPreKeys() - signedPreKeys["2"] = &SignedPreKey{SignedPreKey: []byte("key")} + signedPreKeys["2"] = &protobuf.SignedPreKey{SignedPreKey: []byte("key")} err = SignBundle(privateKey, bundleContainer1) require.NoError(t, err) @@ -115,40 +115,12 @@ func TestSignBundle(t *testing.T) { require.Equal( t, - &privateKey.PublicKey, - recoveredPublicKey, + privateKey.PublicKey, + *recoveredPublicKey, "The correct public key should be recovered", ) } -func TestToBase64(t *testing.T) { - bundle, err := bobBundle() - require.NoError(t, err, "Test bundle should be generated without errors") - - actualBase64Bundle, err := bundle.ToBase64() - require.NoError(t, err, "No error should be reported") - require.Equal( - t, - base64Bundle, - actualBase64Bundle, - "The correct bundle should be generated", - ) -} - -func TestFromBase64(t *testing.T) { - expectedBundle, err := bobBundle() - require.NoError(t, err, "Test bundle should be generated without errors") - - actualBundle, err := FromBase64(base64Bundle) - require.NoError(t, err, "Bundle should be unmarshaled without errors") - require.Equal( - t, - expectedBundle, - actualBundle, - "The correct bundle should be generated", - ) -} - func TestExtractIdentity(t *testing.T) { privateKey, err := crypto.ToECDSA([]byte(alicePrivateKey)) require.NoError(t, err, "Private key should be generated without errors") @@ -168,8 +140,8 @@ func TestExtractIdentity(t *testing.T) { require.Equal( t, - "0x042ed557f5ad336b31a49857e4e9664954ac33385aa20a93e2d64bfe7f08f51277bcb27c1259f802a52ed3ea7ac939043f0cc864e27400294bf121f23877995852", - recoveredPublicKey, + privateKey.PublicKey, + *recoveredPublicKey, "The correct public key should be recovered", ) } diff --git a/services/shhext/filter/service.go b/services/shhext/filter/service.go index 31b86f26dd3..0b90db76868 100644 --- a/services/shhext/filter/service.go +++ b/services/shhext/filter/service.go @@ -45,7 +45,6 @@ type Chat struct { } type Service struct { - keyID string whisper *whisper.Whisper topic *topic.Service chats map[string]*Chat @@ -53,9 +52,8 @@ type Service struct { } // New returns a new filter service -func New(k string, w *whisper.Whisper, t *topic.Service) *Service { +func New(w *whisper.Whisper, t *topic.Service) *Service { return &Service{ - keyID: k, whisper: w, topic: t, mutex: sync.Mutex{}, @@ -184,7 +182,12 @@ func (s *Service) LoadPartitioned(myKey *ecdsa.PrivateKey, theirPublicKey *ecdsa // Load creates filters for a given chat, and returns all the created filters func (s *Service) Load(chat *Chat) ([]*Chat, error) { - myKey, err := s.whisper.GetPrivateKey(s.keyID) + keyID := s.whisper.SelectedKeyPairID() + if keyID == "" { + return nil, errors.New("no key selected") + } + myKey, err := s.whisper.GetPrivateKey(keyID) + if err != nil { return nil, err } diff --git a/services/shhext/filter/service_test.go b/services/shhext/filter/service_test.go index 4ca4dc2ec95..0e96daa9c35 100644 --- a/services/shhext/filter/service_test.go +++ b/services/shhext/filter/service_test.go @@ -69,10 +69,10 @@ func (s *ServiceTestSuite) SetupTest() { // Build services topicService := topic.NewService(topic.NewSQLLitePersistence(db)) whisper := whisper.New(nil) - keyID, err := whisper.AddKeyPair(s.keys[0].privateKey) + _, err = whisper.AddKeyPair(s.keys[0].privateKey) s.Require().NoError(err) - s.service = New(keyID, whisper, topicService) + s.service = New(whisper, topicService) } func (s *ServiceTestSuite) TearDownTest() { @@ -177,6 +177,7 @@ func (s *ServiceTestSuite) TestLoadChat() { response1, err := s.service.Load(&Chat{ChatID: "status"}) + s.Require().NoError(err) s.Require().Equal(1, len(response1)) s.Require().Equal("status", response1[0].ChatID) s.Require().True(response1[0].Listen) diff --git a/services/shhext/service.go b/services/shhext/service.go index 45864f8e193..cc318cf7b01 100644 --- a/services/shhext/service.go +++ b/services/shhext/service.go @@ -18,6 +18,8 @@ import ( "github.com/status-im/status-go/params" "github.com/status-im/status-go/services/shhext/chat" appDB "github.com/status-im/status-go/services/shhext/chat/db" + "github.com/status-im/status-go/services/shhext/chat/multidevice" + "github.com/status-im/status-go/services/shhext/chat/protobuf" "github.com/status-im/status-go/services/shhext/chat/topic" "github.com/status-im/status-go/services/shhext/dedup" "github.com/status-im/status-go/services/shhext/filter" @@ -33,6 +35,7 @@ const ( defaultConnectionsTarget = 1 // defaultTimeoutWaitAdded is a timeout to use to establish initial connections. defaultTimeoutWaitAdded = 5 * time.Second + maxInstallations = 3 ) var errProtocolNotInitialized = errors.New("procotol is not initialized") @@ -188,7 +191,7 @@ func (s *Service) initProtocol(address, encKey, password string) error { return err } - addedBundlesHandler := func(addedBundles []chat.IdentityAndIDPair) { + addedBundlesHandler := func(addedBundles []multidevice.IdentityAndIDPair) { handler := EnvelopeSignalHandler{} for _, bundle := range addedBundles { handler.BundleAdded(bundle[0], bundle[1]) @@ -197,21 +200,31 @@ func (s *Service) initProtocol(address, encKey, password string) error { // Initialize topics topicService := topic.NewService(persistence.GetTopicStorage()) - filterService := filter.New(s.config.AsymKeyID, s.w, topicService) + // Initialize filter + filterService := filter.New(s.w, topicService) s.filter = filterService + // Initialize multidevice + multideviceConfig := &multidevice.Config{ + InstallationID: s.installationID, + ProtocolVersion: chat.ProtocolVersion, + MaxInstallations: maxInstallations, + } + multideviceService := multidevice.New(multideviceConfig, persistence.GetMultideviceStorage()) + s.protocol = chat.NewProtocolService( chat.NewEncryptionService( persistence, chat.DefaultEncryptionServiceConfig(s.installationID)), topicService, + multideviceService, addedBundlesHandler, s.onNewTopicHandler) return nil } -func (s *Service) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, bundle *chat.Bundle) ([]chat.IdentityAndIDPair, error) { +func (s *Service) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, bundle *protobuf.Bundle) ([]multidevice.IdentityAndIDPair, error) { if s.protocol == nil { return nil, errProtocolNotInitialized } @@ -219,7 +232,7 @@ func (s *Service) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, bundle *c return s.protocol.ProcessPublicBundle(myIdentityKey, bundle) } -func (s *Service) GetBundle(myIdentityKey *ecdsa.PrivateKey) (*chat.Bundle, error) { +func (s *Service) GetBundle(myIdentityKey *ecdsa.PrivateKey) (*protobuf.Bundle, error) { if s.protocol == nil { return nil, errProtocolNotInitialized } @@ -236,7 +249,7 @@ func (s *Service) EnableInstallation(myIdentityKey *ecdsa.PublicKey, installatio return s.protocol.EnableInstallation(myIdentityKey, installationID) } -func (s *Service) GetPublicBundle(identityKey *ecdsa.PublicKey) (*chat.Bundle, error) { +func (s *Service) GetPublicBundle(identityKey *ecdsa.PublicKey) (*protobuf.Bundle, error) { if s.protocol == nil { return nil, errProtocolNotInitialized }