diff --git a/lib/inferno/dsl/auth_info.rb b/lib/inferno/dsl/auth_info.rb index 1f89f3385..08f061c45 100644 --- a/lib/inferno/dsl/auth_info.rb +++ b/lib/inferno/dsl/auth_info.rb @@ -1,4 +1,5 @@ require_relative '../entities/attributes' +require_relative 'jwks' module Inferno module DSL @@ -168,6 +169,141 @@ def add_to_client(client) client.set_bearer_token(access_token) end + + # @private + def need_to_refresh? + return false if access_token.blank? || (!backend_services? && refresh_token.blank?) + + return true if expires_in.blank? + + issue_time.to_i + expires_in.to_i - DateTime.now.to_i < 60 + end + + # @private + def able_to_refresh? + token_url.present? && (backend_services? || refresh_token.present?) + end + + # @private + def backend_services? + auth_type == 'backend_services' + end + + # @private + def oauth2_refresh_params + case auth_type + when 'public' + public_auth_refresh_params + when 'symmetric' + symmetric_auth_refresh_params + when 'asymmetric' + asymmetric_auth_refresh_params + when 'backend_services' + backend_services_auth_refresh_params + end + end + + # @private + def symmetric_auth_refresh_params + { + 'grant_type' => 'refresh_token', + 'refresh_token' => refresh_token + } + end + + # @private + def public_auth_refresh_params + symmetric_auth_refresh_params.merge('client_id' => client_id) + end + + # @private + def asymmetric_auth_refresh_params + symmetric_auth_refresh_params.merge( + 'client_assertion_type' => 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer', + 'client_assertion' => client_assertion + ) + end + + # @private + def backend_services_auth_refresh_params + { + 'grant_type' => 'client_credentials', + 'scope' => requested_scopes, + 'client_assertion_type' => 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer', + 'client_assertion' => client_assertion + } + end + + # @private + def oauth2_refresh_headers + base_headers = { 'Content-Type' => 'application/x-www-form-urlencoded' } + + return base_headers unless auth_type == 'symmetric' + + credentials = "#{client_id}:#{client_secret}" + + base_headers.merge( + 'Authorization' => "Basic #{Base64.strict_encode64(credentials)}" + ) + end + + # @private + def private_key + @private_key ||= JWKS.jwks(user_jwks: jwks) + .select { |key| key[:key_ops]&.include?('sign') } + .select { |key| key[:alg] == encryption_algorithm } + .find { |key| !kid || key[:kid] == kid } + end + + # @private + def signing_key + if private_key.nil? + raise Inferno::Exceptions::AssertionException, + "No signing key found for inputs: encryption method = '#{encryption_algorithm}' and kid = '#{kid}'" + end + + @private_key.signing_key + end + + # @private + def auth_jwt_header + { + 'alg' => encryption_algorithm, + 'kid' => private_key['kid'], + 'typ' => 'JWT' + } + end + + # @private + def auth_jwt_claims + { + 'iss' => client_id, + 'sub' => client_id, + 'aud' => token_url, + 'exp' => 5.minutes.from_now.to_i, + 'jti' => SecureRandom.hex(32) + } + end + + # @private + def client_assertion + JWT.encode auth_jwt_claims, signing_key, encryption_algorithm, auth_jwt_header + end + + # @private + def update_from_response_body(request) + token_response_body = JSON.parse(request.response_body) + + expires_in = token_response_body['expires_in'].is_a?(Numeric) ? token_response_body['expires_in'] : nil + + self.access_token = token_response_body['access_token'] + self.refresh_token = token_response_body['refresh_token'] if token_response_body['refresh_token'].present? + self.expires_in = expires_in + self.issue_time = DateTime.now + + add_to_client(client) + self + end end end end diff --git a/lib/inferno/dsl/fhir_client.rb b/lib/inferno/dsl/fhir_client.rb index 5d6ffec2d..2e79d3376 100644 --- a/lib/inferno/dsl/fhir_client.rb +++ b/lib/inferno/dsl/fhir_client.rb @@ -347,7 +347,7 @@ def store_request_and_refresh_token(client, name, tags, &block) # @private def perform_refresh(client) - credentials = client.oauth_credentials + credentials = client.auth_info || client.oauth_credentials post( credentials.token_url, @@ -363,7 +363,7 @@ def perform_refresh(client) Inferno::Repositories::SessionData.new.save( name: credentials.name, value: credentials, - type: 'oauth_credentials', + type: credentials.is_a?(Inferno::DSL::AuthInfo) ? 'auth_info' : 'oauth_credentials', test_session_id: ) end diff --git a/lib/inferno/dsl/jwks.json b/lib/inferno/dsl/jwks.json new file mode 100644 index 000000000..eb46118d1 --- /dev/null +++ b/lib/inferno/dsl/jwks.json @@ -0,0 +1,59 @@ +{ + "keys": [ + { + "kty": "EC", + "crv": "P-384", + "x": "JQKTsV6PT5Szf4QtDA1qrs0EJ1pbimQmM2SKvzOlIAqlph3h1OHmZ2i7MXahIF2C", + "y": "bRWWQRJBgDa6CTgwofYrHjVGcO-A7WNEnu4oJA5OUJPPPpczgx1g2NsfinK-D2Rw", + "use": "sig", + "key_ops": [ + "verify" + ], + "ext": true, + "kid": "4b49a739d1eb115b3225f4cf9beb6d1b", + "alg": "ES384" + }, + { + "kty": "EC", + "crv": "P-384", + "d": "kDkn55p7gryKk2tj6z2ij7ExUnhi0ngxXosvqa73y7epwgthFqaJwApmiXXU2yhK", + "x": "JQKTsV6PT5Szf4QtDA1qrs0EJ1pbimQmM2SKvzOlIAqlph3h1OHmZ2i7MXahIF2C", + "y": "bRWWQRJBgDa6CTgwofYrHjVGcO-A7WNEnu4oJA5OUJPPPpczgx1g2NsfinK-D2Rw", + "key_ops": [ + "sign" + ], + "ext": true, + "kid": "4b49a739d1eb115b3225f4cf9beb6d1b", + "alg": "ES384" + }, + { + "kty": "RSA", + "alg": "RS384", + "n": "vjbIzTqiY8K8zApeNng5ekNNIxJfXAue9BjoMrZ9Qy9m7yIA-tf6muEupEXWhq70tC7vIGLqJJ4O8m7yiH8H2qklX2mCAMg3xG3nbykY2X7JXtW9P8VIdG0sAMt5aZQnUGCgSS3n0qaooGn2LUlTGIR88Qi-4Nrao9_3Ki3UCiICeCiAE224jGCg0OlQU6qj2gEB3o-DWJFlG_dz1y-Mxo5ivaeM0vWuodjDrp-aiabJcSF_dx26sdC9dZdBKXFDq0t19I9S9AyGpGDJwzGRtWHY6LsskNHLvo8Zb5AsJ9eRZKpnh30SYBZI9WHtzU85M9WQqdScR69Vyp-6Uhfbvw", + "e": "AQAB", + "use": "sig", + "key_ops": [ + "verify" + ], + "ext": true, + "kid": "b41528b6f37a9500edb8a905a595bdd7" + }, + { + "kty": "RSA", + "alg": "RS384", + "n": "vjbIzTqiY8K8zApeNng5ekNNIxJfXAue9BjoMrZ9Qy9m7yIA-tf6muEupEXWhq70tC7vIGLqJJ4O8m7yiH8H2qklX2mCAMg3xG3nbykY2X7JXtW9P8VIdG0sAMt5aZQnUGCgSS3n0qaooGn2LUlTGIR88Qi-4Nrao9_3Ki3UCiICeCiAE224jGCg0OlQU6qj2gEB3o-DWJFlG_dz1y-Mxo5ivaeM0vWuodjDrp-aiabJcSF_dx26sdC9dZdBKXFDq0t19I9S9AyGpGDJwzGRtWHY6LsskNHLvo8Zb5AsJ9eRZKpnh30SYBZI9WHtzU85M9WQqdScR69Vyp-6Uhfbvw", + "e": "AQAB", + "d": "rriV9GYimi5by7TOW4xNh6_gYBHVRDBsft2OFF8qapdVHt2GNuRDDxc_B6ga6TY2Enh2MLKLTr1dD3W4FIdTCJiMerrorp07FJS7nJEMgWQDxrfgkX4_EqrhW42L5d4vypYnRXEEW6u4gzkx5uFOkdvJBIK7CsIdSaBFYhochnynNgvbKWasi4rl2hayEH8tdf3B7Z6OIH9alspBTaq3j_zJt_KkrpYEzIUb4UgALB5NTWn5YKr0Avk_asOg8YfjViQwO9ASGaWjQeJ2Rx8OEQwBMQHSDMCSWNiWmYOu9PcwSZFc1vLxqzyIM8QrQSJHCCMo_wGYgke_r0CLeONHEQ", + "p": "5hH_QApWGeobRi1n7XbMfJYohB8K3JDPa0MspfplHpJ-17JiGG2sNoBdBcpaPRf9OX48P8VqO0qrSSRAk-I-uO6OO9BHbIukXJILqnY2JmurYzbcYbt5FVbknlHRJojkF6-7sFBazpueUlOnXCw7X7Z_SkfNE4QX5Ejm2Zm5mek", + "q": "06bZz7c7K9s1-aEZsxYnLJ9eTpKlt1tIBDA_LwIh5W3w259pes2kUtimbnkyOf-V2ZIERsFCh5s-S9IOEMvAIa6M5j9GW1ILNT7AcHIUfcyFcH-FF8BU_KJdRP5PXnIXFdYcylvsdoIdchy1AaUIzyiKRCU3HBYI75hez0l_F2c", + "dp": "h_sVIXW6hCCRND48EedIX06k7conMkxIu_39ErDXOWWeoMAnKIcR5TijQnviL__QxD1vQMXezuKIMHfDz2RGbClbWdD1lhtG7wvG515tDPJQXxia0wzqOQmdoFF9S8hXAAT26vPjaAAkaEZXQaxG_4Au5elgNWu6b0wDXZN1Vpk", + "dq": "GqS0YpuUTU8JGmWXUJ4HTGy7eHSpe8134V8ZdRd1oOYYHe2RX64nc25mdR24nuh3uq3Q7_9AGsYGL5E_yAl-JD9O6WUpvDE1y_wcSYty3Os0GRdUb8r8Z9kgmKDS6Pa_xTXw5eBwgfKbNlQ6zPwzgbB-x1lP-K8lbNPni3ybDR0", + "qi": "cqQfoi0sM5Su8ZOhznmdWrDIQB28H6fBKiabgaIKkbWZV4e0nwFvLquHjPOvv4Ao8iEGU5dyhvg0n5BKYPi-4mp6M6OA1sy0NrTr7EsKSYGyu2pBq9rw4oAYTM2LXKg6K-awgUUlkc451IwxHBAe15PWCBM3kvLQeijNid0Vz5I", + "key_ops": [ + "sign" + ], + "ext": true, + "kid": "b41528b6f37a9500edb8a905a595bdd7" + } + ] +} diff --git a/lib/inferno/dsl/jwks.rb b/lib/inferno/dsl/jwks.rb new file mode 100644 index 000000000..1180d82fa --- /dev/null +++ b/lib/inferno/dsl/jwks.rb @@ -0,0 +1,79 @@ +module Inferno + module DSL + # The JWKS class provides methods to handle JSON Web Key Sets (JWKS) + # within Inferno. + # + # This class allows users to fetch, parse, and manage JWKS, ensuring + # that the necessary keys for verifying tokens are available. + class JWKS + class << self + # Returns a formatted JSON string of the JWKS public keys that are used for verification. + # This method filters out keys that do not have the 'verify' operation. + # + # @return [String] The formatted JSON string of the JWKS public keys. + # + # @example + # jwks_json = Inferno::JWKS.jwks_json + # puts jwks_json + def jwks_json + @jwks_json ||= + JSON.pretty_generate( + { keys: jwks.export[:keys].select { |key| key[:key_ops]&.include?('verify') } } + ) + end + + # Provides the default file path to the JWKS file. + # This method is primarily used internally to locate the default JWKS file. + # + # @return [String] The default JWKS file path. + # + # @private + def default_jwks_path + @default_jwks_path ||= File.join(__dir__, 'jwks.json') + end + + # Fetches the JWKS file path from the environment variable `INFERNO_JWKS_PATH`. + # If the environment variable is not set, it falls back to the default path + # provided by `.default_jwks_path`. + # + # @return [String] The JWKS file path. + # + # @private + def jwks_path + @jwks_path ||= + ENV.fetch('INFERNO_JWKS_PATH', default_jwks_path) + end + + # Reads the JWKS content from the file located at the JWKS path. + # + # @return [String] The json content of the JWKS file. + # + # @private + def default_jwks_json + @default_jwks_json ||= File.read(jwks_path) + end + + # Parses and returns a `JWT::JWK::Set` object from the provided JWKS string + # or from the file located at the JWKS path. If a user-provided JWKS string + # is not available, it reads the JWKS from the file. + # + # @param user_jwks [String, nil] An optional json containing the JWKS. + # If not provided, the method reads from the file. + # @return [JWT::JWK::Set] The parsed JWKS set. + # + # @example + # # Using a user-provided JWKS string + # user_jwks = '{"keys":[...]}' + # jwks_set = Inferno::JWKS.jwks(user_jwks: user_jwks) + # + # # Using the default JWKS file + # jwks_set = Inferno::JWKS.jwks + def jwks(user_jwks: nil) + JWT::JWK::Set.new(JSON.parse(user_jwks.presence || default_jwks_json)) + end + end + end + end + + JWKS = DSL::JWKS +end diff --git a/lib/inferno/ext/fhir_client.rb b/lib/inferno/ext/fhir_client.rb index 1963eecbc..62a33f226 100644 --- a/lib/inferno/ext/fhir_client.rb +++ b/lib/inferno/ext/fhir_client.rb @@ -3,11 +3,11 @@ class Client attr_accessor :oauth_credentials, :auth_info def need_to_refresh? - oauth_credentials&.need_to_refresh? + !!(auth_info&.need_to_refresh? || oauth_credentials&.need_to_refresh?) end def able_to_refresh? - oauth_credentials&.able_to_refresh? + !!(auth_info&.able_to_refresh? || oauth_credentials&.able_to_refresh?) end end end diff --git a/spec/fixtures/auth_info_constants.rb b/spec/fixtures/auth_info_constants.rb new file mode 100644 index 000000000..a65ba96d4 --- /dev/null +++ b/spec/fixtures/auth_info_constants.rb @@ -0,0 +1,71 @@ +module AuthInfoConstants + AUTH_URL = 'http://example.com/authorization'.freeze + TOKEN_URL = 'http://example.com/token'.freeze + REQUESTED_SCOPES = 'launch/patient openid fhirUser patient/*.*'.freeze + ENCRYPTION_ALGORITHM = 'ES384'.freeze + KID = '4b49a739d1eb115b3225f4cf9beb6d1b'.freeze + JWKS = File.read(File.join('lib', 'inferno', 'dsl', 'jwks.json')).freeze + class << self + def token_info + { + access_token: 'SAMPLE_TOKEN', + refresh_token: 'SAMPLE_REFRESH_TOKEN', + expires_in: '3600', + issue_time: Time.now.iso8601 + } + end + + def public_access_default + { + auth_type: 'public', + token_url: TOKEN_URL, + client_id: 'SAMPLE_PUBLIC_CLIENT_ID', + requested_scopes: REQUESTED_SCOPES, + pkce_support: 'enabled', + pkce_code_challenge_method: 'S256', + auth_request_method: 'GET' + }.merge(token_info) + end + + def symmetric_confidential_access_default + { + auth_type: 'symmetric', + token_url: TOKEN_URL, + client_id: 'SAMPLE_CONFIDENTIAL_CLIENT_ID', + client_secret: 'SAMPLE_CONFIDENTIAL_CLIENT_SECRET', + auth_url: AUTH_URL, + requested_scopes: REQUESTED_SCOPES, + pkce_support: 'enabled', + pkce_code_challenge_method: 'S256', + auth_request_method: 'POST', + use_discovery: 'false' + }.merge(token_info) + end + + def asymmetric_confidential_access_default + { + auth_type: 'asymmetric', + token_url: TOKEN_URL, + client_id: 'SAMPLE_CONFIDENTIAL_CLIENT_ID', + requested_scopes: REQUESTED_SCOPES, + pkce_support: 'disabled', + auth_request_method: 'POST', + encryption_algorithm: ENCRYPTION_ALGORITHM, + jwks: JWKS, + kid: KID + }.merge(token_info) + end + + def backend_services_access_default + { + auth_type: 'backend_services', + token_url: TOKEN_URL, + client_id: 'SAMPLE_CONFIDENTIAL_CLIENT_ID', + requested_scopes: REQUESTED_SCOPES, + encryption_algorithm: ENCRYPTION_ALGORITHM, + jwks: JWKS, + kid: KID + }.merge(token_info) + end + end +end diff --git a/spec/inferno/dsl/auth_info_spec.rb b/spec/inferno/dsl/auth_info_spec.rb index cf3399e4c..01f6977f5 100644 --- a/spec/inferno/dsl/auth_info_spec.rb +++ b/spec/inferno/dsl/auth_info_spec.rb @@ -17,7 +17,15 @@ name: 'NAME' } end + let(:public_access_default) { AuthInfoConstants.public_access_default } + let(:symmetric_confidential_access_default) { AuthInfoConstants.symmetric_confidential_access_default } + let(:asymmetric_confidential_access_default) { AuthInfoConstants.asymmetric_confidential_access_default } + let(:backend_services_access_default) { AuthInfoConstants.backend_services_access_default } let(:auth_info) { described_class.new(full_params) } + let(:public_auth_info) { described_class.new(public_access_default) } + let(:symmetric_auth_info) { described_class.new(symmetric_confidential_access_default) } + let(:asymmetric_auth_info) { described_class.new(asymmetric_confidential_access_default) } + let(:backend_services_auth_info) { described_class.new(backend_services_access_default) } let(:client) { FHIR::Client.new('http://example.com') } describe '.new' do @@ -71,4 +79,216 @@ expect(JSON.parse(described_class.new(hash).to_s)).to include(hash.stringify_keys) end end + + describe '#need_to_refresh?' do + it 'returns false if there is no access token' do + auth_info.access_token = nil + expect(auth_info.need_to_refresh?).to be(false) + end + + it 'returns true if there is no expires_in' do + auth_info.expires_in = nil + expect(auth_info.need_to_refresh?).to be(true) + end + + it 'returns true if the token has will expire in under a minute' do + auth_info.expires_in = 59 + expect(auth_info.need_to_refresh?).to be(true) + end + + it 'returns true if the token has expired' do + auth_info.issue_time = 2.hours.ago + expect(auth_info.need_to_refresh?).to be(true) + end + + it 'returns false if the token is valid for over a minute' do + auth_info.expires_in = 61 + expect(auth_info.need_to_refresh?).to be(false) + end + + context 'when public, symmetric, or asymmetric auth' do + it 'returns false if there is no refresh token' do + [public_auth_info, symmetric_auth_info, asymmetric_auth_info].each do |credentials| + credentials.refresh_token = nil + expect(credentials.need_to_refresh?).to be(false) + end + end + end + + context 'when backend services auth' do + it 'returns true if no refresh token and access token expired' do + backend_services_auth_info.refresh_token = nil + backend_services_auth_info.issue_time = 2.hours.ago + expect(backend_services_auth_info.need_to_refresh?).to be(true) + end + end + end + + describe '#able_to_refresh?' do + context 'when public, symmetric, or asymmetric auth' do + it 'returns true if a refresh token and token url are present' do + [public_auth_info, symmetric_auth_info, asymmetric_auth_info].each do |credentials| + expect(credentials.able_to_refresh?).to be(true) + end + end + + it 'returns false if the refresh token or token url are missing' do + [:refresh_token, :token_url].each do |field| + [public_access_default, symmetric_confidential_access_default, + asymmetric_confidential_access_default].each do |params| + expect(described_class.new(params.merge("#{field}": nil)).able_to_refresh?).to be(false) + end + end + end + end + + context 'when backend services auth' do + it 'returns true if token url is present' do + expect(backend_services_auth_info.able_to_refresh?).to be(true) + end + + it 'returns false if the token url is missing' do + backend_services_auth_info.token_url = nil + expect(backend_services_auth_info.able_to_refresh?).to be(false) + end + end + end + + describe '#backend_services?' do + it 'returns true if auth type is backend services' do + expect(backend_services_auth_info.backend_services?).to be(true) + end + + it 'returns false if auth type is not backend services' do + expect(public_auth_info.backend_services?).to be(false) + end + end + + describe '#oauth2_refresh_params' do + context 'when public auth' do + it 'returns a hash with grant_type, refresh_token, and client_id params' do + params = public_auth_info.oauth2_refresh_params + expect(params).to include('grant_type', 'refresh_token', 'client_id') + expect(params['grant_type']).to eq('refresh_token') + expect(params['refresh_token']).to eq(public_auth_info.refresh_token) + expect(params['client_id']).to eq(public_auth_info.client_id) + end + end + + context 'when symmetric auth' do + it 'returns a hash with grant_type and refresh_token params' do + params = symmetric_auth_info.oauth2_refresh_params + expect(params).to include('grant_type', 'refresh_token') + expect(params['grant_type']).to eq('refresh_token') + expect(params['refresh_token']).to eq(symmetric_auth_info.refresh_token) + end + end + + context 'when asymmetric auth' do + it 'returns a hash with grant_type, refresh_token, client_assertion_type, and client_assertion params' do + params = asymmetric_auth_info.oauth2_refresh_params + expect(params).to include('grant_type', 'refresh_token', 'client_assertion_type', 'client_assertion') + expect(params['grant_type']).to eq('refresh_token') + expect(params['refresh_token']).to eq(asymmetric_auth_info.refresh_token) + expect(params['client_assertion_type']).to eq('urn:ietf:params:oauth:client-assertion-type:jwt-bearer') + end + end + + context 'when backend services auth' do + it 'returns a hash with grant_type, scope, client_assertion_type, and client_assertion params' do + params = backend_services_auth_info.oauth2_refresh_params + expect(params).to include('grant_type', 'scope', 'client_assertion_type', 'client_assertion') + expect(params['grant_type']).to eq('client_credentials') + expect(params['scope']).to eq(backend_services_auth_info.requested_scopes) + expect(params['client_assertion_type']).to eq('urn:ietf:params:oauth:client-assertion-type:jwt-bearer') + end + end + end + + describe '#oauth2_refresh_headers' do + context 'when symmetric auth' do + it 'returns a hash with Content-Type and Authorization headers' do + expect(symmetric_auth_info.oauth2_refresh_headers).to include('Authorization', 'Content-Type') + end + end + + context 'when public, asymmetric, or backend services auth' do + it 'returns a hash with a Content-Type header' do + [public_auth_info, asymmetric_auth_info, backend_services_auth_info].each do |credentials| + expect(credentials.oauth2_refresh_headers).to eq('Content-Type' => 'application/x-www-form-urlencoded') + end + end + end + end + + describe '#client_assertion' do + context 'when kid is present' do + it 'returns valid JWT signed with keys having the correct algorithm and kid' do + jwt = asymmetric_auth_info.client_assertion + claims, header = JWT.decode(jwt, nil, false) + + expect(header['alg']).to eq(asymmetric_auth_info.encryption_algorithm) + expect(header['typ']).to eq('JWT') + expect(header['kid']).to eq(asymmetric_auth_info.kid) + expect(claims['iss']).to eq(asymmetric_auth_info.client_id) + expect(claims['aud']).to eq(asymmetric_auth_info.token_url) + expect(claims['sub']).to eq(asymmetric_auth_info.client_id) + expect(claims['exp']).to be_present + expect(claims['jti']).to be_present + end + end + + context 'when kid is missing' do + it 'returns valid JWT igned with keys having the correct algorithm' do + asymmetric_auth_info.kid = nil + jwt = asymmetric_auth_info.client_assertion + claims, header = JWT.decode(jwt, nil, false) + + expect(header['alg']).to eq(asymmetric_auth_info.encryption_algorithm) + expect(header['typ']).to eq('JWT') + expect(header['kid']).to be_present + expect(claims['iss']).to eq(asymmetric_auth_info.client_id) + expect(claims['aud']).to eq(asymmetric_auth_info.token_url) + expect(claims['sub']).to eq(asymmetric_auth_info.client_id) + expect(claims['exp']).to be_present + expect(claims['jti']).to be_present + end + end + + it 'throws exception when kid not found for the given algorithm' do + asymmetric_auth_info.kid = 'random' + expect do + asymmetric_auth_info.client_assertion + end.to raise_error(Inferno::Exceptions::AssertionException) + end + end + + describe '#update_from_response_body' do + before { auth_info.add_to_client(client) } + + it 'updates the refresh token if a new one is received' do + response_body = { + access_token: 'NEW_ACCESS_TOKEN', + refresh_token: 'NEW_REFRESH_TOKEN', + expires_in: 3600 + } + request = OpenStruct.new(response_body: response_body.to_json) + + auth_info.update_from_response_body(request) + + expect(auth_info.refresh_token).to eq('NEW_REFRESH_TOKEN') + end + + it 'does not update the refresh token if none is received' do + response_body = { + access_token: 'NEW_ACCESS_TOKEN', + expires_in: 3600 + } + request = OpenStruct.new(response_body: response_body.to_json) + + auth_info.update_from_response_body(request) + + expect(auth_info.refresh_token).to eq('REFRESH_TOKEN') + end + end end diff --git a/spec/inferno/dsl/fhir_client_spec.rb b/spec/inferno/dsl/fhir_client_spec.rb index 68861003d..38c355357 100644 --- a/spec/inferno/dsl/fhir_client_spec.rb +++ b/spec/inferno/dsl/fhir_client_spec.rb @@ -29,6 +29,13 @@ def test_session_id ) end + fhir_client :client_with_auth_info do + url 'http://www.example.com/fhir' + auth_info( + Inferno::DSL::AuthInfo.new(AuthInfoConstants.public_access_default) + ) + end + fhir_client :client_with_trailing_slash do url 'http://www.example.com/fhir/' end @@ -364,6 +371,20 @@ def test_session_id end end + context 'with auth info' do + it 'performs a refresh if the token is about to expire' do + client = group.fhir_client(:client_with_auth_info) + allow(client).to receive(:need_to_refresh?).and_return(true) + allow(client).to receive(:able_to_refresh?).and_return(true) + allow(group).to receive(:perform_refresh).with(client) + + group.fhir_operation(path, client: :client_with_auth_info) + + expect(stub_operation_request).to have_been_made.once + expect(group).to have_received(:perform_refresh) + end + end + context 'with a base url that causes a TCP error' do before do allow_any_instance_of(FHIR::Client) @@ -534,6 +555,20 @@ def test_session_id end end + context 'with auth info' do + it 'performs a refresh if the token is about to expire' do + client = group.fhir_client(:client_with_auth_info) + allow(client).to receive(:need_to_refresh?).and_return(true) + allow(client).to receive(:able_to_refresh?).and_return(true) + allow(group).to receive(:perform_refresh).with(client) + + group.fhir_read(resource.resourceType, resource_id, client: :client_with_auth_info) + + expect(stub_read_request).to have_been_made.once + expect(group).to have_received(:perform_refresh) + end + end + context 'with a base url that causes a TCP error' do before do allow_any_instance_of(FHIR::Client) @@ -1055,6 +1090,21 @@ def test_session_id end end + context 'with auth info' do + it 'performs a refresh if the token is about to expire' do + stub_get_search_request + client = group.fhir_client(:client_with_auth_info) + allow(client).to receive(:need_to_refresh?).and_return(true) + allow(client).to receive(:able_to_refresh?).and_return(true) + allow(group).to receive(:perform_refresh).with(client) + + group.fhir_search(resource.resourceType, params: { patient: 123 }, client: :client_with_auth_info) + + expect(stub_get_search_request).to have_been_made.once + expect(group).to have_received(:perform_refresh) + end + end + context 'with a base url that causes a TCP error' do before do allow_any_instance_of(FHIR::Client) @@ -1263,65 +1313,216 @@ def test_session_id end end - describe '#perform_refresh' do - let(:client) { group.fhir_client(:client_with_oauth_credentials) } - let(:credentials) { client.oauth_credentials } - let(:token_url) { client.token_url } + describe '#auth_info' do + let(:client) { group.fhir_client(:client_with_auth_info) } - context 'when the refresh is unsuccessful' do - it 'does not update credentials' do - original_credentials = credentials.to_hash - token_request = - stub_request(:post, credentials.token_url) - .to_return(status: 400) + it 'uses the given bearer token in the security header' do + token = AuthInfoConstants.public_access_default[:access_token] + expect(client.security_headers).to eq({ 'Authorization' => "Bearer #{token}" }) + end - group.perform_refresh(client) + it 'has the auth flags set correctly' do + expect(client.use_basic_auth).to be_truthy + expect(client.use_oauth2_auth).to be_falsey + end - expect(client.oauth_credentials.to_hash).to eq(original_credentials) - expect(token_request).to have_been_made + it 'stores the credentials on the client' do + expect(client.instance_variable_get(:@auth_info)).to be_a(Inferno::DSL::AuthInfo) + end + end + + describe '#need_to_refresh?' do + context 'with oauth credentials' do + let(:client) { group.fhir_client(:client_with_oauth_credentials) } + + it 'returns true if @oauth_credentials&.need_to_refresh? is true' do + client.oauth_credentials.expires_in = nil + expect(client.need_to_refresh?).to be(true) + end + + it 'returns a falsey if @oauth_credentials&.need_to_refresh? is false' do + client.oauth_credentials.access_token = nil + expect(client).to_not be_need_to_refresh end end - context 'when the refresh is successful' do - let(:token_response_body) do - { - access_token: 'NEW_ACCESS_TOKEN', - token_type: 'bearer', - expires_in: 5000, - scope: 'NEW_SCOPES', - refresh_token: 'NEW_REFRESH_TOKEN' - } + context 'with aut info' do + let(:client) { group.fhir_client(:client_with_auth_info) } + + it 'returns true if @auth_info&.need_to_refresh? is true' do + client.auth_info.expires_in = nil + expect(client.need_to_refresh?).to be(true) end - it 'updates the credentials on the client' do - token_request = - stub_request(:post, credentials.token_url) - .to_return(status: 200, body: JSON.generate(token_response_body)) + it 'returns a falsey if @auth_info&.need_to_refresh? is false' do + client.auth_info.access_token = nil + expect(client).to_not be_need_to_refresh + end + end + + it 'returns a falsey if @auth_info and @oauth_credentials are missing' do + client = group.fhir_client + expect(client).to_not be_need_to_refresh + end + end + + describe '#able_to_refresh?' do + context 'with oauth credentials' do + let(:client) { group.fhir_client(:client_with_oauth_credentials) } + + it 'returns true if @oauth_credentials&.able_to_refresh? is true' do + expect(client.able_to_refresh?).to be(true) + end + + it 'returns a falsey if @oauth_credentials&.able_to_refresh? is false' do + client.oauth_credentials.token_url = nil + expect(client).to_not be_able_to_refresh + end + end - group.perform_refresh(client) + context 'with aut info' do + let(:client) { group.fhir_client(:client_with_auth_info) } - expect(token_request).to have_been_made - expect(credentials.access_token).to eq(token_response_body[:access_token]) - expect(credentials.refresh_token).to eq(token_response_body[:refresh_token]) + it 'returns true if @auth_info&.able_to_refresh? is true' do + expect(client.able_to_refresh?).to be(true) end - it 'updates the credentials in the database' do - session = repo_create(:test_session, test_suite_id: 'demo') - allow(group).to receive(:test_session_id).and_return(session.id) + it 'returns a falsey if @auth_info&.able_to_refresh? is false' do + client.auth_info.token_url = nil + expect(client).to_not be_able_to_refresh + end + end - credentials.name = 'name' - token_request = - stub_request(:post, credentials.token_url) - .to_return(status: 200, body: JSON.generate(token_response_body)) + it 'returns a falsey if @auth_info and @oauth_credentials are missing' do + client = group.fhir_client + expect(client).to_not be_able_to_refresh + end + end + + describe '#perform_refresh' do + context 'with oauth credentials' do + let(:client) { group.fhir_client(:client_with_oauth_credentials) } + let(:credentials) { client.oauth_credentials } + let(:token_url) { client.token_url } + + context 'when the refresh is unsuccessful' do + it 'does not update credentials' do + original_credentials = credentials.to_hash + token_request = + stub_request(:post, credentials.token_url) + .to_return(status: 400) + + group.perform_refresh(client) + + expect(client.oauth_credentials.to_hash).to eq(original_credentials) + expect(token_request).to have_been_made + end + end + + context 'when the refresh is successful' do + let(:token_response_body) do + { + access_token: 'NEW_ACCESS_TOKEN', + token_type: 'bearer', + expires_in: 5000, + scope: 'NEW_SCOPES', + refresh_token: 'NEW_REFRESH_TOKEN' + } + end + + it 'updates the credentials on the client' do + token_request = + stub_request(:post, credentials.token_url) + .to_return(status: 200, body: JSON.generate(token_response_body)) + + group.perform_refresh(client) + + expect(token_request).to have_been_made + expect(credentials.access_token).to eq(token_response_body[:access_token]) + expect(credentials.refresh_token).to eq(token_response_body[:refresh_token]) + end - group.perform_refresh(client) + it 'updates the credentials in the database' do + session = repo_create(:test_session, test_suite_id: 'demo') + allow(group).to receive(:test_session_id).and_return(session.id) - expect(token_request).to have_been_made + credentials.name = 'name' + token_request = + stub_request(:post, credentials.token_url) + .to_return(status: 200, body: JSON.generate(token_response_body)) - persisted_credentials = - session_data_repo.load(test_session_id: group.test_session_id, name: 'name', type: 'oauth_credentials') + group.perform_refresh(client) - expect(persisted_credentials.to_hash).to eq(credentials.to_hash) + expect(token_request).to have_been_made + + persisted_credentials = + session_data_repo.load(test_session_id: group.test_session_id, name: 'name', type: 'oauth_credentials') + + expect(persisted_credentials.to_hash).to eq(credentials.to_hash) + end + end + end + + context 'with auth info' do + let(:client) { group.fhir_client(:client_with_auth_info) } + let(:credentials) { client.auth_info } + let(:token_url) { client.token_url } + + context 'when the refresh is unsuccessful' do + it 'does not update credentials' do + original_credentials = credentials.to_hash + token_request = + stub_request(:post, credentials.token_url) + .to_return(status: 400) + + group.perform_refresh(client) + + expect(client.auth_info.to_hash).to eq(original_credentials) + expect(token_request).to have_been_made + end + end + + context 'when the refresh is successful' do + let(:token_response_body) do + { + access_token: 'NEW_ACCESS_TOKEN', + token_type: 'bearer', + expires_in: 5000, + scope: 'NEW_SCOPES', + refresh_token: 'NEW_REFRESH_TOKEN' + } + end + + it 'updates the credentials on the client' do + token_request = + stub_request(:post, credentials.token_url) + .to_return(status: 200, body: JSON.generate(token_response_body)) + + group.perform_refresh(client) + + expect(token_request).to have_been_made + expect(credentials.access_token).to eq(token_response_body[:access_token]) + expect(credentials.refresh_token).to eq(token_response_body[:refresh_token]) + end + + it 'updates the credentials in the database' do + session = repo_create(:test_session, test_suite_id: 'demo') + allow(group).to receive(:test_session_id).and_return(session.id) + + credentials.name = 'name' + token_request = + stub_request(:post, credentials.token_url) + .to_return(status: 200, body: JSON.generate(token_response_body)) + + group.perform_refresh(client) + + expect(token_request).to have_been_made + + persisted_credentials = + session_data_repo.load(test_session_id: group.test_session_id, name: 'name', type: 'auth_info') + + expect(persisted_credentials.to_hash).to eq(credentials.to_hash) + end end end end