Skip to content

Commit

Permalink
Merge pull request #2487 from Indicio-tech/feat/sd-jwt-implementation
Browse files Browse the repository at this point in the history
Feat/sd jwt implementation
  • Loading branch information
dbluhm committed Sep 20, 2023
2 parents d705ca2 + d40b6e6 commit 2930ac2
Show file tree
Hide file tree
Showing 10 changed files with 1,617 additions and 161 deletions.
50 changes: 49 additions & 1 deletion aries_cloudagent/messaging/valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,26 @@ def __init__(self):
)


class NonSDList(Regexp):
"""Validate NonSD List."""

EXAMPLE = [
"name",
"address",
"address.street_address",
"nationalities[1:3]",
]
PATTERN = r"[a-z0-9:\[\]_\.@?\(\)]"

def __init__(self):
"""Initialize the instance."""

super().__init__(
NonSDList.PATTERN,
error="Value {input} is not a valid NonSDList",
)


class JSONWebToken(Regexp):
"""Validate JSON Web Token."""

Expand All @@ -208,7 +228,7 @@ class JSONWebToken(Regexp):
"eyJhIjogIjAifQ."
"dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
)
PATTERN = r"^[-_a-zA-Z0-9]*\.[-_a-zA-Z0-9]*\.[-_a-zA-Z0-9]*$"
PATTERN = r"^[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$"

def __init__(self):
"""Initialize the instance."""
Expand All @@ -219,6 +239,28 @@ def __init__(self):
)


class SDJSONWebToken(Regexp):
"""Validate SD-JSON Web Token."""

EXAMPLE = (
"eyJhbGciOiJFZERTQSJ9."
"eyJhIjogIjAifQ."
"dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
"~WyJEM3BUSFdCYWNRcFdpREc2TWZKLUZnIiwgIkRFIl0"
"~WyJPMTFySVRjRTdHcXExYW9oRkd0aDh3IiwgIlNBIl0"
"~WyJkVmEzX1JlTGNsWTU0R1FHZm5oWlRnIiwgInVwZGF0ZWRfYXQiLCAxNTcwMDAwMDAwXQ"
)
PATTERN = r"^[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+(?:~[a-zA-Z0-9._-]+)*~?$"

def __init__(self):
"""Initialize the instance."""

super().__init__(
SDJSONWebToken.PATTERN,
error="Value {input} is not a valid SD-JSON Web token",
)


class DIDKey(Regexp):
"""Validate value against DID key specification."""

