diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml new file mode 100644 index 0000000..b54309f --- /dev/null +++ b/.github/workflows/main.yaml @@ -0,0 +1,38 @@ +name: GitHub Actions Workflow +on: [push] +jobs: + fmt: + name: Fmt + runs-on: ubuntu-latest + steps: + - name: Clone repository + uses: actions/checkout@v2 + - name: Set up Go + uses: actions/setup-go@v2 + with: + go-version: '1.14.4' + - name: Fmt + run: go fmt github.com/pavel-v-chernykh/keystore-go/... + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - name: Clone repository + uses: actions/checkout@v2 + - name: Lint + uses: golangci/golangci-lint-action@v1.2.1 + with: + args: --timeout=5m0s -c .golangci.yaml + version: v1.27 + test: + name: Test + runs-on: ubuntu-latest + steps: + - name: Clone repository + uses: actions/checkout@v2 + - name: Set up Go + uses: actions/setup-go@v2 + with: + go-version: '1.14.4' + - name: Test + run: go test -cover -count=1 -v github.com/pavel-v-chernykh/keystore-go/... diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..bec624b --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,28 @@ +modules-download-mode: readonly + +linters: + enable-all: true + disable: + - gochecknoglobals + - funlen + - goerr113 + +linters-settings: + gomnd: + settings: + mnd: + checks: case,condition,return + +issues: + exclude-rules: + - path: _test\.go + linters: + - testpackage + - maligned + - dupl + - linters: + - gosec + text: "G401: " + - linters: + - gosec + text: "G505: " diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4bb6855 --- /dev/null +++ b/Makefile @@ -0,0 +1,13 @@ +fmt: + go fmt github.com/pavel-v-chernykh/keystore-go/... + +lint: + golangci-lint run -c .golangci.yaml + +test: + go test -cover -count=1 -v ./... + +all: fmt lint test + +.PHONY: fmt lint test all +.DEFAULT_GOAL := all diff --git a/common.go b/common.go index b9c9655..2052c6b 100644 --- a/common.go +++ b/common.go @@ -25,6 +25,7 @@ func passwordBytes(password []byte) []byte { for _, b := range password { result = append(result, 0, b) } + return result } diff --git a/common_test.go b/common_test.go index 5abcd6a..aba895a 100644 --- a/common_test.go +++ b/common_test.go @@ -7,22 +7,27 @@ import ( ) func TestZeroing(t *testing.T) { - type zeroingItem struct { - input []byte - } - type zeroingTable []zeroingItem + type ( + zeroingItem struct { + input []byte + } + zeroingTable []zeroingItem + ) var table zeroingTable + for i := 0; i < 20; i++ { buf := make([]byte, 4096) if _, err := rand.Read(buf); err != nil { t.Errorf("read random bytes: %v", err) } + table = append(table, zeroingItem{input: buf}) } for _, tt := range table { zeroing(tt.input) + for i := range tt.input { if tt.input[i] != 0 { t.Errorf("fill input with zeros '%v'", tt.input) @@ -36,19 +41,25 @@ func TestPasswordBytes(t *testing.T) { input []byte output []byte } + var table []passwordBytesItem + for i := 0; i < 20; i++ { input := make([]byte, 1024) if _, err := rand.Read(input); err != nil { t.Errorf("read random bytes: %v", err) } + output := make([]byte, len(input)*2) + for j, k := 0, 0; j < len(output); j, k = j+2, k+1 { output[j] = 0 output[j+1] = input[k] } + table = append(table, passwordBytesItem{input: input, output: output}) } + for _, tt := range table { output := passwordBytes(tt.input) if !reflect.DeepEqual(output, tt.output) { diff --git a/decoder.go b/decoder.go index 9b76c24..161cca6 100644 --- a/decoder.go +++ b/decoder.go @@ -19,53 +19,67 @@ type keyStoreDecoder struct { func (ksd *keyStoreDecoder) readUint16() (uint16, error) { const blockSize = 2 + if _, err := io.ReadFull(ksd.r, ksd.b[:blockSize]); err != nil { return 0, fmt.Errorf("read uint16: %w", err) } + if _, err := ksd.md.Write(ksd.b[:blockSize]); err != nil { return 0, fmt.Errorf("update digest: %w", err) } + return byteOrder.Uint16(ksd.b[:blockSize]), nil } func (ksd *keyStoreDecoder) readUint32() (uint32, error) { const blockSize = 4 + if _, err := io.ReadFull(ksd.r, ksd.b[:blockSize]); err != nil { return 0, fmt.Errorf("read uint32: %w", err) } + if _, err := ksd.md.Write(ksd.b[:blockSize]); err != nil { return 0, fmt.Errorf("update digest: %w", err) } + return byteOrder.Uint32(ksd.b[:blockSize]), nil } func (ksd *keyStoreDecoder) readUint64() (uint64, error) { const blockSize = 8 + if _, err := io.ReadFull(ksd.r, ksd.b[:blockSize]); err != nil { return 0, fmt.Errorf("read uint64: %w", err) } + if _, err := ksd.md.Write(ksd.b[:blockSize]); err != nil { return 0, fmt.Errorf("update digest: %w", err) } + return byteOrder.Uint64(ksd.b[:blockSize]), nil } func (ksd *keyStoreDecoder) readBytes(num uint32) ([]byte, error) { var result []byte + for lenToRead := num; lenToRead > 0; { blockSize := lenToRead if blockSize > bufSize { blockSize = bufSize } + if _, err := io.ReadFull(ksd.r, ksd.b[:blockSize]); err != nil { return result, fmt.Errorf("read %d bytes: %w", num, err) } + result = append(result, ksd.b[:blockSize]...) lenToRead -= blockSize } + if _, err := ksd.md.Write(result); err != nil { return nil, fmt.Errorf("update digest: %w", err) } + return result, nil } @@ -74,15 +88,18 @@ func (ksd *keyStoreDecoder) readString() (string, error) { if err != nil { return "", fmt.Errorf("read length: %w", err) } + strBody, err := ksd.readBytes(uint32(strLen)) if err != nil { return "", fmt.Errorf("read body: %w", err) } + return string(strBody), nil } func (ksd *keyStoreDecoder) readCertificate(version uint32) (*Certificate, error) { var certType string + switch version { case version01: certType = defaultCertificateType @@ -91,22 +108,27 @@ func (ksd *keyStoreDecoder) readCertificate(version uint32) (*Certificate, error if err != nil { return nil, fmt.Errorf("read type: %w", err) } + certType = readCertType default: return nil, errors.New("got unknown version") } + certLen, err := ksd.readUint32() if err != nil { return nil, fmt.Errorf("read length: %w", err) } + certContent, err := ksd.readBytes(certLen) if err != nil { return nil, fmt.Errorf("read content: %w", err) } + certificate := Certificate{ Type: certType, Content: certContent, } + return &certificate, nil } @@ -115,30 +137,38 @@ func (ksd *keyStoreDecoder) readPrivateKeyEntry(version uint32, password []byte) if err != nil { return nil, fmt.Errorf("read creation timestamp: %w", err) } + length, err := ksd.readUint32() if err != nil { return nil, fmt.Errorf("read length: %w", err) } + encryptedPrivateKey, err := ksd.readBytes(length) if err != nil { return nil, fmt.Errorf("read encrypted private key: %w", err) } + certNum, err := ksd.readUint32() if err != nil { return nil, fmt.Errorf("read number of certificates: %w", err) } + chain := make([]Certificate, 0, certNum) + for i := uint32(0); i < certNum; i++ { cert, err := ksd.readCertificate(version) if err != nil { return nil, fmt.Errorf("read %d certificate: %w", i, err) } + chain = append(chain, *cert) } + decryptedPrivateKey, err := decrypt(encryptedPrivateKey, password) if err != nil { return nil, fmt.Errorf("decrypt content: %w", err) } + creationDateTime := millisecondsToTime(int64(creationTimeStamp)) privateKeyEntry := PrivateKeyEntry{ Entry: Entry{ @@ -147,6 +177,7 @@ func (ksd *keyStoreDecoder) readPrivateKeyEntry(version uint32, password []byte) PrivateKey: decryptedPrivateKey, CertificateChain: chain, } + return &privateKeyEntry, nil } @@ -155,10 +186,12 @@ func (ksd *keyStoreDecoder) readTrustedCertificateEntry(version uint32) (*Truste if err != nil { return nil, fmt.Errorf("read creation timestamp: %w", err) } + certificate, err := ksd.readCertificate(version) if err != nil { return nil, fmt.Errorf("read certificate: %w", err) } + creationDateTime := millisecondsToTime(int64(creationTimeStamp)) trustedCertificateEntry := TrustedCertificateEntry{ Entry: Entry{ @@ -166,6 +199,7 @@ func (ksd *keyStoreDecoder) readTrustedCertificateEntry(version uint32) (*Truste }, Certificate: *certificate, } + return &trustedCertificateEntry, nil } @@ -174,22 +208,26 @@ func (ksd *keyStoreDecoder) readEntry(version uint32, password []byte) (string, if err != nil { return "", nil, fmt.Errorf("read tag: %w", err) } + alias, err := ksd.readString() if err != nil { return "", nil, fmt.Errorf("read alias: %w", err) } + switch tag { case privateKeyTag: entry, err := ksd.readPrivateKeyEntry(version, password) if err != nil { return "", nil, fmt.Errorf("read private key entry: %w", err) } + return alias, entry, nil case trustedCertificateTag: entry, err := ksd.readTrustedCertificateEntry(version) if err != nil { return "", nil, fmt.Errorf("read trusted certificate entry: %w", err) } + return alias, entry, nil default: return "", nil, errors.New("got unknown entry tag") @@ -197,51 +235,63 @@ func (ksd *keyStoreDecoder) readEntry(version uint32, password []byte) (string, } // Decode reads keystore representation from r then decrypts and check signature using password -// It is strongly recommended to fill password slice with zero after usage +// It is strongly recommended to fill password slice with zero after usage. func Decode(r io.Reader, password []byte) (KeyStore, error) { ksd := keyStoreDecoder{ r: r, md: sha1.New(), } + passwordBytes := passwordBytes(password) defer zeroing(passwordBytes) + if _, err := ksd.md.Write(passwordBytes); err != nil { return nil, fmt.Errorf("update digest with password: %w", err) } + if _, err := ksd.md.Write(whitenerMessage); err != nil { return nil, fmt.Errorf("update digest with whitener message: %w", err) } + readMagic, err := ksd.readUint32() if err != nil { return nil, fmt.Errorf("read magic: %w", err) } + if readMagic != magic { return nil, errors.New("got invalid magic") } + version, err := ksd.readUint32() if err != nil { return nil, fmt.Errorf("read version: %w", err) } + entryNum, err := ksd.readUint32() if err != nil { return nil, fmt.Errorf("read number of entries: %w", err) } + keyStore := make(KeyStore, entryNum) + for i := uint32(0); i < entryNum; i++ { alias, entry, err := ksd.readEntry(version, password) if err != nil { return nil, fmt.Errorf("read %d entry: %w", i, err) } + keyStore[alias] = entry } - computedDigest := ksd.md.Sum(nil) actualDigest, err := ksd.readBytes(uint32(ksd.md.Size())) if err != nil { return nil, fmt.Errorf("read digest: %w", err) } + + computedDigest := ksd.md.Sum(nil) if !bytes.Equal(actualDigest, computedDigest) { return nil, errors.New("got invalid digest") } + return keyStore, nil } diff --git a/decoder_test.go b/decoder_test.go index 6b89f9c..e2b58fe 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -19,6 +19,7 @@ func TestReadUint16(t *testing.T) { err error hash [sha1.Size]byte } + var readUint32Table = func() []readUint16Item { var table []readUint16Item table = append(table, readUint16Item{ @@ -65,13 +66,16 @@ func TestReadUint16(t *testing.T) { r: bytes.NewReader(tt.input), md: sha1.New(), } + number, err := ksd.readUint16() if !reflect.DeepEqual(err, tt.err) { t.Errorf("invalid error '%v' '%v'", err, tt.err) } + if number != tt.number { t.Errorf("invalid number '%v' '%v'", number, tt.number) } + hash := ksd.md.Sum(nil) if !reflect.DeepEqual(hash, tt.hash[:]) { t.Errorf("invalid hash '%v' '%v'", hash, tt.hash) @@ -86,6 +90,7 @@ func TestReadUint32(t *testing.T) { err error hash [sha1.Size]byte } + var readUint32Table = func() []readUint32Item { var table []readUint32Item table = append(table, readUint32Item{ @@ -132,13 +137,16 @@ func TestReadUint32(t *testing.T) { r: bytes.NewReader(tt.input), md: sha1.New(), } + number, err := ksd.readUint32() if !reflect.DeepEqual(err, tt.err) { t.Errorf("invalid error '%v' '%v'", err, tt.err) } + if number != tt.number { t.Errorf("invalid uint32 '%v' '%v'", number, tt.number) } + hash := ksd.md.Sum(nil) if !reflect.DeepEqual(hash, tt.hash[:]) { t.Errorf("invalid hash '%v' '%v'", hash, tt.hash) @@ -153,7 +161,8 @@ func TestReadUint64(t *testing.T) { err error hash [sha1.Size]byte } - var readUint64Table = func() []readUint64Item { + + readUint64Table := func() []readUint64Item { var table []readUint64Item table = append(table, readUint64Item{ input: nil, @@ -174,8 +183,11 @@ func TestReadUint64(t *testing.T) { hash: sha1.Sum(nil), }) buf := make([]byte, 8) + var number uint64 = 10 + binary.BigEndian.PutUint64(buf, number) + table = append(table, readUint64Item{ input: buf, number: number, @@ -185,12 +197,14 @@ func TestReadUint64(t *testing.T) { buf = make([]byte, 8) number = 0 binary.BigEndian.PutUint64(buf, number) + table = append(table, readUint64Item{ input: buf, number: number, err: nil, hash: sha1.Sum(buf), }) + return table }() @@ -199,13 +213,16 @@ func TestReadUint64(t *testing.T) { r: bytes.NewReader(tt.input), md: sha1.New(), } + number, err := ksd.readUint64() if !reflect.DeepEqual(err, tt.err) { t.Errorf("invalid error '%v' '%v'", err, tt.err) } + if number != tt.number { t.Errorf("invalid uint64 '%v' '%v'", number, tt.number) } + hash := ksd.md.Sum(nil) if !reflect.DeepEqual(hash, tt.hash[:]) { t.Errorf("invalid hash '%v' '%v'", hash, tt.hash) @@ -220,7 +237,8 @@ func TestReadBytes(t *testing.T) { bytes []byte hash [sha1.Size]byte } - var readUint32Table = func() []readBytesItem { + + readUint32Table := func() []readBytesItem { var table []readBytesItem table = append(table, readBytesItem{ input: nil, @@ -245,14 +263,17 @@ func TestReadBytes(t *testing.T) { if _, err := rand.Read(buf); err != nil { t.Errorf("read random bytes: %v", err) } + return buf }() + table = append(table, readBytesItem{ input: buf, readLen: 9 * 1024, bytes: buf[:9*1024], hash: sha1.Sum(buf[:9*1024]), }) + return table }() @@ -261,13 +282,16 @@ func TestReadBytes(t *testing.T) { r: bytes.NewReader(tt.input), md: sha1.New(), } + bts, err := ksd.readBytes(tt.readLen) if err != nil { t.Errorf("got error '%v'", err) } + if !reflect.DeepEqual(bts, tt.bytes) { t.Errorf("invalid bytes '%v' '%v'", bts, tt.bytes) } + hash := ksd.md.Sum(nil) if !reflect.DeepEqual(hash, tt.hash[:]) { t.Errorf("invalid hash '%v' '%v'", hash, tt.hash) @@ -282,7 +306,8 @@ func TestReadString(t *testing.T) { err error hash [sha1.Size]byte } - var readUint32Table = func() []readStringItem { + + readUint32Table := func() []readStringItem { var table []readStringItem table = append(table, readStringItem{ input: nil, @@ -318,6 +343,7 @@ func TestReadString(t *testing.T) { err: nil, hash: sha1.Sum(buf), }) + return table }() @@ -326,13 +352,16 @@ func TestReadString(t *testing.T) { r: bytes.NewReader(tt.input), md: sha1.New(), } + str, err := ksd.readString() if !reflect.DeepEqual(err, tt.err) { t.Errorf("invalid error '%v' '%v'", err, tt.err) } + if str != tt.string { t.Errorf("invalid string '%v' '%v'", str, tt.string) } + hash := ksd.md.Sum(nil) if !reflect.DeepEqual(hash, tt.hash[:]) { t.Errorf("invalid hash '%v' '%v'", hash, tt.hash) @@ -430,13 +459,16 @@ func TestReadCertificate(t *testing.T) { r: bytes.NewReader(tt.input), md: sha1.New(), } + cert, err := ksd.readCertificate(tt.version) if !reflect.DeepEqual(err, tt.err) { t.Errorf("invalid error '%v' '%v'", err, tt.err) } + if cert != nil && tt.cert != nil && !reflect.DeepEqual(cert, tt.cert) { t.Errorf("invalid certificate '%v' '%v'", cert, tt.cert) } + hash := ksd.md.Sum(nil) if !reflect.DeepEqual(hash, tt.hash[:]) { t.Errorf("invalid hash '%v' '%v'", hash, tt.hash) diff --git a/encoder.go b/encoder.go index b80129c..beae5db 100644 --- a/encoder.go +++ b/encoder.go @@ -19,37 +19,49 @@ type keyStoreEncoder struct { func (kse *keyStoreEncoder) writeUint16(value uint16) error { const blockSize = 2 + byteOrder.PutUint16(kse.b[:blockSize], value) + if _, err := kse.w.Write(kse.b[:blockSize]); err != nil { return fmt.Errorf("write uint16: %w", err) } + if _, err := kse.md.Write(kse.b[:blockSize]); err != nil { return fmt.Errorf("update digest: %w", err) } + return nil } func (kse *keyStoreEncoder) writeUint32(value uint32) error { const blockSize = 4 + byteOrder.PutUint32(kse.b[:blockSize], value) + if _, err := kse.w.Write(kse.b[:blockSize]); err != nil { return fmt.Errorf("write uint32: %w", err) } + if _, err := kse.md.Write(kse.b[:blockSize]); err != nil { return fmt.Errorf("update digest: %w", err) } + return nil } func (kse *keyStoreEncoder) writeUint64(value uint64) error { const blockSize = 8 + byteOrder.PutUint64(kse.b[:blockSize], value) + if _, err := kse.w.Write(kse.b[:blockSize]); err != nil { return fmt.Errorf("write uint64: %w", err) } + if _, err := kse.md.Write(kse.b[:blockSize]); err != nil { return fmt.Errorf("update digest: %w", err) } + return nil } @@ -57,9 +69,11 @@ func (kse *keyStoreEncoder) writeBytes(value []byte) error { if _, err := kse.w.Write(value); err != nil { return fmt.Errorf("write %d bytes: %w", len(value), err) } + if _, err := kse.md.Write(value); err != nil { return fmt.Errorf("update digest: %w", err) } + return nil } @@ -68,12 +82,15 @@ func (kse *keyStoreEncoder) writeString(value string) error { if strLen > math.MaxUint16 { return fmt.Errorf("got string %d bytes long, max length is %d", strLen, math.MaxUint16) } + if err := kse.writeUint16(uint16(strLen)); err != nil { return fmt.Errorf("write length: %w", err) } + if err := kse.writeBytes([]byte(value)); err != nil { return fmt.Errorf("write body: %w", err) } + return nil } @@ -81,16 +98,20 @@ func (kse *keyStoreEncoder) writeCertificate(cert Certificate) error { if err := kse.writeString(cert.Type); err != nil { return fmt.Errorf("write type: %w", err) } + certLen := uint64(len(cert.Content)) if certLen > math.MaxUint32 { return fmt.Errorf("got certificate %d bytes long, max length is %d", certLen, math.MaxUint32) } + if err := kse.writeUint32(uint32(certLen)); err != nil { return fmt.Errorf("write length: %w", err) } + if err := kse.writeBytes(cert.Content); err != nil { return fmt.Errorf("write content: %w", err) } + return nil } @@ -98,38 +119,48 @@ func (kse *keyStoreEncoder) writePrivateKeyEntry(alias string, pke *PrivateKeyEn if err := kse.writeUint32(privateKeyTag); err != nil { return fmt.Errorf("write tag: %w", err) } + if err := kse.writeString(alias); err != nil { return fmt.Errorf("write alias: %w", err) } + if err := kse.writeUint64(uint64(timeToMilliseconds(pke.CreationTime))); err != nil { return fmt.Errorf("write creation timestamp: %w", err) } + encryptedContent, err := encrypt(kse.rand, pke.PrivateKey, password) if err != nil { return fmt.Errorf("encrypt content: %w", err) } + length := uint64(len(encryptedContent)) if length > math.MaxUint32 { return fmt.Errorf("got encrypted content %d bytes long, max length is %d", length, math.MaxUint32) } + if err := kse.writeUint32(uint32(length)); err != nil { return fmt.Errorf("filed to write length: %w", err) } + if err := kse.writeBytes(encryptedContent); err != nil { return fmt.Errorf("write content: %w", err) } + certNum := uint64(len(pke.CertificateChain)) if certNum > math.MaxUint32 { return fmt.Errorf("got certificate chain %d entries long, max number of entries is %d", certNum, math.MaxUint32) } + if err := kse.writeUint32(uint32(certNum)); err != nil { return fmt.Errorf("write number of certificates: %w", err) } + for i, cert := range pke.CertificateChain { if err := kse.writeCertificate(cert); err != nil { return fmt.Errorf("write %d certificate: %w", i, err) } } + return nil } @@ -137,41 +168,49 @@ func (kse *keyStoreEncoder) writeTrustedCertificateEntry(alias string, tce *Trus if err := kse.writeUint32(trustedCertificateTag); err != nil { return fmt.Errorf("write tag: %w", err) } + if err := kse.writeString(alias); err != nil { return fmt.Errorf("write alias: %w", err) } + if err := kse.writeUint64(uint64(timeToMilliseconds(tce.CreationTime))); err != nil { return fmt.Errorf("write creation timestamp: %w", err) } + if err := kse.writeCertificate(tce.Certificate); err != nil { return fmt.Errorf("write certificate: %w", err) } + return nil } // Encode encrypts and signs keystore using password and writes its representation into w -// It is strongly recommended to fill password slice with zero after usage +// It is strongly recommended to fill password slice with zero after usage. func Encode(w io.Writer, ks KeyStore, password []byte) error { return EncodeWithRand(rand.Reader, w, ks, password) } // Encode encrypts and signs keystore using password and writes its representation into w // Random bytes are read from rand, which must be a cryptographically secure source of randomness -// It is strongly recommended to fill password slice with zero after usage +// It is strongly recommended to fill password slice with zero after usage. func EncodeWithRand(rand io.Reader, w io.Writer, ks KeyStore, password []byte) error { kse := keyStoreEncoder{ w: w, md: sha1.New(), rand: rand, } + passwordBytes := passwordBytes(password) defer zeroing(passwordBytes) + if _, err := kse.md.Write(passwordBytes); err != nil { return fmt.Errorf("update digest with password: %w", err) } + if _, err := kse.md.Write(whitenerMessage); err != nil { return fmt.Errorf("update digest with whitener message: %w", err) } + if err := kse.writeUint32(magic); err != nil { return fmt.Errorf("write magic: %w", err) } @@ -179,9 +218,11 @@ func EncodeWithRand(rand io.Reader, w io.Writer, ks KeyStore, password []byte) e if err := kse.writeUint32(version02); err != nil { return fmt.Errorf("write version: %w", err) } + if err := kse.writeUint32(uint32(len(ks))); err != nil { return fmt.Errorf("write number of entries: %w", err) } + for alias, entry := range ks { switch typedEntry := entry.(type) { case *PrivateKeyEntry: @@ -196,8 +237,10 @@ func EncodeWithRand(rand io.Reader, w io.Writer, ks KeyStore, password []byte) e return errors.New("got invalid entry") } } + if err := kse.writeBytes(kse.md.Sum(nil)); err != nil { return fmt.Errorf("write digest: %w", err) } + return nil } diff --git a/go.mod b/go.mod index 4ee214c..9e272a5 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,3 @@ module github.com/pavel-v-chernykh/keystore-go -go 1.13 - -require github.com/magefile/mage v1.9.0 +go 1.14 diff --git a/go.sum b/go.sum index 7eb0fc9..e69de29 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +0,0 @@ -github.com/magefile/mage v1.9.0 h1:t3AU2wNwehMCW97vuqQLtw6puppWXHO+O2MHo5a50XE= -github.com/magefile/mage v1.9.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= diff --git a/keyprotector.go b/keyprotector.go index 454d63f..bca71d2 100644 --- a/keyprotector.go +++ b/keyprotector.go @@ -21,20 +21,25 @@ type keyInfo struct { func decrypt(data []byte, password []byte) ([]byte, error) { var keyInfo keyInfo + asn1Rest, err := asn1.Unmarshal(data, &keyInfo) if err != nil { return nil, fmt.Errorf("unmarshal encrypted key: %w", err) } + if len(asn1Rest) > 0 { return nil, errors.New("got extra data in encrypted key") } + if !keyInfo.Algo.Algorithm.Equal(supportedPrivateKeyAlgorithmOid) { return nil, errors.New("got unsupported private key encryption algorithm") } md := sha1.New() + passwordBytes := passwordBytes(password) defer zeroing(passwordBytes) + salt := make([]byte, saltLen) copy(salt, keyInfo.PrivateKey) encryptedKeyLen := len(keyInfo.PrivateKey) - saltLen - md.Size() @@ -50,13 +55,16 @@ func decrypt(data []byte, password []byte) ([]byte, error) { xorKey := make([]byte, encryptedKeyLen) digest := salt + for i, xorOffset := 0, 0; i < numRounds; i++ { if _, err := md.Write(passwordBytes); err != nil { return nil, fmt.Errorf("update digest with password on %d round: %w", i, err) } + if _, err := md.Write(digest); err != nil { return nil, fmt.Errorf("update digest with digest from previous round on %d round: %w", i, err) } + digest = md.Sum(nil) md.Reset() copy(xorKey[xorOffset:], digest) @@ -71,9 +79,11 @@ func decrypt(data []byte, password []byte) ([]byte, error) { if _, err := md.Write(passwordBytes); err != nil { return nil, fmt.Errorf("update digest with password: %w", err) } + if _, err := md.Write(plainKey); err != nil { return nil, fmt.Errorf("update digest with plain key: %w", err) } + digest = md.Sum(nil) md.Reset() @@ -81,13 +91,16 @@ func decrypt(data []byte, password []byte) ([]byte, error) { if !bytes.Equal(digest, keyInfo.PrivateKey[digestOffset:digestOffset+len(digest)]) { return nil, errors.New("got invalid digest") } + return plainKey, nil } func encrypt(rand io.Reader, plainKey []byte, password []byte) ([]byte, error) { md := sha1.New() + passwordBytes := passwordBytes(password) defer zeroing(passwordBytes) + plainKeyLen := len(plainKey) numRounds := plainKeyLen / md.Size() @@ -103,13 +116,16 @@ func encrypt(rand io.Reader, plainKey []byte, password []byte) ([]byte, error) { xorKey := make([]byte, plainKeyLen) digest := salt + for i, xorOffset := 0, 0; i < numRounds; i++ { if _, err := md.Write(passwordBytes); err != nil { return nil, fmt.Errorf("update digest with password on %d round: %w", i, err) } + if _, err := md.Write(digest); err != nil { return nil, fmt.Errorf("update digest with digest from prevous round on %d round: %w", i, err) } + digest = md.Sum(nil) md.Reset() copy(xorKey[xorOffset:], digest) @@ -131,12 +147,15 @@ func encrypt(rand io.Reader, plainKey []byte, password []byte) ([]byte, error) { if _, err := md.Write(passwordBytes); err != nil { return nil, fmt.Errorf("update digest with password: %w", err) } + if _, err := md.Write(plainKey); err != nil { return nil, fmt.Errorf("udpate digest with plain key: %w", err) } + digest = md.Sum(nil) md.Reset() copy(encryptedKey[encryptedKeyOffset:], digest) + keyInfo := keyInfo{ Algo: pkix.AlgorithmIdentifier{ Algorithm: supportedPrivateKeyAlgorithmOid, @@ -144,9 +163,11 @@ func encrypt(rand io.Reader, plainKey []byte, password []byte) ([]byte, error) { }, PrivateKey: encryptedKey, } + encodedKey, err := asn1.Marshal(keyInfo) if err != nil { return nil, fmt.Errorf("marshal encrypted key: %w", err) } + return encodedKey, nil } diff --git a/keystore.go b/keystore.go index 4e248f3..4d51b36 100644 --- a/keystore.go +++ b/keystore.go @@ -4,28 +4,28 @@ import ( "time" ) -// KeyStore is a mapping of alias to pointer to PrivateKeyEntry or TrustedCertificateEntry +// KeyStore is a mapping of alias to pointer to PrivateKeyEntry or TrustedCertificateEntry. type KeyStore map[string]interface{} -// Certificate describes type of certificate +// Certificate describes type of certificate. type Certificate struct { Type string Content []byte } -// Entry is a basis of entries types supported by keystore +// Entry is a basis of entries types supported by keystore. type Entry struct { CreationTime time.Time } -// PrivateKeyEntry is an entry for private keys and associated certificates +// PrivateKeyEntry is an entry for private keys and associated certificates. type PrivateKeyEntry struct { Entry PrivateKey []byte CertificateChain []Certificate } -// TrustedCertificateEntry is an entry for certificates only +// TrustedCertificateEntry is an entry for certificates only. type TrustedCertificateEntry struct { Entry Certificate Certificate diff --git a/magefile.go b/magefile.go deleted file mode 100644 index b39cb50..0000000 --- a/magefile.go +++ /dev/null @@ -1,42 +0,0 @@ -// +build mage - -package main - -import ( - "fmt" - - "github.com/magefile/mage/sh" -) - -var Default = All - -func Fmt() error { - if err := sh.Run("go", "fmt", "github.com/pavel-v-chernykh/keystore-go/..."); err != nil { - return fmt.Errorf("go fmt: %w", err) - } - return nil -} - -func Test() error { - if err := sh.Run("go", "test", "-cover", "-count=1", "-v", "./..."); err != nil { - return fmt.Errorf("go test: %w", err) - } - return nil -} - -func Lint() error { - if err := sh.Run("golangci-lint", "run"); err != nil { - return fmt.Errorf("golangci-lint run: %w", err) - } - return nil -} - -func All() error { - if err := Fmt(); err != nil { - return err - } - if err := Test(); err != nil { - return err - } - return Lint() -}