diff --git a/examples/cmd/encrypt.go b/examples/cmd/encrypt.go index e81277188..f57841acb 100644 --- a/examples/cmd/encrypt.go +++ b/examples/cmd/encrypt.go @@ -17,6 +17,7 @@ var ( nanoFormat bool autoconfigure bool noKIDInKAO bool + noKIDInNano bool outputName string dataAttributes []string ) @@ -32,6 +33,7 @@ func init() { encryptCmd.Flags().BoolVar(&nanoFormat, "nano", false, "Output in nanoTDF format") encryptCmd.Flags().BoolVar(&autoconfigure, "autoconfigure", true, "Use attribute grants to select kases") encryptCmd.Flags().BoolVar(&noKIDInKAO, "no-kid-in-kao", false, "[deprecated] Disable storing key identifiers in TDF KAOs") + encryptCmd.Flags().BoolVar(&noKIDInNano, "no-kid-in-nano", true, "Disable storing key identifiers in nanoTDF KAS ResourceLocator") encryptCmd.Flags().StringVarP(&outputName, "output", "o", "sensitive.txt.tdf", "name or path of output file; - for stdout") ExamplesCmd.AddCommand(&encryptCmd) @@ -54,6 +56,10 @@ func encrypt(cmd *cobra.Command, args []string) error { if noKIDInKAO { opts = append(opts, sdk.WithNoKIDInKAO()) } + // double negative always gets me + if !noKIDInNano { + opts = append(opts, sdk.WithNoKIDInNano()) + } // Create new offline client client, err := newSDK() diff --git a/sdk/nanotdf.go b/sdk/nanotdf.go index bbde871a7..e87f00296 100644 --- a/sdk/nanotdf.go +++ b/sdk/nanotdf.go @@ -64,6 +64,10 @@ type NanoTDFHeader struct { ecdsaPolicyBindingS []byte } +func (header *NanoTDFHeader) GetKasURL() ResourceLocator { + return header.kasURL +} + // GetCipher -- get the cipher from the nano tdf header func (header *NanoTDFHeader) GetCipher() CipherMode { return header.sigCfg.cipher @@ -659,6 +663,12 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) // CreateNanoTDF - reads plain text from the given reader and saves it to the writer, subject to the given options func (s SDK) CreateNanoTDF(writer io.Writer, reader io.Reader, config NanoTDFConfig) (uint32, error) { + if writer == nil { + return 0, fmt.Errorf("writer is nil") + } + if reader == nil { + return 0, fmt.Errorf("reader is nil") + } var totalSize uint32 buf := bytes.Buffer{} size, err := buf.ReadFrom(reader) @@ -670,18 +680,30 @@ func (s SDK) CreateNanoTDF(writer io.Writer, reader io.Reader, config NanoTDFCon return 0, errors.New("exceeds max size for nano tdf") } - kasURL, err := config.kasURL.getURL() + kasURL, err := config.kasURL.GetURL() if err != nil { return 0, fmt.Errorf("config.kasURL failed:%w", err) } - - kasPublicKey, err := getECPublicKey(kasURL, s.dialOptions...) + if kasURL == "https://" || kasURL == "http://" { + return 0, errors.New("config.kasUrl is empty") + } + kasPublicKey, kid, err := getECPublicKeyKid(kasURL, s.dialOptions...) if err != nil { return 0, fmt.Errorf("getECPublicKey failed:%w", err) } - slog.Debug("CreateNanoTDF", slog.String("header size", kasPublicKey)) + // kid from kasPublicKey endpoint + slog.Debug("kasPublicKey", slog.String("kid", kid)) + + // update KAS URL with kid if set + if kid != "" && !s.nanoFeatures.noKID { + err = config.kasURL.setURLWithIdentifier(kasURL, kid) + if err != nil { + return 0, fmt.Errorf("getECPublicKey setURLWithIdentifier failed:%w", err) + } + } + config.kasPublicKey, err = ocrypto.ECPubKeyFromPem([]byte(kasPublicKey)) if err != nil { return 0, fmt.Errorf("ocrypto.ECPubKeyFromPem failed: %w", err) @@ -771,7 +793,7 @@ func (s SDK) ReadNanoTDFContext(ctx context.Context, writer io.Writer, reader io return 0, fmt.Errorf("readSeeker.Seek failed: %w", err) } - kasURL, err := header.kasURL.getURL() + kasURL, err := header.kasURL.GetURL() if err != nil { return 0, fmt.Errorf("readSeeker.Seek failed: %w", err) } @@ -844,17 +866,17 @@ func (s SDK) ReadNanoTDFContext(ctx context.Context, writer io.Writer, reader io return uint32(writeLen), nil } -// getECPublicKey - Contact the specified KAS and get its public key -func getECPublicKey(kasURL string, opts ...grpc.DialOption) (string, error) { +// getECPublicKeyKid - Contact the specified KAS and get its public key +func getECPublicKeyKid(kasURL string, opts ...grpc.DialOption) (string, string, error) { req := kas.PublicKeyRequest{} req.Algorithm = "ec:secp256r1" grpcAddress, err := getGRPCAddress(kasURL) if err != nil { - return "", err + return "", "", err } conn, err := grpc.Dial(grpcAddress, opts...) if err != nil { - return "", fmt.Errorf("error connecting to grpc service at %s: %w", kasURL, err) + return "", "", fmt.Errorf("error connecting to grpc service at %s: %w", kasURL, err) } defer conn.Close() @@ -864,10 +886,10 @@ func getECPublicKey(kasURL string, opts ...grpc.DialOption) (string, error) { resp, err := serviceClient.PublicKey(ctx, &req) if err != nil { - return "", fmt.Errorf("error making request to KAS: %w", err) + return "", "", fmt.Errorf("error making request to KAS: %w", err) } - return resp.GetPublicKey(), nil + return resp.GetPublicKey(), resp.GetKid(), nil } type requestBody struct { diff --git a/sdk/nanotdf_config_test.go b/sdk/nanotdf_config_test.go index 01d076cc7..39e19f013 100644 --- a/sdk/nanotdf_config_test.go +++ b/sdk/nanotdf_config_test.go @@ -45,7 +45,7 @@ func TestNanoTDFConfig2(t *testing.T) { t.Fatal(err) } - readKasURL, err := conf.kasURL.getURL() + readKasURL, err := conf.kasURL.GetURL() if err != nil { t.Fatal(err) } diff --git a/sdk/nanotdf_policy_test.go b/sdk/nanotdf_policy_test.go index df7a0de34..dd204261c 100644 --- a/sdk/nanotdf_policy_test.go +++ b/sdk/nanotdf_policy_test.go @@ -36,7 +36,7 @@ func TestNanoTDFPolicy(t *testing.T) { t.Fatal(err) } - fullURL, err := pb2.rp.url.getURL() + fullURL, err := pb2.rp.url.GetURL() if err != nil { t.Fatal(err) } diff --git a/sdk/nanotdf_test.go b/sdk/nanotdf_test.go index ac060e657..844e480d4 100644 --- a/sdk/nanotdf_test.go +++ b/sdk/nanotdf_test.go @@ -3,6 +3,7 @@ package sdk import ( "bytes" "encoding/gob" + "errors" "io" "os" "testing" @@ -239,3 +240,89 @@ func NotTestCreateNanoTDF(t *testing.T) { t.Fatal(err) } } + +func TestGetECPublicKeyKid(t *testing.T) { + var tests = []struct { + name string + kasURL string + dialOption grpc.DialOption + shouldFail bool + }{ + { + name: "Valid URL, Unreachable gRPC server", + kasURL: "http://localhost", + dialOption: grpc.WithBlock(), + shouldFail: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, _, err := getECPublicKeyKid(test.kasURL, test.dialOption) + if (err != nil) != test.shouldFail { + t.Errorf("Error does not match the expected outcome. Error: %v", err) + } + }) + } +} + +func TestCreateNanoTDF(t *testing.T) { + tests := []struct { + name string + writer io.Writer + reader io.Reader + config NanoTDFConfig + expectedError error + }{ + { + name: "Nil writer", + writer: nil, + reader: bytes.NewReader([]byte("test data")), + config: NanoTDFConfig{}, + expectedError: errors.New("writer is nil"), + }, + { + name: "Nil reader", + writer: new(bytes.Buffer), + reader: nil, + config: NanoTDFConfig{}, + expectedError: errors.New("reader is nil"), + }, + { + name: "Empty NanoTDFConfig", + writer: new(bytes.Buffer), + reader: bytes.NewReader([]byte("test data")), + config: NanoTDFConfig{}, + expectedError: errors.New("config.kasUrl is empty"), + }, + { + name: "KAS Identifier NanoTDFConfig", + writer: new(bytes.Buffer), + reader: bytes.NewReader([]byte("test data")), + config: NanoTDFConfig{ + kasURL: ResourceLocator{ + protocol: 1, + body: "kas.com", + identifier: "e0", + }, + }, + expectedError: errors.New("getECPublicKey failed:error connecting to grpc service at https://kas.com: grpc: no transport security set (use grpc.WithTransportCredentials(insecure.NewCredentials()) explicitly or set credentials)"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s SDK + _, err := s.CreateNanoTDF(tt.writer, tt.reader, tt.config) + if err != nil { + if tt.expectedError == nil { + t.Errorf("unexpected error: %v", err) + } else if err.Error() != tt.expectedError.Error() { + t.Errorf("expected error: %v, got: %v", tt.expectedError, err) + } + } else if tt.expectedError != nil { + t.Errorf("expected error: %v, got nil", tt.expectedError) + } + }) + } +} diff --git a/sdk/options.go b/sdk/options.go index 40df5fb1f..cd1cba416 100644 --- a/sdk/options.go +++ b/sdk/options.go @@ -31,6 +31,7 @@ type config struct { dpopKey *ocrypto.RsaKeyPair ipc bool tdfFeatures tdfFeatures + nanoFeatures nanoFeatures customAccessTokenSource auth.AccessTokenSource oauthAccessTokenSource oauth2.TokenSource coreConn *grpc.ClientConn @@ -42,6 +43,12 @@ type tdfFeatures struct { noKID bool } +// Options specific to NanoTDF protocol features +type nanoFeatures struct { + // noKID For backward compatibility, don't store the KID in the KAS ResourceLocator. + noKID bool +} + type PlatformConfiguration map[string]interface{} func (c *config) build() []grpc.DialOption { @@ -200,3 +207,11 @@ func WithCustomCoreConnection(conn *grpc.ClientConn) Option { c.coreConn = conn } } + +// WithNoKIDInNano disables storing the KID in the KAS ResourceLocator. +// This allows generating NanoTDF files that are compatible with legacy file formats (no KID). +func WithNoKIDInNano() Option { + return func(c *config) { + c.nanoFeatures.noKID = true + } +} diff --git a/sdk/options_test.go b/sdk/options_test.go new file mode 100644 index 000000000..44b957e87 --- /dev/null +++ b/sdk/options_test.go @@ -0,0 +1,39 @@ +package sdk + +import ( + "testing" +) + +func TestWithKIDInNano(t *testing.T) { + tests := []struct { + name string + kid bool + want bool + }{ + { + name: "noKID to be true", + kid: false, + want: true, + }, + { + name: "noKID to be false", + kid: true, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &config{} + + if !tt.kid { + option := WithNoKIDInNano() + option(c) + } + + if c.nanoFeatures.noKID != tt.want { + t.Errorf("WithKIDInNano() = %v, want %v", c.nanoFeatures.noKID, tt.want) + } + }) + } +} diff --git a/sdk/resource_locator.go b/sdk/resource_locator.go index 2bf746dcd..f9b9dc707 100644 --- a/sdk/resource_locator.go +++ b/sdk/resource_locator.go @@ -19,20 +19,51 @@ import ( // ResourceLocator - structure to contain a protocol + body comprising an URL type ResourceLocator struct { - protocol urlProtocol // See urlProtocol values below - body string // Body of url + protocol protocolHeader // See protocolHeader values below + // body URL without protocol scheme + body string + // identifier unique to this URL + identifier string } -// urlProtocol - shorthand for protocol prefix on fully qualified url -type urlProtocol uint8 +// protocolHeader - shorthand for protocol prefix on fully qualified url +// also specifies the optional resource identifier - current usage is a key identifier +type protocolHeader uint8 + +func (h protocolHeader) identifierLength() int { + switch h & 0xF0 { //nolint:nolintlint,exhaustive // overloaded + case identifierNone, urlProtocolHTTPS: + return identifierNoneLength + case identifier2Byte: + return identifier2ByteLength + case identifier8Byte: + return identifier8ByteLength + case identifier32Byte: + return identifier32ByteLength + default: + return 0 + } +} const ( - kMaxBodyLen int = 255 - kPrefixHTTPS string = "https://" - kPrefixHTTP string = "http://" - urlProtocolHTTP urlProtocol = 0 - urlProtocolHTTPS urlProtocol = 1 - // urlProtocolShared urlProtocol = 255 // TODO - how is this handled/parsed/rendered? + kMaxBodyLen int = 255 + // kPrefixHTTPS identifier field is size of 0 bytes (not present) + kPrefixHTTPS string = "https://" + kPrefixHTTP string = "http://" + urlProtocolHTTP protocolHeader = 0x0 + urlProtocolHTTPS protocolHeader = 0x1 + // urlProtocolUnreserved protocolHeader = 0x2 + // urlProtocolSharedRes protocolHeader = 0xf + // identifier + identifierNone protocolHeader = 0 << 4 + identifier2Byte protocolHeader = 1 << 4 + identifier8Byte protocolHeader = 2 << 4 + identifier32Byte protocolHeader = 3 << 4 + // length + identifierNoneLength int = 0 + identifier2ByteLength int = 2 + identifier8ByteLength int = 8 + identifier32ByteLength int = 32 ) func NewResourceLocator(url string) (*ResourceLocator, error) { @@ -49,6 +80,7 @@ func NewResourceLocator(url string) (*ResourceLocator, error) { func NewResourceLocatorFromReader(reader io.Reader) (*ResourceLocator, error) { rl := &ResourceLocator{} err := rl.readResourceLocator(reader) + if err != nil { return nil, err } @@ -57,7 +89,79 @@ func NewResourceLocatorFromReader(reader io.Reader) (*ResourceLocator, error) { // getLength - return the serialized length (in bytes) of this object func (rl ResourceLocator) getLength() uint16 { - return uint16(1 /* protocol byte */ + 1 /* length byte */ + len(rl.body)) + return uint16(1 /* protocol byte */ + 1 /* length byte */ + len(rl.body) + len(rl.identifier)) +} + +// setURL - Store a fully qualified protocol+body string into a ResourceLocator as a protocol value and a body string +func (rl *ResourceLocator) setURLWithIdentifier(url string, identifier string) error { + if identifier == "" { + return errors.New("identifier is empty") + } + lowerURL := strings.ToLower(url) + if strings.HasPrefix(lowerURL, kPrefixHTTPS) { + urlBody := url[len(kPrefixHTTPS):] + if len(urlBody) > kMaxBodyLen { + return errors.New("URL too long") + } + identifierLen := len(identifier) + switch { + case identifierLen == 0: + rl.protocol = urlProtocolHTTPS | identifierNone + case identifierLen >= 1 && identifierLen <= 2: + rl.protocol = urlProtocolHTTPS | identifier2Byte + case identifierLen >= 3 && identifierLen <= 8: + rl.protocol = urlProtocolHTTPS | identifier8Byte + case identifierLen >= 9 && identifierLen <= 32: + rl.protocol = urlProtocolHTTPS | identifier32Byte + default: + return fmt.Errorf("unsupported identifier length: %d", identifierLen) + } + rl.body = urlBody + rl.identifier = identifier + return nil + } + if strings.HasPrefix(lowerURL, kPrefixHTTP) { + urlBody := url[len(kPrefixHTTP):] + if len(urlBody) > kMaxBodyLen { + return errors.New("URL too long") + } + identifierLen := len(identifier) + padding := "" + switch { + case identifierLen == 0: + rl.protocol = urlProtocolHTTP | identifierNone + case identifierLen >= 1 && identifierLen <= identifier2ByteLength: + padding = strings.Repeat("\x00", identifier2ByteLength-identifierLen) + rl.protocol = urlProtocolHTTP | identifier2Byte + case identifierLen >= 3 && identifierLen <= identifier8ByteLength: + padding = strings.Repeat("\x00", identifier8ByteLength-identifierLen) + rl.protocol = urlProtocolHTTP | identifier8Byte + case identifierLen >= 9 && identifierLen <= identifier32ByteLength: + padding = strings.Repeat("\x00", identifier32ByteLength-identifierLen) + rl.protocol = urlProtocolHTTP | identifier32Byte + default: + return fmt.Errorf("unsupported identifier length: %d", identifierLen) + } + rl.body = urlBody + rl.identifier = identifier + padding + return nil + } + return errors.New("unsupported protocol with identifier: " + url) +} + +// GetIdentifier - identifier is returned if the correct protocol enum is set else error +func (rl ResourceLocator) GetIdentifier() (string, error) { + // read the identifier if it exists + switch rl.protocol & 0xf0 { + case identifierNone, urlProtocolHTTPS: + return "", fmt.Errorf("legacy resource locator identifer: %x", rl.protocol) + case identifier2Byte, identifier8Byte, identifier32Byte: + if rl.identifier == "" { + return "", fmt.Errorf("no resource locator identifer: %d", rl.protocol) + } + return rl.identifier, nil + } + return "", fmt.Errorf("unsupported identifer protocol: %x", rl.protocol) } // setURL - Store a fully qualified protocol+body string into a ResourceLocator as a protocol value and a body string @@ -84,15 +188,16 @@ func (rl *ResourceLocator) setURL(url string) error { return errors.New("unsupported protocol: " + url) } -// getURL - Retrieve a fully qualified protocol+body URL string from a ResourceLocator struct -func (rl ResourceLocator) getURL() (string, error) { - if rl.protocol == urlProtocolHTTPS { +// GetURL - Retrieve a fully qualified protocol+body URL string from a ResourceLocator struct +func (rl ResourceLocator) GetURL() (string, error) { + switch rl.protocol & 0xF { // use bitwise AND to get first 4 bits + case urlProtocolHTTPS, identifier2Byte, identifier8Byte, identifier32Byte: return kPrefixHTTPS + rl.body, nil - } - if rl.protocol == urlProtocolHTTP { + case urlProtocolHTTP: return kPrefixHTTP + rl.body, nil + default: + return "", fmt.Errorf("unsupported protocol: %x", rl.protocol) } - return "", fmt.Errorf("unsupported protocol: %d", rl.protocol) } // writeResourceLocator - writes the content of the resource locator to the supplied writer @@ -108,16 +213,24 @@ func (rl ResourceLocator) writeResourceLocator(writer io.Writer) error { if _, err := writer.Write([]byte(rl.body)); err != nil { return err } + // identifier + if len(rl.identifier) > 0 { + if _, err := writer.Write([]byte(rl.identifier)); err != nil { + return err + } + } return nil } +const protocolSharedRes = 0x4 + // readResourceLocator - read the encoded protocol and body string into a ResourceLocator func (rl *ResourceLocator) readResourceLocator(reader io.Reader) error { if err := binary.Read(reader, binary.BigEndian, &rl.protocol); err != nil { return errors.Join(Error("Error reading ResourceLocator protocol value"), err) } - if (rl.protocol != urlProtocolHTTP) && (rl.protocol != urlProtocolHTTPS) { // TODO - support 'shared' protocol? + if (rl.protocol&0x0f != urlProtocolHTTP) && (rl.protocol&0x0f != urlProtocolHTTPS) { return errors.New("Unsupported protocol: " + strconv.Itoa(int(rl.protocol))) } var lengthBody byte @@ -129,5 +242,32 @@ func (rl *ResourceLocator) readResourceLocator(reader io.Reader) error { return errors.Join(Error("Error reading ResourceLocator body value"), err) } rl.body = string(body) // TODO - normalize to lowercase? + // read the identifier if it exists + switch rl.protocol & 0xf0 { + case identifierNone, urlProtocolHTTPS: + // noop and exhaustive for linter + case identifier2Byte: + identifier := make([]byte, 2) //nolint:mnd // 2 bytes + if err := binary.Read(reader, binary.BigEndian, &identifier); err != nil { + return errors.New("Error reading ResourceLocator identifier value: " + err.Error()) + } + rl.identifier = string(identifier) + case identifier8Byte: + identifier := make([]byte, 8) //nolint:mnd // 8 bytes + if err := binary.Read(reader, binary.BigEndian, &identifier); err != nil { + return errors.New("Error reading ResourceLocator identifier value: " + err.Error()) + } + rl.identifier = string(identifier) + case identifier32Byte: + identifier := make([]byte, 32) //nolint:mnd // 32 bytes + if err := binary.Read(reader, binary.BigEndian, &identifier); err != nil { + return errors.New("Error reading ResourceLocator identifier value: " + err.Error()) + } + rl.identifier = string(identifier) + case protocolSharedRes: + // noop for legacy relative file references + default: + return errors.New("unsupported identifier protocol: " + strconv.Itoa(int(rl.protocol))) + } return nil } diff --git a/sdk/resource_locator_test.go b/sdk/resource_locator_test.go index af0769d87..e7cc84f0b 100644 --- a/sdk/resource_locator_test.go +++ b/sdk/resource_locator_test.go @@ -1,6 +1,7 @@ package sdk import ( + "bytes" "testing" ) @@ -42,3 +43,172 @@ func TestResourceLocatorBad(t *testing.T) { t.Fatal("expecting error") } } + +func TestReadResourceLocator(t *testing.T) { + tests := []struct { + n string + protocol protocolHeader + body string + identifier string + expectError bool + }{ + {"http plain", urlProtocolHTTP, "test.com", "", false}, + {"https plain", urlProtocolHTTPS, "test.com", "", false}, + {"https id2", urlProtocolHTTPS, "test.com", "id", false}, + {"https id32", urlProtocolHTTPS, "test.com", "id1234567890123456789012345678901", false}, + {"invalid protocol", 123, "test.com", "X", true}, + {"unknown protocol id2", identifierNone, "test.com", "i0", false}, + {"unknown protocol id2", identifier2Byte, "test.com", "X", true}, + {"unknown protocol id8", identifier8Byte, "test.com", "X", true}, + {"unknown protocol id32", identifier32Byte, "test.com", "X", true}, + } + + for _, test := range tests { + t.Run(test.n, func(t *testing.T) { + rl := &ResourceLocator{ + protocol: test.protocol, + body: test.body, + identifier: test.identifier, + } + buff := bytes.Buffer{} + if err := rl.writeResourceLocator(&buff); err != nil { + t.Fatal(err) + } + err := rl.readResourceLocator(&buff) + if (err != nil) != test.expectError { + t.Fatalf("expected error: %v, got %v, error: %v", test.expectError, err != nil, err) + } + if err == nil && rl.body != test.body { + t.Fatalf("expected body: %s, got %s", test.body, rl.body) + } + if err == nil && rl.identifier != test.identifier { + t.Fatalf("expected identifier: %s, got %s", test.identifier, rl.identifier) + } + }) + } +} + +func TestGetIdentifier(t *testing.T) { + tests := []struct { + n string + protocol protocolHeader + identifier string + expected string + expectError bool + }{ + {"none", identifierNone, "testId", "", true}, + {"https lonely", urlProtocolHTTPS, "testId", "", true}, + {"no id 2b", identifier2Byte, "", "", true}, + {"no body 8b", identifier8Byte, "", "", true}, + {"no body 32b", identifier32Byte, "", "", true}, + {"ok 2b", identifier2Byte, "testId", "testId", false}, + {"ok 8b", identifier8Byte, "testId", "testId", false}, + {"ok 32b", identifier32Byte, "testId", "testId", false}, + } + + for _, test := range tests { + t.Run(test.n, func(t *testing.T) { + rl := &ResourceLocator{ + protocol: test.protocol, + identifier: test.identifier, + } + got, err := rl.GetIdentifier() + if (err != nil) != test.expectError { + t.Fatalf("expected error: %v, got %v, error: %v", test.expectError, err != nil, err) + } + if got != test.expected { + t.Fatalf("expected identifier: %s, got %s", test.expected, got) + } + }) + } +} + +func TestProtocolHeaderIdentifierLength(t *testing.T) { + tests := []struct { + n string + header protocolHeader + length int + }{ + {"none-https", urlProtocolHTTPS, identifierNoneLength}, + {"none-none", identifierNone, identifierNoneLength}, + {"2b", identifier2Byte, identifier2ByteLength}, + {"8b", identifier8Byte, identifier8ByteLength}, + {"32b", identifier32Byte, identifier32ByteLength}, + {"relative", protocolHeader(255), 0}, + } + + for _, test := range tests { + t.Run(test.n, func(t *testing.T) { + got := test.header.identifierLength() + if got != test.length { + t.Fatalf("expected length: %d, got %d", test.length, got) + } + }) + } +} + +func TestNewResourceLocatorWithIdentifierFromReader(t *testing.T) { + setupResourceLocator := func(url, identifier string) ([]byte, error) { + locator := ResourceLocator{} + if err := locator.setURLWithIdentifier(url, identifier); err != nil { + return nil, err + } + var buf bytes.Buffer + if err := locator.writeResourceLocator(&buf); err != nil { + return nil, err + } + return buf.Bytes(), nil + } + // 2 Bytes + t0Data, err := setupResourceLocator("https://example.com", "t0") + if err != nil { + t.Fatal(err) + } + // 8 Bytes + t1Data, err := setupResourceLocator("https://example.com", "t1t1t1t1") + if err != nil { + t.Fatal(err) + } + // 32 Bytes + t2Data, err := setupResourceLocator("https://example.com", "t2t2t2t2t2t2t2t2t2t2t2t2t2t2t2t2") + if err != nil { + t.Fatal(err) + } + // 0 Bytes no identifier + t3Data, err := setupResourceLocator("https://example.com", "") + if err == nil { + // must error + t.Fatal(err) + } + + tests := []struct { + n string + data []byte + expectBody string + expectIdent string + expectError bool + }{ + {"id2", t0Data, "example.com", "t0", false}, + {"id8", t1Data, "example.com", "t1t1t1t1", false}, + {"id32", t2Data, "example.com", "t2t2t2t2t2t2t2t2t2t2t2t2t2t2t2t2", false}, + {"id0", t3Data, "example.com", "", true}, + } + + for _, test := range tests { + t.Run(test.n, func(t *testing.T) { + rl, err := NewResourceLocatorFromReader(bytes.NewReader(test.data)) + if test.expectError { + if err == nil { + t.Fatalf("expected error, got %v", rl) + } + return + } + if rl.body != test.expectBody { + t.Fatalf("expected body: %s, got %s", test.expectBody, rl.body) + } + if rl.identifier != test.expectIdent { + t.Fatalf("expected identifier: %s, got %s", test.expectIdent, rl.identifier) + } + }) + } +} diff --git a/service/internal/security/errors.go b/service/internal/security/errors.go index baa06d4d4..73f429376 100644 --- a/service/internal/security/errors.go +++ b/service/internal/security/errors.go @@ -1,14 +1,17 @@ package security const ( - ErrCertNotFound = Error("not found") - ErrCertificateEncode = Error("certificate encode error") - ErrPublicKeyMarshal = Error("public key marshal error") - ErrHSMUnexpected = Error("hsm unexpected") - ErrHSMDecrypt = Error("hsm decrypt error") - ErrHSMNotFound = Error("hsm unavailable") - ErrKeyConfig = Error("key configuration error") - ErrUnknownHashFunction = Error("unknown hash function") + ErrCertNotFound = Error("not found") + ErrNoKeys = Error("keys not found") + ErrKeyPairInfoNotFound = Error("key pair info not found") + ErrKeyPairInfoMalformed = Error("key pair info malformed") + ErrCertificateEncode = Error("certificate encode error") + ErrPublicKeyMarshal = Error("public key marshal error") + ErrHSMUnexpected = Error("hsm unexpected") + ErrHSMDecrypt = Error("hsm decrypt error") + ErrHSMNotFound = Error("hsm unavailable") + ErrKeyConfig = Error("key configuration error") + ErrUnknownHashFunction = Error("unknown hash function") ) type Error string diff --git a/service/internal/security/standard_crypto.go b/service/internal/security/standard_crypto.go index 2570a98b1..a180e7751 100644 --- a/service/internal/security/standard_crypto.go +++ b/service/internal/security/standard_crypto.go @@ -359,15 +359,15 @@ func (s StandardCrypto) GenerateNanoTDFSymmetricKey(kasKID string, ephemeralPubl ecKeys, ok := s.keys[AlgorithmECP256R1] if !ok || len(ecKeys) == 0 { - return nil, ErrCertNotFound + return nil, ErrNoKeys } k, ok := ecKeys[kasKID] if !ok { - return nil, ErrCertNotFound + return nil, ErrKeyPairInfoNotFound } ec, ok := k.(StandardECCrypto) if !ok { - return nil, ErrCertNotFound + return nil, ErrKeyPairInfoMalformed } symmetricKey, err := ocrypto.ComputeECDHKey([]byte(ec.ecPrivateKeyPem), ephemeralECDSAPublicKeyPEM) diff --git a/service/kas/access/rewrap.go b/service/kas/access/rewrap.go index 9906953cc..f448e6f64 100644 --- a/service/kas/access/rewrap.go +++ b/service/kas/access/rewrap.go @@ -393,20 +393,25 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, body *RequestBody, entity *en } func (p *Provider) nanoTDFRewrap(ctx context.Context, body *RequestBody, entity *entityInfo) (*kaspb.RewrapResponse, error) { - // TODO Lookup KID from request content - // Should this be in the locator or somewhere else? - kid, err := p.lookupKid(ctx, security.AlgorithmECP256R1) - if err != nil { - p.Logger.WarnContext(ctx, "failure to find default kid for ec", "err", err) - return nil, err400("bad request") - } headerReader := bytes.NewReader(body.KeyAccess.Header) header, _, err := sdk.NewNanoTDFHeaderFromReader(headerReader) if err != nil { return nil, fmt.Errorf("failed to parse NanoTDF header: %w", err) } - + // Lookup KID from nano header + kid, err := header.GetKasURL().GetIdentifier() + if err != nil { + p.Logger.InfoContext(ctx, "nanoTDFRewrap GetIdentifier", "kid", kid, "err", err) + // legacy nano with KID + kid, err = p.lookupKid(ctx, security.AlgorithmECP256R1) + if err != nil { + p.Logger.ErrorContext(ctx, "failure to find default kid for ec", "err", err) + return nil, err400("bad request") + } + p.Logger.InfoContext(ctx, "nanoTDFRewrap lookupKid", "kid", kid) + } + p.Logger.InfoContext(ctx, "nanoTDFRewrap", "kid", kid) ecCurve, err := header.ECCurve() if err != nil { return nil, fmt.Errorf("ECCurve failed: %w", err) diff --git a/test/tdf-roundtrips.bats b/test/tdf-roundtrips.bats index b1605ea91..af28f751f 100755 --- a/test/tdf-roundtrips.bats +++ b/test/tdf-roundtrips.bats @@ -58,11 +58,14 @@ @test "examples: roundtrip nanoTDF" { echo "[INFO] creating nanotdf file" - go run ./examples encrypt -o sensitive.txt.ntdf --nano "Hello NanoTDF" + go run ./examples encrypt -o sensitive.txt.ntdf --nano --no-kid-in-nano "Hello NanoTDF" + go run ./examples encrypt -o sensitive-kid.txt.ntdf --nano "Hello NanoTDF KID" echo "[INFO] decrypting nanotdf..." go run ./examples decrypt sensitive.txt.ntdf go run ./examples decrypt sensitive.txt.ntdf | grep "Hello NanoTDF" + go run ./examples decrypt sensitive-kid.txt.ntdf + go run ./examples decrypt sensitive-kid.txt.ntdf | grep "Hello NanoTDF KID" } @test "examples: legacy key support Z-TDF" {