-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaxum.rs
151 lines (136 loc) · 3.92 KB
/
axum.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
use axum::{
extract::{
ws::{Message as WsMessage, WebSocket, WebSocketUpgrade},
Json, Query, State,
},
response::{
sse::{Event as SseEvent, Sse},
Html, IntoResponse,
},
routing::{get, post},
Router,
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::net::SocketAddr;
use tagged_channels::TaggedChannels;
#[derive(Clone, Eq, Hash, PartialEq)]
enum ChannelTag {
UserId(i32),
IsAdmin,
}
#[derive(Deserialize, Serialize)]
#[serde(tag = "_type")]
enum EventMessage {
User(UserMessage),
Admin(SimpleMessage),
Broadcast(SimpleMessage),
}
#[derive(Deserialize, Serialize)]
struct UserMessage {
user_id: i32,
message: String,
}
#[derive(Deserialize, Serialize)]
struct SimpleMessage {
message: String,
}
#[derive(Deserialize)]
struct ConnectionParams {
user_id: Option<i32>,
is_admin: bool,
}
#[tokio::main]
async fn main() {
let channels = TaggedChannels::new();
let app = Router::new()
.route("/", get(index))
.route("/send", post(send))
.route("/sse", get(sse_ui))
.route("/ws", get(ws_ui))
.route("/sse-events", get(events))
.route("/ws-events", get(ws_events))
.with_state(channels);
let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
axum::Server::bind(&addr)
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await
.unwrap();
}
async fn index() -> Html<String> {
let page = [("WebSocket", "/ws"), ("SSE", "/sse")]
.iter()
.map(|(name, url)| format!(r#"<li><a href="{url}">{name} example</a></li>"#))
.collect();
Html(page)
}
async fn sse_ui() -> Html<String> {
Html(include_str!("ui.html").replace("{{example}}", "sse"))
}
async fn ws_ui() -> Html<String> {
Html(include_str!("ui.html").replace("{{example}}", "ws"))
}
async fn send(
State(channels): State<TaggedChannels<EventMessage, ChannelTag>>,
Json(message): Json<EventMessage>,
) {
use EventMessage::*;
match message {
User(data) => {
let tag = ChannelTag::UserId(data.user_id);
channels.send_by_tag(&tag, User(data)).await
}
Admin(data) => {
let tag = ChannelTag::IsAdmin;
channels.send_by_tag(&tag, Admin(data)).await
}
Broadcast(data) => channels.broadcast(Broadcast(data)).await,
}
}
/// Handler for browser to receive SSE events
async fn events(
Query(params): Query<ConnectionParams>,
State(mut channels): State<TaggedChannels<EventMessage, ChannelTag>>,
) -> Sse<impl Stream<Item = Result<SseEvent, Infallible>>> {
let stream = async_stream::stream! {
let mut rx = channels.create_channel(params.as_tags());
while let Some(msg) = rx.recv().await {
let Ok(json) = serde_json::to_string(&msg) else { continue };
yield Ok(SseEvent::default().data(json));
}
};
Sse::new(stream)
}
async fn ws_events(
ws: WebSocketUpgrade,
Query(params): Query<ConnectionParams>,
State(channels): State<TaggedChannels<EventMessage, ChannelTag>>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, channels, params.as_tags()))
}
async fn handle_socket(
mut socket: WebSocket,
mut channels: TaggedChannels<EventMessage, ChannelTag>,
tags: Vec<ChannelTag>,
) {
let mut rx = channels.create_channel(tags);
while let Some(msg) = rx.recv().await {
let Ok(json) = serde_json::to_string(&msg) else { continue };
if socket.send(WsMessage::Text(json)).await.is_err() {
break;
}
}
}
impl ConnectionParams {
fn as_tags(&self) -> Vec<ChannelTag> {
let mut tags = Vec::new();
if let Some(id) = self.user_id {
tags.push(ChannelTag::UserId(id));
}
if self.is_admin {
tags.push(ChannelTag::IsAdmin);
}
tags
}
}