diff --git a/crypto/encrypt_decrypt_test.go b/crypto/encrypt_decrypt_test.go index 5eae4c6c..27b52dc7 100644 --- a/crypto/encrypt_decrypt_test.go +++ b/crypto/encrypt_decrypt_test.go @@ -924,6 +924,32 @@ func TestEncryptDecryptKey(t *testing.T) { } } +func TestEncryptDecryptStreamTrimmedLines(t *testing.T) { + for _, material := range testMaterialForProfiles { + t.Run(material.profileName, func(t *testing.T) { + encHandle, _ := material.pgp.Encryption(). + Recipients(material.keyRingTestPublic). + SigningKeys(material.keyRingTestPrivate). + TrimLines(). + New() + decHandle, _ := material.pgp.Decryption(). + DecryptionKeys(material.keyRingTestPrivate). + VerificationKeys(material.keyRingTestPublic). + New() + testEncryptDecryptStreamWithExpected( + t, + []byte("text with \r \t \n trimmed\n \t"), + []byte("text with\n trimmed\n"), + nil, + encHandle, + decHandle, + len(material.keyRingTestPrivate.entities), + Bytes, + ) + }) + } +} + func TestEncryptCompressionApplied(t *testing.T) { const numReplicas = 10 builder := strings.Builder{} @@ -1169,6 +1195,28 @@ func testEncryptDecryptStream( decHandle PGPDecryption, numberOfSigsToVerify int, encoding int8, +) { + testEncryptDecryptStreamWithExpected( + t, + messageBytes, + messageBytes, + metadata, + encHandle, + decHandle, + numberOfSigsToVerify, + encoding, + ) +} + +func testEncryptDecryptStreamWithExpected( + t *testing.T, + messageBytes []byte, + expected []byte, + metadata *LiteralMetadata, + encHandle PGPEncryption, + decHandle PGPDecryption, + numberOfSigsToVerify int, + encoding int8, ) { messageReader := bytes.NewReader(messageBytes) var ciphertextBuf bytes.Buffer @@ -1200,7 +1248,7 @@ func testEncryptDecryptStream( if err != nil { t.Fatal("Expected no error while reading the decrypted data, got:", err) } - if !bytes.Equal(decryptedBytes, messageBytes) { + if !bytes.Equal(decryptedBytes, expected) { t.Fatalf("Expected the decrypted data to be %s got %s", string(decryptedBytes), string(messageBytes)) } if numberOfSigsToVerify > 0 { diff --git a/crypto/encryption_handle.go b/crypto/encryption_handle.go index b58851e8..b698713b 100644 --- a/crypto/encryption_handle.go +++ b/crypto/encryption_handle.go @@ -55,6 +55,9 @@ type encryptionHandle struct { // Is only considered if DetachedSignature is not set. PlainDetachedSignature bool IsUTF8 bool + // TrimLines trims each end of the line in the input message before encryption. + // Remove trailing spaces, carriage returns and tabs from each line (separated by \n characters). + TrimLines bool // ExternalSignature allows to include an external signature into // the encrypted message. ExternalSignature []byte @@ -332,6 +335,9 @@ func (eh *encryptionHandle) encryptingWriters(keys, data, detachedSignature Writ openpgp.NewCanonicalTextWriteCloser(messageWriter), ) } + if eh.TrimLines { + messageWriter = internal.NewTrimWriteCloser(messageWriter) + } return messageWriter, nil } diff --git a/crypto/encryption_handle_builder.go b/crypto/encryption_handle_builder.go index 9f8b18d9..d6e94bdf 100644 --- a/crypto/encryption_handle_builder.go +++ b/crypto/encryption_handle_builder.go @@ -148,6 +148,13 @@ func (ehb *EncryptionHandleBuilder) Utf8() *EncryptionHandleBuilder { return ehb } +// TrimLines enables that each line in the input message is trimmed before encryption. +// Trim removes trailing spaces, carriage returns and tabs from each line (separated by \n characters). +func (ehb *EncryptionHandleBuilder) TrimLines() *EncryptionHandleBuilder { + ehb.handle.TrimLines = true + return ehb +} + // DetachedSignature indicates that the message should be signed, // but the signature should not be included in the same pgp message as the input data. // Instead the detached signature is encrypted in a separate pgp message. diff --git a/crypto/sign_handle.go b/crypto/sign_handle.go index 70b8468f..2e0cf54b 100644 --- a/crypto/sign_handle.go +++ b/crypto/sign_handle.go @@ -16,10 +16,13 @@ import ( ) type signatureHandle struct { - SignKeyRing *KeyRing - SignContext *SigningContext - IsUTF8 bool - Detached bool + SignKeyRing *KeyRing + SignContext *SigningContext + IsUTF8 bool + Detached bool + // TrimLines trims each end of the line in the input message before encryption. + // Remove trailing spaces, carriage returns and tabs from each line (separated by \n characters). + TrimLines bool ArmorHeaders map[string]string profile SignProfile clock Clock @@ -87,6 +90,9 @@ func (sh *signatureHandle) SigningWriter(outputWriter Writer, encoding int8) (me openpgp.NewCanonicalTextWriteCloser(messageWriter), ) } + if sh.TrimLines { + messageWriter = internal.NewTrimWriteCloser(messageWriter) + } return messageWriter, nil } diff --git a/crypto/sign_handle_builder.go b/crypto/sign_handle_builder.go index f4460756..1d653504 100644 --- a/crypto/sign_handle_builder.go +++ b/crypto/sign_handle_builder.go @@ -68,6 +68,13 @@ func (shb *SignHandleBuilder) Utf8() *SignHandleBuilder { return shb } +// TrimLines enables that each line in the input message is trimmed before encryption. +// Trim removes trailing spaces, carriage returns and tabs from each line (separated by \n characters). +func (shb *SignHandleBuilder) TrimLines() *SignHandleBuilder { + shb.handle.TrimLines = true + return shb +} + // SignTime sets the internal clock to always return // the supplied unix time for signing instead of the device time. func (shb *SignHandleBuilder) SignTime(unixTime int64) *SignHandleBuilder { diff --git a/internal/trim_lines_writer.go b/internal/trim_lines_writer.go new file mode 100644 index 00000000..d78c57f1 --- /dev/null +++ b/internal/trim_lines_writer.go @@ -0,0 +1,89 @@ +package internal + +import ( + "bytes" + "io" +) + +func trim(p []byte) []byte { + return bytes.TrimRight(p, " \t\r") +} + +func NewTrimWriteCloser(internal io.WriteCloser) *TrimWriteCloser { + return NewTrimWriteCloserWithBufferSize(internal, 256) +} + +func NewTrimWriteCloserWithBufferSize(internal io.WriteCloser, size int) *TrimWriteCloser { + return &TrimWriteCloser{ + internal: internal, + whitespace: bytes.NewBuffer(make([]byte, 0, size)), + err: nil, + } +} + +type TrimWriteCloser struct { + internal io.WriteCloser + whitespace *bytes.Buffer + err error +} + +func (w *TrimWriteCloser) Write(p []byte) (n int, err error) { + n = len(p) + if w.err != nil { + return 0, err + } + for index := bytes.IndexByte(p, '\n'); index != -1; index = bytes.IndexByte(p, '\n') { + trimmedSuffixLine := trim(p[:index]) + bufferWhitespace := w.whitespace.Bytes() + if len(bufferWhitespace) > 0 { + if len(trimmedSuffixLine) != 0 { + if _, err = w.internal.Write(bufferWhitespace); err != nil { + w.err = err + return 0, err + } + } + w.whitespace.Reset() + } + if len(trimmedSuffixLine) < len(p[:index]) { + if _, err = w.internal.Write(trimmedSuffixLine); err != nil { + w.err = err + return index, err + } + if _, err = w.internal.Write([]byte("\n")); err != nil { + w.err = err + return index + 1, err + } + } else { + if _, err = w.internal.Write(p[:index+1]); err != nil { + w.err = err + return index + 1, err + } + } + p = p[index+1:] + } + + if len(p) > 0 { + nonWhitespace := trim(p) + if len(nonWhitespace) > 0 && w.whitespace.Len() > 0 { + if _, err = w.internal.Write(w.whitespace.Bytes()); err != nil { + w.err = err + return n - len(p), err + } + w.whitespace.Reset() + } + if _, err = w.internal.Write(nonWhitespace); err != nil { + w.err = err + return n - len(p), err + } + + if _, err = w.whitespace.Write(p[len(nonWhitespace):]); err != nil { + w.err = err + return n - len(p), err + } + } + return n, nil +} + +func (w *TrimWriteCloser) Close() error { + return w.internal.Close() +} diff --git a/internal/trim_lines_writer_test.go b/internal/trim_lines_writer_test.go new file mode 100644 index 00000000..0ac60e10 --- /dev/null +++ b/internal/trim_lines_writer_test.go @@ -0,0 +1,41 @@ +package internal + +import ( + "bytes" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func testTrimWriteCloser(t *testing.T, test string) { + testBytes := []byte(test) + for _, batchSize := range []int{1, 2, 3, 7, 11} { + var buff bytes.Buffer + w := &noOpWriteCloser{ + writer: &buff, + } + trimWriter := NewTrimWriteCloser(w) + for ind := 0; ind < len(testBytes); ind += batchSize { + end := ind + batchSize + if end > len(testBytes) { + end = len(testBytes) + } + if _, err := trimWriter.Write(testBytes[ind:end]); err != nil { + t.Fatal(err) + } + } + if err := trimWriter.Close(); err != nil { + t.Fatal(err) + } + assert.Equal(t, TrimEachLine(test), buff.String()) + } +} + +func TestTrimWriteCloser(t *testing.T) { + testTrimWriteCloser(t, "\n \t \r") + testTrimWriteCloser(t, "this is a test \n \t \n\n") + testTrimWriteCloser(t, "saf\n \t sdf\n \r\rsd \t fsdf\n") + testTrimWriteCloser(t, "BEGIN:VCARD\r\nVERSION:4.0\r\nFN;PREF=1: \r\nEND:VCARD") + testTrimWriteCloser(t, strings.Repeat("\r \nthis is a test \n \t \n\n)", 10000)) +} diff --git a/internal/utf8.go b/internal/utf8.go index 1802e359..e0f3c29b 100644 --- a/internal/utf8.go +++ b/internal/utf8.go @@ -164,19 +164,15 @@ func NewUtf8CheckWriteCloser(wrap io.WriteCloser) *Utf8CheckWriteCloser { } func (cw *Utf8CheckWriteCloser) Write(p []byte) (n int, err error) { - err = cw.check(p) - if err != nil { - return + if err = cw.check(p); err != nil { + return 0, err } - n, err = cw.internal.Write(p) - return + return cw.internal.Write(p) } func (cw *Utf8CheckWriteCloser) Close() (err error) { - err = cw.close() - if err != nil { - return + if err = cw.close(); err != nil { + return err } - err = cw.internal.Close() - return + return cw.internal.Close() }