diff --git a/go.mod b/go.mod index 6a1fd27..09271f0 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.7.1 github.com/trisacrypto/directory v1.3.1 - github.com/trisacrypto/trisa v0.3.5 + github.com/trisacrypto/trisa v0.3.6 github.com/urfave/cli v1.22.5 google.golang.org/grpc v1.45.0 google.golang.org/protobuf v1.28.0 diff --git a/go.sum b/go.sum index ab1a3ee..f6cb6f6 100644 --- a/go.sum +++ b/go.sum @@ -633,6 +633,8 @@ github.com/trisacrypto/trisa v0.3.4 h1:4GF5cpHY9Pg3Qupbp+mzIUMRwbxUFNXed82V7UB0u github.com/trisacrypto/trisa v0.3.4/go.mod h1:5VC6uJBIyiPreZfR67Mri6+CVXIuVuFZwTPromBPi9g= github.com/trisacrypto/trisa v0.3.5 h1:aRn6vSJ/LFc6e0lFy1vxdZfQHIak9VQpu9Ny//q99aE= github.com/trisacrypto/trisa v0.3.5/go.mod h1:5VC6uJBIyiPreZfR67Mri6+CVXIuVuFZwTPromBPi9g= +github.com/trisacrypto/trisa v0.3.6 h1:AaD4O/0gNLUbdf3A6zP420Jq19JWtt/31WicR9oj+Xc= +github.com/trisacrypto/trisa v0.3.6/go.mod h1:5VC6uJBIyiPreZfR67Mri6+CVXIuVuFZwTPromBPi9g= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= diff --git a/pkg/rvasp/transfer.go b/pkg/rvasp/transfer.go index 2cf891d..d8f64a4 100644 --- a/pkg/rvasp/transfer.go +++ b/pkg/rvasp/transfer.go @@ -160,12 +160,6 @@ func parsePayload(payload *protocol.Payload, response bool) (identity *ivms101.I return nil, nil, nil, protocol.Errorf(protocol.UnparseableIdentity, "could non unmarshal identity: %s", err) } - // Validate identity fields - if identity.Originator == nil || identity.OriginatingVasp == nil || identity.BeneficiaryVasp == nil || identity.Beneficiary == nil { - log.Warn().Msg("incomplete identity payload") - return nil, nil, nil, protocol.Errorf(protocol.IncompleteIdentity, "incomplete identity payload") - } - // Parse the transaction message type var msgTx proto.Message if msgTx, err = payload.Transaction.UnmarshalNew(); err != nil { @@ -191,3 +185,39 @@ func parsePayload(payload *protocol.Payload, response bool) (identity *ivms101.I } return identity, transaction, pending, nil } + +// Validate an identity payload, returning an error if the payload is not valid. +func validateIdentityPayload(identity *ivms101.IdentityPayload, requireBeneficiary bool) (err *protocol.Error) { + // Verify the identity payload is not nil + if identity == nil { + log.Warn().Msg("identity payload is nil") + return protocol.Errorf(protocol.InternalError, "identity payload is nil") + } + + // Validate that the originator is present + if identity.Originator == nil { + log.Warn().Msg("identity payload missing originator") + return protocol.Errorf(protocol.IncompleteIdentity, "missing originator") + } + + // Validate that the originator vasp is present + if identity.OriginatingVasp == nil { + log.Warn().Msg("identity payload missing originating vasp") + return protocol.Errorf(protocol.IncompleteIdentity, "missing originating vasp") + } + + if requireBeneficiary { + // Validate that the beneficiary is present + if identity.Beneficiary == nil { + log.Warn().Msg("identity payload missing beneficiary") + return protocol.Errorf(protocol.IncompleteIdentity, "missing beneficiary") + } + + // Validate that the beneficiary vasp is present + if identity.BeneficiaryVasp == nil { + log.Warn().Msg("identity payload missing beneficiary vasp") + return protocol.Errorf(protocol.IncompleteIdentity, "missing beneficiary vasp") + } + } + return nil +} diff --git a/pkg/rvasp/trisa.go b/pkg/rvasp/trisa.go index 9f2a26d..2d1eb0a 100644 --- a/pkg/rvasp/trisa.go +++ b/pkg/rvasp/trisa.go @@ -157,8 +157,8 @@ func (s *TRISA) Transfer(ctx context.Context, in *protocol.SecureEnvelope) (out } var transferError *protocol.Error - if out, transferError = s.handleTransaction(ctx, peer, in); err != nil { - log.Warn().Err(err).Msg("could not complete transfer") + if out, transferError = s.handleTransaction(ctx, peer, in); transferError != nil { + log.Warn().Err(transferError).Msg("could not complete transfer") var msg *protocol.SecureEnvelope if msg, err = envelope.Reject(transferError, envelope.WithEnvelopeID(in.Id)); err != nil { log.Error().Err(err).Msg("could not create TRISA error envelope") @@ -478,19 +478,13 @@ func (s *TRISA) respondTransfer(in *protocol.SecureEnvelope, peer *peers.Peer, i return nil, protocol.Errorf(protocol.InternalError, "request could not be processed") } - if requireBeneficiary { - // TODO: Validate the actual fields in the beneficiary identity - if identity.BeneficiaryVasp == nil || identity.BeneficiaryVasp.BeneficiaryVasp == nil { - log.Warn().Msg("TRISA protocol error: missing beneficiary vasp identity") - reject := protocol.Errorf(protocol.Rejected, "missing beneficiary vasp identity") - if out, err = envelope.Reject(reject, envelope.WithEnvelopeID(in.Id)); err != nil { - log.Error().Err(err).Msg("could not create reject envelope") - return nil, protocol.Errorf(protocol.InternalError, "request coould not be processed: %s", err) - } - xfer.SetState(pb.TransactionState_REJECTED) - return out, nil - } - } else { + if transferError = validateIdentityPayload(identity, requireBeneficiary); transferError != nil { + log.Warn().Str("message", transferError.Message).Msg("could not validate identity payload") + xfer.SetState(pb.TransactionState_REJECTED) + return nil, transferError + } + + if !requireBeneficiary { // Fill in the beneficiary identity information for the repair policy s.repairBeneficiary(identity, account) } @@ -665,13 +659,6 @@ func (s *TRISA) sendAsync(tx *db.Transaction) (err error) { return fmt.Errorf("could not fetch originator peer: %s", err) } - // Fetch the signing key from the remote peer - var signKey *rsa.PublicKey - if signKey, err = s.parent.fetchSigningKey(peer); err != nil { - log.Warn().Err(err).Msg("could not fetch signing key from originator peer") - return fmt.Errorf("could not fetch signing key from originator peer: %s", err) - } - // Create the identity for the payload identity := &ivms101.IdentityPayload{} if err = protojson.Unmarshal([]byte(tx.Identity), identity); err != nil { @@ -681,6 +668,31 @@ func (s *TRISA) sendAsync(tx *db.Transaction) (err error) { // Repair the beneficiary information if this is the first handshake if tx.State == pb.TransactionState_PENDING_SENT { + var validationError *protocol.Error + if validationError = validateIdentityPayload(identity, false); validationError != nil { + log.Warn().Str("message", validationError.Message).Msg("could not validate identity payload") + var reject *protocol.SecureEnvelope + if reject, err = envelope.Reject(validationError, envelope.WithEnvelopeID(tx.Envelope)); err != nil { + log.Error().Err(err).Msg("TRISA protocol error while creating reject envelope") + return fmt.Errorf("TRISA protocol error: %s", err) + } + + // Conduct the TRISA exchange, handle errors + if reject, err = peer.Transfer(reject); err != nil { + log.Warn().Err(err).Msg("could not perform TRISA exchange") + return fmt.Errorf("could not perform TRISA exchange: %s", err) + } + + // Check for the TRISA rejection error + rejectErr, isErr := envelope.Check(reject) + if !isErr || rejectErr == nil { + state := envelope.Status(reject) + log.Warn().Str("state", state.String()).Msg("unexpected TRISA response, expected reject envelope") + return fmt.Errorf("expected TRISA rejection error, received envelope in state %s", state.String()) + } + tx.SetState(pb.TransactionState_REJECTED) + } + var account *db.Account if account, err = tx.GetAccount(s.parent.db); err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -697,6 +709,13 @@ func (s *TRISA) sendAsync(tx *db.Transaction) (err error) { } } + // Fetch the signing key from the remote peer + var signKey *rsa.PublicKey + if signKey, err = s.parent.fetchSigningKey(peer); err != nil { + log.Warn().Err(err).Msg("could not fetch signing key from originator peer") + return fmt.Errorf("could not fetch signing key from originator peer: %s", err) + } + // Create the generic.Transaction for the payload transaction := &generic.Transaction{} if err = protojson.Unmarshal([]byte(tx.Transaction), transaction); err != nil {