Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ecdsa dependency #403

Merged
merged 6 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 33 additions & 28 deletions okta/jwt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
from Cryptodome.PublicKey import RSA
from ast import literal_eval
import jose.jwk as jwk
import jose.jwt as jwt
import os
import time
import uuid
import os

from ast import literal_eval
from Cryptodome.PublicKey import RSA
from jwcrypto.jwk import JWK, InvalidJWKType
from jwt import encode as jwt_encode


class JWT():
Expand Down Expand Up @@ -63,32 +64,36 @@ def get_PEM_JWK(private_key):
# if string repr, convert to dict object
if isinstance(private_key, str):
private_key = literal_eval(private_key)
# Create JWK using dict obj
my_jwk = jwk.construct(private_key, JWT.HASH_ALGORITHM)
# remove whitespace from key vaules
private_key = {k: ''.join(private_key[k].split()) for k in private_key}
# ensure private_key is JSON formatted
try:
json.loads(private_key)
except TypeError:
private_key = json.dumps(private_key)
try:
my_jwk = JWK.from_json(private_key)
except InvalidJWKType:
raise ValueError(
"JWK given is of the wrong type")
else: # it's a PEM
# check for filepath or explicit private key
if isinstance(private_key, (str, bytes, os.PathLike)) and os.path.exists(private_key):
# open file if exists and import key
# open file if exists and read
pem_file = open(private_key, 'r')
my_pem = RSA.import_key(pem_file.read())
private_key = pem_file.read()
pem_file.close()
else:
# convert given string to bytes and import key
private_key_bytes = bytes(private_key, 'ascii')
my_pem = RSA.import_key(private_key_bytes)

if not my_pem:
# return error if import failed
return (None, ValueError(
"RSA Private Key given is of the wrong type"))

if my_jwk: # was JWK provided
# get PEM using JWK
pem_bytes = my_jwk.to_pem(JWT.PEM_FORMAT)
my_pem = RSA.import_key(pem_bytes)
else: # was pem provided
# get JWK using PEM
my_jwk = jwk.construct(my_pem.export_key(), JWT.HASH_ALGORITHM)
# remove leading whitespaces from each line
my_pem = '\n'.join([line.strip() for line in private_key.splitlines()])
my_pem = bytes(my_pem, 'ascii')
try:
my_jwk = JWK.from_pem(my_pem)
except ValueError:
raise ValueError(
"RSA Private Key given is of the wrong type")

my_pem = my_jwk.export_to_pem(private_key=True, password=None)
my_pem = RSA.import_key(my_pem)

return (my_pem, my_jwk)

Expand All @@ -108,7 +113,7 @@ def create_token(org_url, client_id, private_key, kid=None):
str: Generated JWT
"""
# Generate PEM and JWK
my_pem, my_jwk = JWT.get_PEM_JWK(private_key)
my_pem, _ = JWT.get_PEM_JWK(private_key)
# Get current time and expiry time for token
issued_time = int(time.time())
expiry_time = issued_time + JWT.ONE_HOUR
Expand Down Expand Up @@ -142,5 +147,5 @@ def create_token(org_url, client_id, private_key, kid=None):
if "kid" in headers:
del headers["kid"]

token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM, headers=headers)
token = jwt_encode(claims, my_pem.export_key(), JWT.HASH_ALGORITHM, headers)
return token
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ pyyaml
xmltodict
yarl
pycryptodomex
python-jose[cryptography]
jwcrypto
pyjwt
aenum
pydash
flake8
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def get_version():
"xmltodict",
"yarl",
"pycryptodomex",
"python-jose",
"jwcrypto",
"pyjwt",
"aenum==3.1.11",
"pydash"
]
Expand Down
3 changes: 3 additions & 0 deletions tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,6 @@ def mock_next_link(self_url: URL):
KLElmMvzocvFaWKvup_a3vPaBi6y4K5kBiq60o-IDMGQ''',
"kid": "5ashWt3LP1zkYwMGbfMsVizRfx52QTyky4GTHd9MykE"
}

SAMPLE_INVALID_JWK = {'foo':'bar'}
SAMPLE_INVALID_RSA = 'foobar'
16 changes: 8 additions & 8 deletions tests/unit/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@


def test_private_key_with_kid_in_private_key(mocker):
mocked_encode = mocker.patch('jose.jwt.encode')
mocked_encode = mocker.patch('okta.jwt.jwt_encode')
JWT.create_token("test.com", "test-client-id", mocks.SAMPLE_JWK_WITH_KID)
expected_kid = mocks.SAMPLE_JWK_WITH_KID["kid"]
_, kwargs = mocked_encode.call_args
args = mocked_encode.call_args.args
mocked_encode.assert_called_once()
assert "kid" in kwargs["headers"]
assert kwargs["headers"]["kid"] == expected_kid
assert "kid" in args[-1]
assert args[-1]["kid"] == expected_kid


def test_private_key_with_kid_in_config(mocker):
mocked_encode = mocker.patch('jose.jwt.encode')
mocked_encode = mocker.patch('okta.jwt.jwt_encode')
expected_kid = "test-kid"
JWT.create_token("test.com", "test-client-id", mocks.SAMPLE_JWK, kid=expected_kid)
_, kwargs = mocked_encode.call_args
args = mocked_encode.call_args.args
mocked_encode.assert_called_once()
assert "kid" in kwargs["headers"]
assert kwargs["headers"]["kid"] == expected_kid
assert "kid" in args[-1]
assert args[-1]["kid"] == expected_kid
13 changes: 10 additions & 3 deletions tests/unit/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_private_key_PEM_JWK_dict(jwk_input):
generated_pem, generated_jwk = JWT.get_PEM_JWK(jwk_input)

assert generated_pem is not None and generated_jwk is not None
assert not generated_jwk.is_public()
assert generated_jwk.has_private


def test_private_key_PEM_JWK_file(fs):
Expand All @@ -24,11 +24,18 @@ def test_private_key_PEM_JWK_file(fs):
generated_pem, generated_jwk = JWT.get_PEM_JWK(file_path)

assert generated_pem is not None and generated_jwk is not None
assert not generated_jwk.is_public()
assert generated_jwk.has_private


def test_private_key_PEM_JWK_explicit_string():
generated_pem, generated_jwk = JWT.get_PEM_JWK(mocks.SAMPLE_RSA)

assert generated_pem is not None and generated_jwk is not None
assert not generated_jwk.is_public()
assert generated_jwk.has_private


@pytest.mark.parametrize("private_key",
[mocks.SAMPLE_INVALID_JWK, str(mocks.SAMPLE_INVALID_JWK), mocks.SAMPLE_INVALID_RSA])
def test_invalid_private_key_PEM_JWK(private_key):
with pytest.raises(ValueError):
generated_pem, generated_jwk = JWT.get_PEM_JWK(private_key)
Loading