Skip to content

Commit

Permalink
Use OakRequestWithSessionIdentifier in wasm client
Browse files Browse the repository at this point in the history
Change-Id: I81bedc3b806560d1172c818158a717b77f62a8d1
  • Loading branch information
jul-sh committed Sep 13, 2024
1 parent 03fbd19 commit cf34895
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 23 deletions.
2 changes: 1 addition & 1 deletion oak_proto_rust/generated/oak.session.v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ pub mod oak_request {
}
}
/// Wrapper around OakRequest that is used in cases where it is necessary to
/// identify a contigious session across several invocations/streams.
/// identify a contiguous session across multiple invocations/streams.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost_derive::Message)]
pub struct OakRequestWithSessionIdentifier {
Expand Down
68 changes: 47 additions & 21 deletions oak_session_wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//

use std::collections::VecDeque;

use oak_proto_rust::oak::session::v1::SessionResponse;
use js_sys::Math;
use oak_proto_rust::oak::session::v1::{
oak_request, oak_response, OakRequest, OakRequestWithSessionIdentifier, OakResponse,
PlaintextMessage,
};
use oak_session::{
attestation::AttestationType, config::SessionConfig, handshake::HandshakeType, ClientSession,
ProtocolEngine, Session,
Expand All @@ -27,6 +30,8 @@ use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct WasmClientSession {
inner: ClientSession,
// Used by the server to identify the client session
session_id: Vec<u8>,
}

#[wasm_bindgen]
Expand All @@ -42,7 +47,11 @@ impl WasmClientSession {
let config =
SessionConfig::builder(AttestationType::Unattested, HandshakeType::NoiseNN).build();
let inner = ClientSession::create(config).map_err(|e| JsValue::from_str(&e.to_string()))?;
Ok(WasmClientSession { inner })

// Generate a random 16-byte session ID
let session_id: Vec<u8> = (0..16).map(|_| (Math::random() * 256.0) as u8).collect();

Ok(WasmClientSession { inner, session_id })
}

/// Checks whether session is ready to send and receive encrypted messages.
Expand All @@ -60,7 +69,8 @@ impl WasmClientSession {
/// in multiple outgoing protocol messages being created.
#[wasm_bindgen]
pub fn write(&mut self, plaintext: &[u8]) -> Result<(), JsValue> {
self.inner.write(plaintext).map_err(|e| JsValue::from_str(&e.to_string()))
let plaintext_message = PlaintextMessage { plaintext: plaintext.to_vec() };
self.inner.write(&plaintext_message).map_err(|e| JsValue::from_str(&e.to_string()))
}

/// Attempts to find a message containing ciphertext in the queue of
Expand All @@ -70,43 +80,59 @@ impl WasmClientSession {
/// This function can be called multiple times in a row.
#[wasm_bindgen]
pub fn read(&mut self) -> Result<Option<Vec<u8>>, JsValue> {
self.inner.read().map_err(|e| JsValue::from_str(&e.to_string()))
self.inner
.read()
.map(|opt_msg: Option<PlaintextMessage>| opt_msg.map(|msg| msg.plaintext))
.map_err(|e| JsValue::from_str(&e.to_string()))
}

/// Puts a message received from the peer into the state-machine.
///
/// The message is a byte-encoded protobuf message of the type
/// `type.googleapis.com/oak.session.v1`. It may contain ciphertext,
/// attestation evidence, or a handshake step.
/// `type.googleapis.com/oak.session.v1.OakResponse`.
#[wasm_bindgen]
pub fn put_incoming_message(
&mut self,
incoming_message: &[u8],
) -> Result<PutIncomingMessageResult, JsValue> {
self.inner
.put_incoming_message(
&SessionResponse::decode(incoming_message)
.map_err(|e| JsValue::from_str(&e.to_string()))?,
)
.map(|opt_result| {
opt_result
.map(|_| PutIncomingMessageResult::Success)
.unwrap_or(PutIncomingMessageResult::NoIncomingMessageExpected)
})
.map_err(|e| JsValue::from_str(&e.to_string()))
let oak_response =
OakResponse::decode(incoming_message).map_err(|e| JsValue::from_str(&e.to_string()))?;

if let Some(oak_response::Request::SessionResponse(session_response)) = oak_response.request
{
self.inner
.put_incoming_message(&session_response)
.map(|opt_result| {
opt_result
.map(|_| PutIncomingMessageResult::Success)
.unwrap_or(PutIncomingMessageResult::NoIncomingMessageExpected)
})
.map_err(|e| JsValue::from_str(&e.to_string()))
} else {
Err(JsValue::from_str("Unexpected OakResponse type"))
}
}

/// Gets the next message that needs to be sent to the peer
/// from the state-machine.
///
/// The message is a byte-encoded protobuf message of the type
/// `type.googleapis.com/oak.session.v1`. It may contain ciphertext,
/// attestation evidence, or a handshake step.
/// `type.googleapis.com/oak.session.v1.OakRequestWithSessionIdentifier`.
#[wasm_bindgen]
pub fn get_outgoing_message(&mut self) -> Result<Option<Vec<u8>>, JsValue> {
self.inner
.get_outgoing_message()
.map(|opt_msg| opt_msg.map(|msg| msg.encode_to_vec()))
.map(|opt_msg| {
opt_msg.map(|session_request| {
let oak_request_with_id = OakRequestWithSessionIdentifier {
session_id: self.session_id.clone(),
request: Some(OakRequest {
request: Some(oak_request::Request::SessionRequest(session_request)),
}),
};
oak_request_with_id.encode_to_vec()
})
})
.map_err(|e| JsValue::from_str(&e.to_string()))
}
}
Expand Down
2 changes: 1 addition & 1 deletion proto/session/protocol.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ message OakRequest {
}

// Wrapper around OakRequest that is used in cases where it is necessary to
// identify a contigious session across several invocations/streams.
// identify a contiguous session across multiple invocations/streams.
message OakRequestWithSessionIdentifier {
// Unique string to identify the session. This should be at least 128 bits of
// unique information.
Expand Down

0 comments on commit cf34895

Please sign in to comment.