diff --git a/zio-http/src/main/scala/zhttp/socket/WebSocketFrame.scala b/zio-http/src/main/scala/zhttp/socket/WebSocketFrame.scala index 04c30ad3cc..505d002240 100644 --- a/zio-http/src/main/scala/zhttp/socket/WebSocketFrame.scala +++ b/zio-http/src/main/scala/zhttp/socket/WebSocketFrame.scala @@ -1,7 +1,8 @@ package zhttp.socket -import io.netty.buffer.ByteBuf +import io.netty.buffer.{ByteBuf, ByteBufUtil, Unpooled} import io.netty.handler.codec.http.websocketx.{WebSocketFrame => JWebSocketFrame, _} +import zio.Chunk sealed trait WebSocketFrame extends Product with Serializable { self => final def toWebSocketFrame: JWebSocketFrame = WebSocketFrame.toJFrame(self) @@ -10,13 +11,13 @@ sealed trait WebSocketFrame extends Product with Serializable { self => object WebSocketFrame { - case class Binary(buffer: ByteBuf) extends WebSocketFrame { override val isFinal: Boolean = true } + case class Binary(bytes: Chunk[Byte]) extends WebSocketFrame { override val isFinal: Boolean = true } object Binary { - def apply(buffer: ByteBuf, isFinal: Boolean): Binary = { + def apply(bytes: Chunk[Byte], isFinal: Boolean): Binary = { val arg = isFinal - new Binary(buffer) { override val isFinal: Boolean = arg } + new Binary(bytes) { override val isFinal: Boolean = arg } } - def unapply(frame: WebSocketFrame.Binary): Option[ByteBuf] = Some(frame.buffer) + def unapply(frame: WebSocketFrame.Binary): Option[Chunk[Byte]] = Some(frame.bytes) } case class Text(text: String) extends WebSocketFrame { override val isFinal: Boolean = true } @@ -48,7 +49,7 @@ object WebSocketFrame { def close(status: Int, reason: Option[String] = None): WebSocketFrame = WebSocketFrame.Close(status, reason) - def binary(chunks: ByteBuf): WebSocketFrame = WebSocketFrame.Binary(chunks) + def binary(bytes: Chunk[Byte]): WebSocketFrame = WebSocketFrame.Binary(bytes) def ping: WebSocketFrame = WebSocketFrame.Ping @@ -63,7 +64,7 @@ object WebSocketFrame { case _: PongWebSocketFrame => Option(Pong) case m: BinaryWebSocketFrame => - Option(Binary((m.content()), m.isFinalFragment)) + Option(Binary(Chunk.fromArray(ByteBufUtil.getBytes(m.content())), m.isFinalFragment)) case m: TextWebSocketFrame => Option(Text(m.text(), m.isFinalFragment)) case m: CloseWebSocketFrame => @@ -77,7 +78,7 @@ object WebSocketFrame { def toJFrame(frame: WebSocketFrame): JWebSocketFrame = frame match { case b: Binary => - new BinaryWebSocketFrame(b.isFinal, 0, b.buffer) + new BinaryWebSocketFrame(b.isFinal, 0, Unpooled.wrappedBuffer(b.bytes.toArray)) case t: Text => new TextWebSocketFrame(t.isFinal, 0, t.text) case Close(status, Some(text)) => diff --git a/zio-http/src/test/scala/zhttp/service/WebSocketServerSpec.scala b/zio-http/src/test/scala/zhttp/service/WebSocketServerSpec.scala index 057908b450..dd56aae7ed 100644 --- a/zio-http/src/test/scala/zhttp/service/WebSocketServerSpec.scala +++ b/zio-http/src/test/scala/zhttp/service/WebSocketServerSpec.scala @@ -5,6 +5,7 @@ import zhttp.internal.{DynamicServer, HttpRunnableSpec} import zhttp.service.server._ import zhttp.socket.{Socket, WebSocketFrame} import zio._ +import zio.stream.ZStream import zio.test.Assertion.equalTo import zio.test.TestAspect.timeout import zio.test._ @@ -16,10 +17,10 @@ object WebSocketServerSpec extends HttpRunnableSpec { private val app = serve { DynamicServer.app } override def spec = suite("Server") { - app.as(List(websocketSpec)).useNow + app.as(List(websocketServerSpec, websocketFrameSpec)).useNow }.provideCustomLayerShared(env) @@ timeout(10 seconds) - def websocketSpec = suite("WebSocket Server") { + def websocketServerSpec = suite("WebSocketServer") { suite("connections") { test("Multiple websocket upgrades") { val app = Socket.succeed(WebSocketFrame.text("BAR")).toHttp.deployWS @@ -31,4 +32,17 @@ object WebSocketServerSpec extends HttpRunnableSpec { } } } + + def websocketFrameSpec = suite("WebSocketFrameSpec") { + test("binary") { + val socket = Socket.collect[WebSocketFrame] { case WebSocketFrame.Binary(buffer) => + ZStream.succeed(WebSocketFrame.Binary(buffer)) + } + + val app = socket.toHttp.deployWS + val open = Socket.succeed(WebSocketFrame.binary(Chunk.fromArray("Hello, World".getBytes))) + + assertM(app(socket.toSocketApp.onOpen(open)).map(_.status))(equalTo(Status.SWITCHING_PROTOCOLS)) + } + } }