Skip to content

Commit

Permalink
Modify ClientSession and ServerSession to own the configs instead of
Browse files Browse the repository at this point in the history
holding references to it.

This simplifies integration on the caller side since the caller no
longer needs to ensure the lifetime of the configs outside of the
sessions. It should be the sessions responsibility to own the config and
destroy it whenever the session is destroyed.

Change-Id: I6ce161a9b5f58b3c38a19204bc5f61116ee78661
  • Loading branch information
rakshita-tandon authored and k-naliuka committed Aug 12, 2024
1 parent 37c0407 commit c81f2b9
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 83 deletions.
24 changes: 12 additions & 12 deletions oak_session/src/attestation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,23 @@ pub trait AttestationProvider {
/// Client-side Attestation Provider that initiates remote attestation with the
/// server.
#[allow(dead_code)]
pub struct ClientAttestationProvider<'a> {
config: AttestationProviderConfig<'a>,
pub struct ClientAttestationProvider {
config: AttestationProviderConfig,
}

impl<'a> ClientAttestationProvider<'a> {
pub fn new(config: AttestationProviderConfig<'a>) -> Self {
impl ClientAttestationProvider {
pub fn new(config: AttestationProviderConfig) -> Self {
Self { config }
}
}

impl<'a> AttestationProvider for ClientAttestationProvider<'a> {
impl AttestationProvider for ClientAttestationProvider {
fn get_attestation_results(self) -> Option<AttestationResults> {
core::unimplemented!();
}
}

impl<'a> ProtocolEngine<AttestResponse, AttestRequest> for ClientAttestationProvider<'a> {
impl ProtocolEngine<AttestResponse, AttestRequest> for ClientAttestationProvider {
fn get_outgoing_message(&mut self) -> anyhow::Result<Option<AttestRequest>> {
core::unimplemented!();
}
Expand All @@ -91,23 +91,23 @@ impl<'a> ProtocolEngine<AttestResponse, AttestRequest> for ClientAttestationProv
/// Server-side Attestation Provider that responds to the remote attestation
/// request from the client.
#[allow(dead_code)]
pub struct ServerAttestationProvider<'a> {
config: AttestationProviderConfig<'a>,
pub struct ServerAttestationProvider {
config: AttestationProviderConfig,
}

impl<'a> ServerAttestationProvider<'a> {
pub fn new(config: AttestationProviderConfig<'a>) -> Self {
impl ServerAttestationProvider {
pub fn new(config: AttestationProviderConfig) -> Self {
Self { config }
}
}

impl<'a> AttestationProvider for ServerAttestationProvider<'a> {
impl AttestationProvider for ServerAttestationProvider {
fn get_attestation_results(self) -> Option<AttestationResults> {
core::unimplemented!();
}
}

impl<'a> ProtocolEngine<AttestRequest, AttestResponse> for ServerAttestationProvider<'a> {
impl ProtocolEngine<AttestRequest, AttestResponse> for ServerAttestationProvider {
fn get_outgoing_message(&mut self) -> anyhow::Result<Option<AttestResponse>> {
core::unimplemented!();
}
Expand Down
46 changes: 23 additions & 23 deletions oak_session/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,26 @@ use crate::{
};

#[allow(dead_code)]
pub struct SessionConfig<'a> {
pub attestation_provider_config: AttestationProviderConfig<'a>,
pub handshaker_config: HandshakerConfig<'a>,
pub encryptor_config: EncryptorConfig<'a>,
pub struct SessionConfig {
pub attestation_provider_config: AttestationProviderConfig,
pub handshaker_config: HandshakerConfig,
pub encryptor_config: EncryptorConfig,
}

impl<'a> SessionConfig<'a> {
impl SessionConfig {
pub fn builder(
attestation_type: AttestationType,
handshake_type: HandshakeType,
) -> SessionConfigBuilder<'a> {
) -> SessionConfigBuilder {
SessionConfigBuilder::new(attestation_type, handshake_type)
}
}

pub struct SessionConfigBuilder<'a> {
config: SessionConfig<'a>,
pub struct SessionConfigBuilder {
config: SessionConfig,
}

impl<'a> SessionConfigBuilder<'a> {
impl SessionConfigBuilder {
fn new(attestation_type: AttestationType, handshake_type: HandshakeType) -> Self {
let attestation_provider_config = AttestationProviderConfig {
attestation_type,
Expand All @@ -62,28 +62,28 @@ impl<'a> SessionConfigBuilder<'a> {
};

let encryptor_config = EncryptorConfig {
encryptor_provider: &|sk| {
encryptor_provider: Box::new(|sk| {
<SessionKeys as TryInto<OrderedChannelEncryptor>>::try_into(sk)
.map(|v| Box::new(v) as Box<dyn Encryptor>)
},
}),
};

let config =
SessionConfig { attestation_provider_config, handshaker_config, encryptor_config };
Self { config }
}

pub fn add_self_attester(mut self, attester: &'a dyn Attester) -> Self {
pub fn add_self_attester(mut self, attester: Box<dyn Attester>) -> Self {
self.config.attestation_provider_config.self_attesters.push(attester);
self
}

pub fn add_peer_verifier(mut self, verifier: &'a dyn AttestationVerifier) -> Self {
pub fn add_peer_verifier(mut self, verifier: Box<dyn AttestationVerifier>) -> Self {
self.config.attestation_provider_config.peer_verifiers.push(verifier);
self
}

pub fn set_self_private_key(mut self, private_key: &'a dyn IdentityKeyHandle) -> Self {
pub fn set_self_private_key(mut self, private_key: Box<dyn IdentityKeyHandle>) -> Self {
if self.config.handshaker_config.self_static_private_key.is_none() {
self.config.handshaker_config.self_static_private_key = Some(private_key);
} else {
Expand All @@ -103,36 +103,36 @@ impl<'a> SessionConfigBuilder<'a> {

pub fn set_encryption_provider(
mut self,
encryptor_provider: &'a dyn Fn(SessionKeys) -> Result<Box<dyn Encryptor>, Error>,
encryptor_provider: Box<dyn Fn(SessionKeys) -> Result<Box<dyn Encryptor>, Error>>,
) -> Self {
self.config.encryptor_config.encryptor_provider = encryptor_provider;
self
}

pub fn build(self) -> SessionConfig<'a> {
pub fn build(self) -> SessionConfig {
self.config
}
}

#[allow(dead_code)]
pub struct AttestationProviderConfig<'a> {
pub struct AttestationProviderConfig {
pub attestation_type: AttestationType,
pub self_attesters: Vec<&'a dyn Attester>,
pub peer_verifiers: Vec<&'a dyn AttestationVerifier>,
pub self_attesters: Vec<Box<dyn Attester>>,
pub peer_verifiers: Vec<Box<dyn AttestationVerifier>>,
}

#[allow(dead_code)]
pub struct HandshakerConfig<'a> {
pub struct HandshakerConfig {
pub handshake_type: HandshakeType,
// Used for authentication schemes where a static public key is pre-shared with the responder.
pub self_static_private_key: Option<&'a dyn IdentityKeyHandle>,
pub self_static_private_key: Option<Box<dyn IdentityKeyHandle>>,
// Used for authentication schemes where a responder's static public key is pre-shared with
// the initiator.
pub peer_static_public_key: Option<Vec<u8>>,
// Public key that can be used to bind the attestation obtained from the peer to the handshake.
pub peer_attestation_binding_public_key: Option<Vec<u8>>,
}

pub struct EncryptorConfig<'a> {
pub encryptor_provider: &'a dyn Fn(SessionKeys) -> Result<Box<dyn Encryptor>, Error>,
pub struct EncryptorConfig {
pub encryptor_provider: Box<dyn Fn(SessionKeys) -> Result<Box<dyn Encryptor>, Error>>,
}
19 changes: 11 additions & 8 deletions oak_session/src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//! This module provides an implementation of the Handshaker, which
//! handles cryptographic handshake and secure session creation.

use alloc::boxed::Box;
use core::convert::TryInto;

use anyhow::{anyhow, Context};
Expand Down Expand Up @@ -58,7 +59,7 @@ pub struct ClientHandshaker {
}

impl ClientHandshaker {
pub fn create(handshaker_config: &HandshakerConfig) -> anyhow::Result<Self> {
pub fn create(handshaker_config: HandshakerConfig) -> anyhow::Result<Self> {
let handshake_type = handshaker_config.handshake_type;
let peer_static_public_key = handshaker_config.peer_static_public_key.clone();
Ok(Self {
Expand Down Expand Up @@ -132,15 +133,15 @@ impl ProtocolEngine<HandshakeResponse, HandshakeRequest> for ClientHandshaker {
/// Server-side Handshaker that responds to the crypto handshake request from
/// the client.
#[allow(dead_code)]
pub struct ServerHandshaker<'a> {
pub struct ServerHandshaker {
handshake_type: HandshakeType,
self_identity_key: Option<&'a dyn IdentityKeyHandle>,
self_identity_key: Option<Box<dyn IdentityKeyHandle>>,
handshake_response: Option<HandshakeResponse>,
handshake_result: Option<SessionKeys>,
}

impl<'a> ServerHandshaker<'a> {
pub fn new(handshaker_config: &HandshakerConfig<'a>) -> Self {
impl ServerHandshaker {
pub fn new(handshaker_config: HandshakerConfig) -> Self {
Self {
handshake_type: handshaker_config.handshake_type,
self_identity_key: handshaker_config.self_static_private_key,
Expand All @@ -150,13 +151,13 @@ impl<'a> ServerHandshaker<'a> {
}
}

impl<'a> Handshaker for ServerHandshaker<'a> {
impl Handshaker for ServerHandshaker {
fn derive_session_keys(&mut self) -> Option<SessionKeys> {
self.handshake_result.take()
}
}

impl<'a> ProtocolEngine<HandshakeRequest, HandshakeResponse> for ServerHandshaker<'a> {
impl ProtocolEngine<HandshakeRequest, HandshakeResponse> for ServerHandshaker {
fn get_outgoing_message(&mut self) -> anyhow::Result<Option<HandshakeResponse>> {
Ok(self.handshake_response.take())
}
Expand All @@ -178,7 +179,9 @@ impl<'a> ProtocolEngine<HandshakeRequest, HandshakeResponse> for ServerHandshake
HandshakeType::NoiseKN => core::unimplemented!(),
HandshakeType::NoiseNK => respond_nk(
self.self_identity_key
.context("handshaker_config missing the self private key")?,
.as_ref()
.context("handshaker_config missing the self private key")?
.as_ref(),
&in_data,
)
.map_err(|e| anyhow!("handshake response failed: {e:?}"))?,
Expand Down
40 changes: 20 additions & 20 deletions oak_session/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use oak_proto_rust::oak::session::v1::{
};

use crate::{
config::SessionConfig,
config::{EncryptorConfig, SessionConfig},
handshake::{ClientHandshaker, Handshaker, ServerHandshaker},
ProtocolEngine,
};
Expand Down Expand Up @@ -62,29 +62,29 @@ pub trait Session {
}

// Client-side secure attested session entrypoint.
pub struct ClientSession<'a> {
config: &'a SessionConfig<'a>,
pub struct ClientSession {
handshaker: ClientHandshaker,
// encryptor is initialized once the handshake is completed and the session becomes open
encryptor_config: EncryptorConfig,
encryptor: Option<Box<dyn Encryptor>>,
outgoing_requests: VecDeque<SessionRequest>,
incoming_responses: VecDeque<SessionResponse>,
}

impl<'a> ClientSession<'a> {
pub fn create(config: &'a SessionConfig<'a>) -> Result<Self, Error> {
impl ClientSession {
pub fn create(config: SessionConfig) -> Result<Self, Error> {
Ok(Self {
config,
handshaker: ClientHandshaker::create(&config.handshaker_config)
handshaker: ClientHandshaker::create(config.handshaker_config)
.context("couldn't create the client handshaker")?,
encryptor_config: config.encryptor_config,
encryptor: None,
outgoing_requests: VecDeque::new(),
incoming_responses: VecDeque::new(),
})
}
}

impl<'a> Session for ClientSession<'a> {
impl Session for ClientSession {
fn is_open(&self) -> bool {
self.encryptor.is_some()
}
Expand Down Expand Up @@ -128,7 +128,7 @@ impl<'a> Session for ClientSession<'a> {
}
}

impl<'a> ProtocolEngine<SessionResponse, SessionRequest> for ClientSession<'a> {
impl ProtocolEngine<SessionResponse, SessionRequest> for ClientSession {
fn get_outgoing_message(&mut self) -> anyhow::Result<Option<SessionRequest>> {
if self.is_open() {
return Ok(self.outgoing_requests.pop_front());
Expand Down Expand Up @@ -162,7 +162,7 @@ impl<'a> ProtocolEngine<SessionResponse, SessionRequest> for ClientSession<'a> {
))?;
if let Some(session_keys) = self.handshaker.derive_session_keys() {
self.encryptor = Some(
(self.config.encryptor_config.encryptor_provider)(session_keys)
(self.encryptor_config.encryptor_provider)(session_keys)
.context("couldn't create an encryptor from the session key")?,
)
}
Expand All @@ -174,28 +174,28 @@ impl<'a> ProtocolEngine<SessionResponse, SessionRequest> for ClientSession<'a> {
}

// Server-side secure attested session entrypoint.
pub struct ServerSession<'a> {
config: &'a SessionConfig<'a>,
handshaker: ServerHandshaker<'a>,
pub struct ServerSession {
handshaker: ServerHandshaker,
// encryptor is initialized once the handshake is completed and the session becomes open
encryptor_config: EncryptorConfig,
encryptor: Option<Box<dyn Encryptor>>,
outgoing_responses: VecDeque<SessionResponse>,
incoming_requests: VecDeque<SessionRequest>,
}

impl<'a> ServerSession<'a> {
pub fn new(config: &'a SessionConfig<'a>) -> Self {
impl ServerSession {
pub fn new(config: SessionConfig) -> Self {
Self {
config,
handshaker: ServerHandshaker::new(&config.handshaker_config),
handshaker: ServerHandshaker::new(config.handshaker_config),
encryptor_config: config.encryptor_config,
encryptor: None,
outgoing_responses: VecDeque::new(),
incoming_requests: VecDeque::new(),
}
}
}

impl<'a> Session for ServerSession<'a> {
impl Session for ServerSession {
fn is_open(&self) -> bool {
self.encryptor.is_some()
}
Expand Down Expand Up @@ -239,7 +239,7 @@ impl<'a> Session for ServerSession<'a> {
}
}

impl<'a> ProtocolEngine<SessionRequest, SessionResponse> for ServerSession<'a> {
impl ProtocolEngine<SessionRequest, SessionResponse> for ServerSession {
fn get_outgoing_message(&mut self) -> anyhow::Result<Option<SessionResponse>> {
Ok(self.outgoing_responses.pop_front())
}
Expand Down Expand Up @@ -272,7 +272,7 @@ impl<'a> ProtocolEngine<SessionRequest, SessionResponse> for ServerSession<'a> {
}
if let Some(session_keys) = self.handshaker.derive_session_keys() {
self.encryptor = Some(
(self.config.encryptor_config.encryptor_provider)(session_keys)
(self.encryptor_config.encryptor_provider)(session_keys)
.context("couldn't create an encryptor from the session key")?,
)
}
Expand Down
Loading

0 comments on commit c81f2b9

Please sign in to comment.