diff --git a/internal/keyservice/gpg/keyfile.go b/internal/keyservice/gpg/keyfile.go index 8c496f5..1e058b4 100644 --- a/internal/keyservice/gpg/keyfile.go +++ b/internal/keyservice/gpg/keyfile.go @@ -11,7 +11,7 @@ import ( ) // keyfilePrivateKeys reads the given path and returns any private keys found. -func keyfilePrivateKeys(p string) ([]*packet.PrivateKey, error) { +func keyfilePrivateKeys(p string) ([]privateKeyfile, error) { f, err := os.Open(p) if err != nil { return nil, fmt.Errorf("couldn't open path %s: %v", p, err) @@ -23,7 +23,8 @@ func keyfilePrivateKeys(p string) ([]*packet.PrivateKey, error) { } switch { case fileInfo.Mode().IsRegular(): - return keysFromFile(f) + pk, err := keysFromFile(f) + return []privateKeyfile{*pk}, err case fileInfo.IsDir(): // enumerate files in directory dirents, err := f.ReadDir(0) @@ -31,7 +32,7 @@ func keyfilePrivateKeys(p string) ([]*packet.PrivateKey, error) { return nil, fmt.Errorf("couldn't read directory") } // get any private keys from each file - var privKeys []*packet.PrivateKey + var privKeys []privateKeyfile for _, dirent := range dirents { direntInfo, err := dirent.Info() if err != nil { @@ -49,7 +50,7 @@ func keyfilePrivateKeys(p string) ([]*packet.PrivateKey, error) { return nil, fmt.Errorf("couldn't get keys from file %s: %v", subPath, err) } - privKeys = append(privKeys, subPrivKeys...) + privKeys = append(privKeys, *subPrivKeys) } } return privKeys, nil @@ -59,9 +60,10 @@ func keyfilePrivateKeys(p string) ([]*packet.PrivateKey, error) { } // keysFromFile read a file and return any private keys found -func keysFromFile(f *os.File) ([]*packet.PrivateKey, error) { +func keysFromFile(f *os.File) (*privateKeyfile, error) { var err error var pkt packet.Packet + var uid *packet.UserId var privKeys []*packet.PrivateKey reader := packet.NewReader(f) for pkt, err = reader.Next(); err != io.EOF; pkt, err = reader.Next() { @@ -71,11 +73,20 @@ func keysFromFile(f *os.File) ([]*packet.PrivateKey, error) { if err != nil { return nil, fmt.Errorf("couldn't get next packet: %v", err) } - k, ok := pkt.(*packet.PrivateKey) - if !ok { + switch k := pkt.(type) { + case *packet.PrivateKey: + privKeys = append(privKeys, k) + case *packet.UserId: + uid = k + default: continue } - privKeys = append(privKeys, k) } - return privKeys, nil + if uid == nil { + uid = packet.NewUserId("n/a", "n/a", "n/a") + } + return &privateKeyfile{ + uid: uid, + keys: privKeys, + }, nil } diff --git a/internal/keyservice/gpg/keyservice.go b/internal/keyservice/gpg/keyservice.go index efdb1bd..a5792cf 100644 --- a/internal/keyservice/gpg/keyservice.go +++ b/internal/keyservice/gpg/keyservice.go @@ -14,7 +14,12 @@ import ( // PINEntryService provides an interface to talk to a pinentry program. type PINEntryService interface { - GetPGPPassphrase(string) ([]byte, error) + GetPGPPassphrase(string, string) ([]byte, error) +} + +type privateKeyfile struct { + uid *packet.UserId + keys []*packet.PrivateKey } // KeyService implements an interface for getting cryptographic keys from @@ -22,15 +27,14 @@ type PINEntryService interface { type KeyService struct { // cache passphrases used for decryption passphrases [][]byte - privKeys []*packet.PrivateKey + privKeys []privateKeyfile log *zap.Logger pinentry PINEntryService } // New returns a keyservice initialised with keys found at path. // Path can be a file or directory. -func New(l *zap.Logger, pe PINEntryService, - path string) (*KeyService, error) { +func New(l *zap.Logger, pe PINEntryService, path string) (*KeyService, error) { p, err := keyfilePrivateKeys(path) if err != nil { return nil, err @@ -67,47 +71,49 @@ func (g *KeyService) HaveKey(keygrips [][]byte) (bool, []byte, error) { func (g *KeyService) getKey(keygrip []byte) (*rsa.PrivateKey, error) { var pass []byte var err error - for _, k := range g.privKeys { - pubKey, ok := k.PublicKey.PublicKey.(*rsa.PublicKey) - if !ok { - continue - } - if !bytes.Equal(keygrip, keygripRSA(pubKey)) { - continue - } - if k.Encrypted { - // try existing passphrases - for _, pass := range g.passphrases { - if err = k.Decrypt(pass); err == nil { - g.log.Debug("decrypted using cached passphrase", - zap.String("fingerprint", k.KeyIdString())) - break + for _, pk := range g.privKeys { + for _, k := range pk.keys { + pubKey, ok := k.PublicKey.PublicKey.(*rsa.PublicKey) + if !ok { + continue + } + if !bytes.Equal(keygrip, keygripRSA(pubKey)) { + continue + } + if k.Encrypted { + // try existing passphrases + for _, pass := range g.passphrases { + if err = k.Decrypt(pass); err == nil { + g.log.Debug("decrypted using cached passphrase", + zap.String("fingerprint", k.KeyIdString())) + break + } } } - } - if k.Encrypted { - // ask for a passphrase - pass, err = g.pinentry.GetPGPPassphrase( - fmt.Sprintf("%X %X %X %X", k.Fingerprint[:5], k.Fingerprint[5:10], - k.Fingerprint[10:15], k.Fingerprint[15:])) - if err != nil { - return nil, fmt.Errorf("couldn't get passphrase for key %s: %v", - k.KeyIdString(), err) + if k.Encrypted { + // ask for a passphrase + pass, err = g.pinentry.GetPGPPassphrase( + fmt.Sprintf("%s (%s) <%s>", pk.uid.Name, pk.uid.Comment, pk.uid.Email), + fmt.Sprintf("%X %X %X %X", k.Fingerprint[:5], k.Fingerprint[5:10], k.Fingerprint[10:15], k.Fingerprint[15:])) + if err != nil { + return nil, fmt.Errorf("couldn't get passphrase for key %s: %v", + k.KeyIdString(), err) + } + g.passphrases = append(g.passphrases, pass) + if err = k.Decrypt(pass); err != nil { + return nil, fmt.Errorf("couldn't decrypt key %s: %v", + k.KeyIdString(), err) + } + g.log.Debug("decrypted using passphrase", + zap.String("fingerprint", k.KeyIdString())) } - g.passphrases = append(g.passphrases, pass) - if err = k.Decrypt(pass); err != nil { - return nil, fmt.Errorf("couldn't decrypt key %s: %v", + privKey, ok := k.PrivateKey.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("not an RSA key %s: %v", k.KeyIdString(), err) } - g.log.Debug("decrypted using passphrase", - zap.String("fingerprint", k.KeyIdString())) - } - privKey, ok := k.PrivateKey.(*rsa.PrivateKey) - if !ok { - return nil, fmt.Errorf("not an RSA key %s: %v", - k.KeyIdString(), err) + return privKey, nil } - return privKey, nil } return nil, nil } diff --git a/internal/keyservice/gpg/keyservice_test.go b/internal/keyservice/gpg/keyservice_test.go index f5739e9..a4c2646 100644 --- a/internal/keyservice/gpg/keyservice_test.go +++ b/internal/keyservice/gpg/keyservice_test.go @@ -44,7 +44,7 @@ func TestGetSigner(t *testing.T) { defer ctrl.Finish() var mockPES = mock.NewMockPINEntryService(ctrl) if tc.protected { - mockPES.EXPECT().GetPGPPassphrase(gomock.Any()). + mockPES.EXPECT().GetPGPPassphrase(gomock.Any(), gomock.Any()). Return([]byte("trustno1"), nil) } ks, err := gpg.New(log, mockPES, tc.path) diff --git a/internal/mock/mock_keyservice.go b/internal/mock/mock_keyservice.go index 9e29ead..c22e200 100644 --- a/internal/mock/mock_keyservice.go +++ b/internal/mock/mock_keyservice.go @@ -34,16 +34,16 @@ func (m *MockPINEntryService) EXPECT() *MockPINEntryServiceMockRecorder { } // GetPGPPassphrase mocks base method. -func (m *MockPINEntryService) GetPGPPassphrase(arg0 string) ([]byte, error) { +func (m *MockPINEntryService) GetPGPPassphrase(arg0, arg1 string) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPGPPassphrase", arg0) + ret := m.ctrl.Call(m, "GetPGPPassphrase", arg0, arg1) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // GetPGPPassphrase indicates an expected call of GetPGPPassphrase. -func (mr *MockPINEntryServiceMockRecorder) GetPGPPassphrase(arg0 interface{}) *gomock.Call { +func (mr *MockPINEntryServiceMockRecorder) GetPGPPassphrase(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPGPPassphrase", reflect.TypeOf((*MockPINEntryService)(nil).GetPGPPassphrase), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPGPPassphrase", reflect.TypeOf((*MockPINEntryService)(nil).GetPGPPassphrase), arg0, arg1) } diff --git a/internal/pinentry/pinentry.go b/internal/pinentry/pinentry.go index f32a3cf..24cc6e5 100644 --- a/internal/pinentry/pinentry.go +++ b/internal/pinentry/pinentry.go @@ -16,8 +16,8 @@ type SecurityKey interface { type PINEntry struct{} // GetPGPPassphrase uses pinentry to get the passphrase of the key with the -// given keygrip. -func (*PINEntry) GetPGPPassphrase(fingerprint string) ([]byte, error) { +// given fingerprint. +func (*PINEntry) GetPGPPassphrase(userID, fingerprint string) ([]byte, error) { p, err := pinentry.New() if err != nil { return []byte{}, fmt.Errorf("couldn't get pinentry client: %w", err) @@ -33,7 +33,8 @@ func (*PINEntry) GetPGPPassphrase(fingerprint string) ([]byte, error) { return nil, fmt.Errorf("couldn't set prompt on passphrase pinentry: %w", err) } - err = p.Set("desc", fmt.Sprintf("PGP key fingerprint: %s", fingerprint)) + err = p.Set("desc", fmt.Sprintf("UserID: %s, Fingerprint: %s", userID, + fingerprint)) if err != nil { return nil, fmt.Errorf("couldn't set desc on passphrase pinentry: %w", err)