Expand Down Expand Up @@ -800,9 +842,15 @@ def __init__(
JWS_HEADER_KID_VALIDATE = JWSHeaderKid()
JWS_HEADER_KID_EXAMPLE = JWSHeaderKid.EXAMPLE

NON_SD_LIST_VALIDATE = NonSDList()
NON_SD_LIST_EXAMPLE = NonSDList().EXAMPLE

JWT_VALIDATE = JSONWebToken()
JWT_EXAMPLE = JSONWebToken.EXAMPLE

SD_JWT_VALIDATE = SDJSONWebToken()
SD_JWT_EXAMPLE = SDJSONWebToken.EXAMPLE

DID_KEY_VALIDATE = DIDKey()
DID_KEY_EXAMPLE = DIDKey.EXAMPLE

Expand Down
53 changes: 44 additions & 9 deletions aries_cloudagent/wallet/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import json
import logging
from typing import Any, Mapping, NamedTuple, Optional
from typing import Any, Mapping, Optional

from marshmallow import fields
from pydid import DIDUrl, Resource, VerificationMethod

from ..core.profile import Profile
from ..messaging.jsonld.error import BadJWSHeaderError, InvalidVerificationMethod
from ..messaging.jsonld.routes import SUPPORTED_VERIFICATION_METHOD_TYPES
from ..messaging.models.base import BaseModel, BaseModelSchema
from ..resolver.did_resolver import DIDResolver
from .default_verification_key_strategy import BaseVerificationKeyStrategy
from .base import BaseWallet
Expand Down Expand Up @@ -67,10 +69,11 @@ async def jwt_sign(
if not did:
raise ValueError("DID URL must be absolute")

if not headers.get("typ", None):
headers["typ"] = "JWT"
headers = {
**headers,
"alg": "EdDSA",
"typ": "JWT",
"kid": verification_method,
}
encoded_headers = dict_to_b64(headers)
Expand All @@ -88,13 +91,45 @@ async def jwt_sign(
return f"{encoded_headers}.{encoded_payload}.{sig}"


class JWTVerifyResult(NamedTuple):
class JWTVerifyResult(BaseModel):
"""Result from verify."""

headers: Mapping[str, Any]
payload: Mapping[str, Any]
valid: bool
kid: str
class Meta:
"""JWTVerifyResult metadata."""

schema_class = "JWTVerifyResultSchema"

def __init__(
self,
headers: Mapping[str, Any],
payload: Mapping[str, Any],
valid: bool,
kid: str,
):
"""Initialize a JWTVerifyResult instance."""
self.headers = headers
self.payload = payload
self.valid = valid
self.kid = kid


class JWTVerifyResultSchema(BaseModelSchema):
"""JWTVerifyResult schema."""

class Meta:
"""JWTVerifyResultSchema metadata."""

model_class = JWTVerifyResult

headers = fields.Dict(
required=True, metadata={"description": "Headers from verified JWT."}
)
payload = fields.Dict(
required=True, metadata={"description": "Payload from verified JWT"}
)
valid = fields.Bool(required=True)
kid = fields.Str(required=True, metadata={"description": "kid of signer"})
error = fields.Str(required=False, metadata={"description": "Error text"})


async def resolve_public_key_by_kid_for_verify(profile: Profile, kid: str) -> str:
Expand All @@ -120,7 +155,7 @@ async def resolve_public_key_by_kid_for_verify(profile: Profile, kid: str) -> st

async def jwt_verify(profile: Profile, jwt: str) -> JWTVerifyResult:
"""Verify a JWT and return the headers and payload."""
encoded_headers, encoded_payload, encoded_signiture = jwt.split(".", 3)
encoded_headers, encoded_payload, encoded_signature = jwt.split(".", 3)
headers = b64_to_dict(encoded_headers)
if "alg" not in headers or headers["alg"] != "EdDSA" or "kid" not in headers:
raise BadJWSHeaderError(
Expand All @@ -129,7 +164,7 @@ async def jwt_verify(profile: Profile, jwt: str) -> JWTVerifyResult:

payload = b64_to_dict(encoded_payload)
verification_method = headers["kid"]
decoded_signature = b64_to_bytes(encoded_signiture, urlsafe=True)
decoded_signature = b64_to_bytes(encoded_signature, urlsafe=True)

async with profile.session() as session:
verkey = await resolve_public_key_by_kid_for_verify(
Expand Down
111 changes: 110 additions & 1 deletion aries_cloudagent/wallet/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
INDY_RAW_PUBLIC_KEY_VALIDATE,
JWT_EXAMPLE,
JWT_VALIDATE,
SD_JWT_EXAMPLE,
SD_JWT_VALIDATE,
NON_SD_LIST_EXAMPLE,
NON_SD_LIST_VALIDATE,
IndyDID,
StrOrDictField,
Uri,
)
from ..protocols.coordinate_mediation.v1_0.route_manager import RouteManager
Expand All @@ -50,6 +55,7 @@
from ..resolver.base import ResolverError
from ..storage.error import StorageError, StorageNotFoundError
from ..wallet.jwt import jwt_sign, jwt_verify
from ..wallet.sd_jwt import sd_jwt_sign, sd_jwt_verify
from .base import BaseWallet
from .did_info import DIDInfo
from .did_method import KEY, SOV, DIDMethod, DIDMethods, HolderDefinedDid
Expand Down Expand Up @@ -171,14 +177,32 @@ class JWSCreateSchema(OpenAPISchema):
)


class SDJWSCreateSchema(JWSCreateSchema):
"""Request schema to create an sd-jws with a particular DID."""

non_sd_list = fields.List(
fields.Str(
required=False,
validate=NON_SD_LIST_VALIDATE,
metadata={"example": NON_SD_LIST_EXAMPLE},
)
)


class JWSVerifySchema(OpenAPISchema):
"""Request schema to verify a jws created from a DID."""

jwt = fields.Str(validate=JWT_VALIDATE, metadata={"example": JWT_EXAMPLE})


class SDJWSVerifySchema(OpenAPISchema):
"""Request schema to verify an sd-jws created from a DID."""

sd_jwt = fields.Str(validate=SD_JWT_VALIDATE, metadata={"example": SD_JWT_EXAMPLE})


class JWSVerifyResponseSchema(OpenAPISchema):
"""Response schema for verification result."""
"""Response schema for JWT verification result."""

valid = fields.Bool(required=True)
error = fields.Str(required=False, metadata={"description": "Error text"})
Expand All @@ -191,6 +215,25 @@ class JWSVerifyResponseSchema(OpenAPISchema):
)


class SDJWSVerifyResponseSchema(JWSVerifyResponseSchema):
"""Response schema for SD-JWT verification result."""

disclosures = fields.List(
fields.List(StrOrDictField()),
metadata={
"description": "Disclosure arrays associated with the SD-JWT",
"example": [
["fx1iT_mETjGiC-JzRARnVg", "name", "Alice"],
[
"n4-t3mlh8jSS6yMIT7QHnA",
"street_address",
{"_sd": ["kLZrLK7enwfqeOzJ9-Ss88YS3mhjOAEk9lr_ix2Heng"]},
],
],
},
)


class DIDEndpointSchema(OpenAPISchema):
"""Request schema to set DID endpoint; response schema to get DID endpoint."""

Expand Down Expand Up @@ -941,6 +984,44 @@ async def wallet_jwt_sign(request: web.BaseRequest):
return web.json_response(jws)


@docs(
tags=["wallet"], summary="Create a EdDSA sd-jws using did keys with a given payload"
)
@request_schema(SDJWSCreateSchema)
@response_schema(WalletModuleResponseSchema(), description="")
async def wallet_sd_jwt_sign(request: web.BaseRequest):
"""Request handler for sd-jws creation using did.
Args:
"headers": { ... },
"payload": { ... },
"did": "did:example:123",
"verificationMethod": "did:example:123#keys-1"
with did and verification being mutually exclusive.
"non_sd_list": []
"""
context: AdminRequestContext = request["context"]
body = await request.json()
did = body.get("did")
verification_method = body.get("verificationMethod")
headers = body.get("headers", {})
payload = body.get("payload", {})
non_sd_list = body.get("non_sd_list", [])

try:
sd_jws = await sd_jwt_sign(
context.profile, headers, payload, non_sd_list, did, verification_method
)
except ValueError as err:
raise web.HTTPBadRequest(reason="Bad did or verification method") from err
except WalletNotFoundError as err:
raise web.HTTPNotFound(reason=err.roll_up) from err
except WalletError as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response(sd_jws)


@docs(tags=["wallet"], summary="Verify a EdDSA jws using did keys with a given JWS")
@request_schema(JWSVerifySchema())
@response_schema(JWSVerifyResponseSchema(), 200, description="")
Expand Down Expand Up @@ -970,6 +1051,32 @@ async def wallet_jwt_verify(request: web.BaseRequest):
)


@docs(
tags=["wallet"],
summary="Verify a EdDSA sd-jws using did keys with a given SD-JWS with "
"optional key binding",
)
@request_schema(SDJWSVerifySchema())
@response_schema(SDJWSVerifyResponseSchema(), 200, description="")
async def wallet_sd_jwt_verify(request: web.BaseRequest):
"""Request handler for sd-jws validation using did.
Args:
"sd-jwt": { ... }
"""
context: AdminRequestContext = request["context"]
body = await request.json()
sd_jwt = body["sd_jwt"]
try:
result = await sd_jwt_verify(context.profile, sd_jwt)
except (BadJWSHeaderError, InvalidVerificationMethod) as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err
except ResolverError as err:
raise web.HTTPNotFound(reason=err.roll_up) from err

return web.json_response(result.serialize())


@docs(tags=["wallet"], summary="Query DID endpoint in wallet")
@querystring_schema(DIDQueryStringSchema())
@response_schema(DIDEndpointSchema, 200, description="")
Expand Down Expand Up @@ -1125,6 +1232,8 @@ async def register(app: web.Application):
web.post("/wallet/set-did-endpoint", wallet_set_did_endpoint),
web.post("/wallet/jwt/sign", wallet_jwt_sign),
web.post("/wallet/jwt/verify", wallet_jwt_verify),
web.post("/wallet/sd-jwt/sign", wallet_sd_jwt_sign),
web.post("/wallet/sd-jwt/verify", wallet_sd_jwt_verify),
web.get(
"/wallet/get-did-endpoint", wallet_get_did_endpoint, allow_head=False
),
Expand Down
Loading

0 comments on commit 2930ac2

Please sign in to comment.