diff --git a/core/src/main/scala/epollcat/internal/ch/EpollAsyncSocketChannel.scala b/core/src/main/scala/epollcat/internal/ch/EpollAsyncSocketChannel.scala index 901ac30..31f0049 100644 --- a/core/src/main/scala/epollcat/internal/ch/EpollAsyncSocketChannel.scala +++ b/core/src/main/scala/epollcat/internal/ch/EpollAsyncSocketChannel.scala @@ -113,10 +113,11 @@ final class EpollAsyncSocketChannel private ( unit: TimeUnit, attachment: A, handler: CompletionHandler[Integer, _ >: A] - ): Unit = - if (readReady) { - val position = dst.position() - val count = dst.remaining() + ): Unit = { + val position = dst.position() + val count = dst.remaining() + if (readReady && count > 0) { + val hasArray = dst.hasArray() val buf = if (hasArray) dst.array() else new Array[Byte](count) val offset = if (hasArray) dst.arrayOffset() + position else 0 @@ -151,12 +152,15 @@ final class EpollAsyncSocketChannel private ( } go(buf.at(offset), count, 0) + } else if (count == 0) { + handler.completed(0, attachment) } else { readCallback = () => { readCallback = null read(dst, timeout, unit, attachment, handler) } } + } @stub def connect(remote: SocketAddress): Future[Void] = ??? @@ -305,56 +309,58 @@ final class EpollAsyncSocketChannel private ( unit: TimeUnit, attachment: A, handler: CompletionHandler[Integer, _ >: A] - ): Unit = if (outputShutdown) - handler.failed(new ClosedChannelException, attachment) - else if (writeReady) { + ): Unit = { val position = src.position() val count = src.remaining() + if (outputShutdown) + handler.failed(new ClosedChannelException, attachment) + else if (writeReady && count > 0) { + val hasArray = src.hasArray() + val buf = + if (hasArray) src.array() + else { + val buf = new Array[Byte](count) + src.get(buf) + buf + } + val offset = if (hasArray) src.arrayOffset() + position else 0 - val hasArray = src.hasArray() - val buf = - if (hasArray) src.array() - else { - val buf = new Array[Byte](count) - src.get(buf) - buf + def completed(total: Int): Unit = { + src.position(position + total) + handler.completed(total, attachment) } - val offset = if (hasArray) src.arrayOffset() + position else 0 - def completed(total: Int): Unit = { - src.position(position + total) - handler.completed(total, attachment) - } + @tailrec + def go(buf: Ptr[Byte], count: Int, total: Int): Unit = { + val wrote = + if (LinktimeInfo.isLinux) + posix.sys.socket.send(fd, buf, count.toULong, socket.MSG_NOSIGNAL).toInt + else + posix.unistd.write(fd, buf, count.toULong) - @tailrec - def go(buf: Ptr[Byte], count: Int, total: Int): Unit = { - val wrote = - if (LinktimeInfo.isLinux) - posix.sys.socket.send(fd, buf, count.toULong, socket.MSG_NOSIGNAL).toInt - else - posix.unistd.write(fd, buf, count.toULong) - - if (wrote == -1) { - val e = errno.errno - if (e == posix.errno.EAGAIN || e == posix.errno.EWOULDBLOCK) { - writeReady = false - completed(total) - } else - handler.failed(new RuntimeException(s"write: $e"), attachment) - } else if (wrote < count) - go(buf + wrote.toLong, count - wrote, total + wrote) - else // wrote == count - completed(total + wrote) - } + if (wrote == -1) { + val e = errno.errno + if (e == posix.errno.EAGAIN || e == posix.errno.EWOULDBLOCK) { + writeReady = false + completed(total) + } else + handler.failed(new RuntimeException(s"write: $e"), attachment) + } else if (wrote < count) + go(buf + wrote.toLong, count - wrote, total + wrote) + else // wrote == count + completed(total + wrote) + } - go(buf.at(offset), count, 0) - } else { - writeCallback = () => { - writeCallback = null - write(src, timeout, unit, attachment, handler) + go(buf.at(offset), count, 0) + } else if (count == 0) { + handler.completed(0, attachment) + } else { + writeCallback = () => { + writeCallback = null + write(src, timeout, unit, attachment, handler) + } } } - def getLocalAddress(): SocketAddress = SocketHelpers.getLocalAddress(fd) @stub diff --git a/tests/shared/src/test/scala/epollcat/TcpSuite.scala b/tests/shared/src/test/scala/epollcat/TcpSuite.scala index f14b8c4..82aeb27 100644 --- a/tests/shared/src/test/scala/epollcat/TcpSuite.scala +++ b/tests/shared/src/test/scala/epollcat/TcpSuite.scala @@ -154,6 +154,42 @@ class TcpSuite extends EpollcatSuite { } } + test("write is no-op when position == limit") { + IOServerSocketChannel + .open + .evalTap(_.bind(new InetSocketAddress("localhost", 0))) + .evalMap(_.localAddress) + .use { addr => + IOSocketChannel.open.use { ch => + for { + _ <- ch.connect(addr) + bb <- IO(ByteBuffer.allocate(1)) + _ <- IO(bb.position(1)) + wrote <- ch.write(bb) + _ <- IO(assertEquals(wrote, 0)) + } yield () + } + } + } + + test("read is no-op when position == limit") { + IOServerSocketChannel + .open + .evalTap(_.bind(new InetSocketAddress("localhost", 0))) + .evalMap(_.localAddress) + .use { addr => + IOSocketChannel.open.use { ch => + for { + _ <- ch.connect(addr) + bb <- IO(ByteBuffer.allocate(1)) + _ <- IO(bb.position(1)) + read <- ch.read(bb) + _ <- IO(assertEquals(read, 0)) + } yield () + } + } + } + test("options") { IOSocketChannel.open.use { ch => ch.setOption(StandardSocketOptions.SO_REUSEADDR, java.lang.Boolean.TRUE) *>