Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat - rewrite WS connection #3

Merged
merged 5 commits into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,53 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
```


### Connecting

Connect to Pusher:

```rust
client.connect().await?;
```

### Subscribing to Channels

Subscribe to a public channel:

```rust
client.subscribe("my-channel").await?;
```

Subscribe to a private channel:

```rust
client.subscribe("private-my-channel").await?;
```

Subscribe to a presence channel:

```rust
client.subscribe("presence-my-channel").await?;
```

### Unsubscribing from Channels

```rust
client.unsubscribe("my-channel").await?;
```

### Binding to Events

Bind to a specific event on a channel:

```rust
use pusher_rs::Event;

client.bind("my-event", |event: Event| {
println!("Received event: {:?}", event);
}).await?;
```

### Subscribing to a channel

```rust
Expand Down Expand Up @@ -133,6 +180,15 @@ The library supports four types of channels:

Each channel type has specific features and authentication requirements.

### Handling Connection State

Get the current connection state:

```rust
let state = client.get_connection_state().await;
println!("Current connection state: {:?}", state);
```

## Error Handling

The library uses a custom `PusherError` type for error handling. You can match on different error variants to handle specific error cases:
Expand All @@ -148,6 +204,14 @@ match client.connect().await {
}
```

### Disconnecting

When you're done, disconnect from Pusher:

```rust
client.disconnect().await?;
```

## Advanced Usage

### Custom Configuration
Expand Down Expand Up @@ -186,6 +250,25 @@ if let Some(channel) = channel_list.get("my-channel") {
}
```

### Presence Channels

When subscribing to a presence channel, you can provide user information:

```rust
use serde_json::json;

let channel = "presence-my-channel";
let socket_id = client.get_socket_id().await?;
let user_id = "user_123";
let user_info = json!({
"name": "John Doe",
"email": "john@example.com"
});

let auth = client.authenticate_presence_channel(&socket_id, channel, user_id, Some(&user_info))?;
client.subscribe_with_auth(channel, &auth).await?;
```

### Tests

Integration tests live under `tests/integration_tests`
Expand Down
20 changes: 20 additions & 0 deletions src/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use serde_json::Value;
pub struct Event {
pub event: String,
pub channel: Option<String>,
#[serde(with = "json_string")]
pub data: Value,
}

Expand Down Expand Up @@ -113,6 +114,25 @@ impl SystemEvent {
}
}

mod json_string {
use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
use serde_json::Value;

pub fn serialize<S>(value: &Value, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
value.to_string().serialize(serializer)
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<Value, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
serde_json::from_str(&s).map_err(D::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
Expand Down
109 changes: 65 additions & 44 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use cbc::{Decryptor, Encryptor};
use hmac::{Hmac, Mac};
use log::info;
use rand::Rng;
use serde_json::json;
use serde_json::{json, Value};
use sha2::Sha256;
use std::collections::HashMap;
use std::sync::Arc;
Expand All @@ -28,14 +28,15 @@ pub use config::PusherConfig;
pub use error::{PusherError, PusherResult};
pub use events::{Event, SystemEvent};

use websocket::WebSocketClient;
use websocket::{WebSocketClient, WebSocketCommand};

/// This struct provides methods for connecting to Pusher, subscribing to channels,
/// triggering events, and handling incoming events.
pub struct PusherClient {
config: PusherConfig,
auth: PusherAuth,
websocket: Option<WebSocketClient>,
// websocket: Option<WebSocketClient>,
websocket_command_tx: Option<mpsc::Sender<WebSocketCommand>>,
channels: Arc<RwLock<HashMap<String, Channel>>>,
event_handlers: Arc<RwLock<HashMap<String, Vec<Box<dyn Fn(Event) + Send + Sync + 'static>>>>>,
state: Arc<RwLock<ConnectionState>>,
Expand Down Expand Up @@ -73,29 +74,44 @@ impl PusherClient {
let auth = PusherAuth::new(&config.app_key, &config.app_secret);
let (event_tx, event_rx) = mpsc::channel(100);
let state = Arc::new(RwLock::new(ConnectionState::Disconnected));
let event_handlers = Arc::new(RwLock::new(HashMap::new()));
let encrypted_channels = Arc::new(RwLock::new(HashMap::new()));
let event_handlers = Arc::new(RwLock::new(std::collections::HashMap::new()));
let encrypted_channels = Arc::new(RwLock::new(std::collections::HashMap::new()));

let client = Self {
config,
auth,
websocket: None,
channels: Arc::new(RwLock::new(HashMap::new())),
websocket_command_tx: None,
channels: Arc::new(RwLock::new(std::collections::HashMap::new())),
event_handlers: event_handlers.clone(),
state: state.clone(),
event_tx,
encrypted_channels,
};

// Spawn the event handling task
tokio::spawn(Self::handle_events(event_rx, event_handlers));

Ok(client)
}

async fn send(&self, message: String) -> PusherResult<()> {
if let Some(tx) = &self.websocket_command_tx {
tx.send(WebSocketCommand::Send(message))
.await
.map_err(|e| {
PusherError::WebSocketError(format!("Failed to send command: {}", e))
})?;
Ok(())
} else {
Err(PusherError::ConnectionError("Not connected".into()))
}
}

async fn handle_events(
mut event_rx: mpsc::Receiver<Event>,
event_handlers: Arc<
RwLock<HashMap<String, Vec<Box<dyn Fn(Event) + Send + Sync + 'static>>>>,
RwLock<
std::collections::HashMap<String, Vec<Box<dyn Fn(Event) + Send + Sync + 'static>>>,
>,
>,
) {
while let Some(event) = event_rx.recv().await {
Expand All @@ -115,18 +131,24 @@ impl PusherClient {
/// A `PusherResult` indicating success or failure.
pub async fn connect(&mut self) -> PusherResult<()> {
let url = self.get_websocket_url()?;
let mut websocket =
WebSocketClient::new(url.clone(), Arc::clone(&self.state), self.event_tx.clone());
let (command_tx, command_rx) = mpsc::channel(100);

let mut websocket = WebSocketClient::new(
url.clone(),
Arc::clone(&self.state),
self.event_tx.clone(),
command_rx,
);

log::info!("Connecting to Pusher using URL: {}", url);
websocket.connect().await?;
self.websocket = Some(websocket);

// Start the WebSocket event loop
let mut ws = self.websocket.take().unwrap();
tokio::spawn(async move {
ws.run().await;
websocket.run().await;
});

self.websocket_command_tx = Some(command_tx);

Ok(())
}

Expand All @@ -136,11 +158,12 @@ impl PusherClient {
///
/// A `PusherResult` indicating success or failure.
pub async fn disconnect(&mut self) -> PusherResult<()> {
if let Some(websocket) = &self.websocket {
websocket.close().await?;
if let Some(tx) = self.websocket_command_tx.take() {
tx.send(WebSocketCommand::Close).await.map_err(|e| {
PusherError::WebSocketError(format!("Failed to send close command: {}", e))
})?;
}
*self.state.write().await = ConnectionState::Disconnected;
self.websocket = None;
Ok(())
}

Expand All @@ -158,21 +181,17 @@ impl PusherClient {
let mut channels = self.channels.write().await;
channels.insert(channel_name.to_string(), channel);

if let Some(websocket) = &self.websocket {
let data = json!({
"event": "pusher:subscribe",
"data": {
"channel": channel_name
}
});
websocket.send(serde_json::to_string(&data)?).await?;
} else {
return Err(PusherError::ConnectionError("Not connected".into()));
}
let data = json!({
"event": "pusher:subscribe",
"data": {
"channel": channel_name
}
});

Ok(())
self.send(serde_json::to_string(&data)?).await
}


/// Subscribes to an encrypted channel.
///
/// # Arguments
Expand Down Expand Up @@ -208,6 +227,7 @@ impl PusherClient {
/// # Returns
///
/// A `PusherResult` indicating success or failure.
///
pub async fn unsubscribe(&mut self, channel_name: &str) -> PusherResult<()> {
{
let mut channels = self.channels.write().await;
Expand All @@ -219,19 +239,14 @@ impl PusherClient {
encrypted_channels.remove(channel_name);
}

if let Some(websocket) = &self.websocket {
let data = json!({
"event": "pusher:unsubscribe",
"data": {
"channel": channel_name
}
});
websocket.send(serde_json::to_string(&data)?).await?;
} else {
return Err(PusherError::ConnectionError("Not connected".into()));
}
let data = json!({
"event": "pusher:unsubscribe",
"data": {
"channel": channel_name
}
});

Ok(())
self.send(serde_json::to_string(&data)?).await
}

/// Triggers an event on a channel.
Expand All @@ -251,10 +266,14 @@ impl PusherClient {
self.config.cluster, self.config.app_id
);

// Validate that the data is valid JSON, but keep it as a string
serde_json::from_str::<serde_json::Value>(data)
.map_err(|e| PusherError::JsonError(e))?;

let body = json!({
"name": event,
"channel": channel,
"data": data
"data": data, // Keep data as a string
});
let path = format!("/apps/{}/events", self.config.app_id);
let auth_params = self.auth.authenticate_request("POST", &path, &body)?;
Expand Down Expand Up @@ -371,6 +390,7 @@ impl PusherClient {
/// # Returns
///
/// A `PusherResult` indicating success or failure.
///
pub async fn bind<F>(&self, event_name: &str, callback: F) -> PusherResult<()>
where
F: Fn(Event) + Send + Sync + 'static,
Expand Down Expand Up @@ -535,7 +555,8 @@ mod tests {

#[tokio::test]
async fn test_trigger_batch() {
let config = PusherConfig::from_env().expect("Failed to load Pusher configuration from environment");
let config =
PusherConfig::from_env().expect("Failed to load Pusher configuration from environment");
let client = PusherClient::new(config).unwrap();

let batch_events = vec![
Expand Down
Loading