Skip to content

Commit

Permalink
Fix streamWriteOffset's in QuicStreamAsyncTransport
Browse files Browse the repository at this point in the history
Summary: Now we track the write offset from QSAT's PoV, rather than querying the QuicSocket for QUIC's perspective.  Previously, the write callbacks were firing too early, leading to problems.

Reviewed By: mjoras

Differential Revision: D60305967

fbshipit-source-id: ea0470e1d2654848164f4edcfbd5a72a8f33d064
  • Loading branch information
afrind authored and facebook-github-bot committed Aug 1, 2024
1 parent b70f3d2 commit e1675e2
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 94 deletions.
83 changes: 27 additions & 56 deletions quic/api/QuicStreamAsyncTransport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ void QuicStreamAsyncTransport::setStreamId(quic::StreamId id) {
for (auto& p : writeCallbacks_) {
p.first += *streamWriteOffset;
}
streamWriteOffset_ += *streamWriteOffset;
sock_->notifyPendingWriteOnStream(*id_, this);
}
}
Expand Down Expand Up @@ -95,24 +96,14 @@ folly::AsyncTransport::ReadCallback* QuicStreamAsyncTransport::getReadCallback()
}

void QuicStreamAsyncTransport::addWriteCallback(
AsyncTransport::WriteCallback* callback,
size_t offset) {
AsyncTransport::WriteCallback* callback) {
size_t size = writeBuf_.chainLength();
writeCallbacks_.emplace_back(offset + size, callback);
writeCallbacks_.emplace_back(streamWriteOffset_ + size, callback);
if (id_) {
sock_->notifyPendingWriteOnStream(*id_, this);
}
}

void QuicStreamAsyncTransport::handleWriteOffsetError(
AsyncTransport::WriteCallback* callback,
LocalErrorCode error) {
folly::AsyncSocketException ex(
folly::AsyncSocketException::UNKNOWN,
folly::to<std::string>("Quic write error: ", toString(error)));
callback->writeErr(0, ex);
}

bool QuicStreamAsyncTransport::handleWriteStateError(
AsyncTransport::WriteCallback* callback) {
if (writeEOF_ != EOFState::NOT_SEEN) {
Expand All @@ -134,14 +125,6 @@ bool QuicStreamAsyncTransport::handleWriteStateError(
}
}

folly::Expected<size_t, LocalErrorCode>
QuicStreamAsyncTransport::getStreamWriteOffset() const {
if (!id_) {
return 0;
}
return sock_->getStreamWriteOffset(*id_);
}

void QuicStreamAsyncTransport::write(
AsyncTransport::WriteCallback* callback,
const void* buf,
Expand All @@ -150,13 +133,8 @@ void QuicStreamAsyncTransport::write(
if (handleWriteStateError(callback)) {
return;
}
auto streamWriteOffset = getStreamWriteOffset();
if (streamWriteOffset.hasError()) {
handleWriteOffsetError(callback, streamWriteOffset.error());
return;
}
writeBuf_.append(folly::IOBuf::wrapBuffer(buf, bytes));
addWriteCallback(callback, *streamWriteOffset);
addWriteCallback(callback);
}

void QuicStreamAsyncTransport::writev(
Expand All @@ -167,15 +145,10 @@ void QuicStreamAsyncTransport::writev(
if (handleWriteStateError(callback)) {
return;
}
auto streamWriteOffset = getStreamWriteOffset();
if (streamWriteOffset.hasError()) {
handleWriteOffsetError(callback, streamWriteOffset.error());
return;
}
for (size_t i = 0; i < count; i++) {
writeBuf_.append(folly::IOBuf::wrapBuffer(vec[i].iov_base, vec[i].iov_len));
}
addWriteCallback(callback, *streamWriteOffset);
addWriteCallback(callback);
}

void QuicStreamAsyncTransport::writeChain(
Expand All @@ -185,13 +158,8 @@ void QuicStreamAsyncTransport::writeChain(
if (handleWriteStateError(callback)) {
return;
}
auto streamWriteOffset = getStreamWriteOffset();
if (streamWriteOffset.hasError()) {
handleWriteOffsetError(callback, streamWriteOffset.error());
return;
}
writeBuf_.append(std::move(buf));
addWriteCallback(callback, *streamWriteOffset);
addWriteCallback(callback);
}

void QuicStreamAsyncTransport::close() {
Expand Down Expand Up @@ -323,12 +291,11 @@ bool QuicStreamAsyncTransport::isEorTrackingEnabled() const {
void QuicStreamAsyncTransport::setEorTracking(bool /*track*/) {}

size_t QuicStreamAsyncTransport::getAppBytesWritten() const {
auto res = getStreamWriteOffset();
// TODO: track written bytes to have it available after QUIC stream closure
return res.hasError() ? 0 : res.value();
return streamWriteOffset_ + writeBuf_.chainLength();
}

size_t QuicStreamAsyncTransport::getRawBytesWritten() const {
// TOOD: should this include QUIC framing overhead?
return getAppBytesWritten();
}

Expand Down Expand Up @@ -438,23 +405,14 @@ void QuicStreamAsyncTransport::handleRead() {
}

void QuicStreamAsyncTransport::send(uint64_t maxToSend) {
VLOG(4) << __func__ << " " << maxToSend;
CHECK(id_);
// overkill until there are delivery cbs
folly::DelayedDestruction::DestructorGuard dg(this);
uint64_t toSend =
std::min(maxToSend, folly::to<uint64_t>(writeBuf_.chainLength()));
auto streamWriteOffset = sock_->getStreamWriteOffset(*id_);
if (streamWriteOffset.hasError()) {
// handle error
folly::AsyncSocketException ex(
folly::AsyncSocketException::UNKNOWN,
folly::to<std::string>(
"Quic write error: ", toString(streamWriteOffset.error())));
failWrites(ex);
return;
}

uint64_t sentOffset = *streamWriteOffset + toSend;
uint64_t sentOffset = streamWriteOffset_ + toSend;
bool writeEOF =
(writeEOF_ == EOFState::QUEUED && writeBuf_.chainLength() == toSend);
auto res = sock_->writeChain(
Expand All @@ -472,16 +430,27 @@ void QuicStreamAsyncTransport::send(uint64_t maxToSend) {
if (writeEOF) {
writeEOF_ = EOFState::DELIVERED;
} else if (writeBuf_.chainLength()) {
sock_->notifyPendingWriteOnStream(*id_, this);
VLOG(4) << __func__ << " buffered data, requesting callback";
auto res2 = sock_->notifyPendingWriteOnStream(*id_, this);
if (!res2) {
folly::AsyncSocketException ex(
folly::AsyncSocketException::UNKNOWN,
folly::to<std::string>("Quic write error: ", toString(res2.error())));
failWrites(ex);
return;
}
}
// not actually sent. Mirrors AsyncSocket and invokes when data is in
// transport buffers
invokeWriteCallbacks(sentOffset);
streamWriteOffset_ = sentOffset;
invokeWriteCallbacks();
}

void QuicStreamAsyncTransport::invokeWriteCallbacks(size_t sentOffset) {
void QuicStreamAsyncTransport::invokeWriteCallbacks() {
VLOG(4) << __func__ << " " << streamWriteOffset_;
while (!writeCallbacks_.empty() &&
writeCallbacks_.front().first <= sentOffset) {
writeCallbacks_.front().first <= streamWriteOffset_) {
VLOG(4) << __func__ << " " << writeCallbacks_.front().first;
auto wcb = writeCallbacks_.front().second;
writeCallbacks_.pop_front();
wcb->writeSuccess();
Expand All @@ -493,13 +462,15 @@ void QuicStreamAsyncTransport::invokeWriteCallbacks(size_t sentOffset) {

void QuicStreamAsyncTransport::failWrites(
const folly::AsyncSocketException& ex) {
VLOG(4) << __func__;
while (!writeCallbacks_.empty()) {
auto& front = writeCallbacks_.front();
auto wcb = front.second;
writeCallbacks_.pop_front();
// TODO: track bytesWritten, when buffer was split it may not be 0
wcb->writeErr(0, ex);
}
writeEOF_ = EOFState::ERROR;
}

void QuicStreamAsyncTransport::onStreamWriteReady(
Expand Down
11 changes: 4 additions & 7 deletions quic/api/QuicStreamAsyncTransport.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,20 @@ class QuicStreamAsyncTransport : public folly::AsyncTransport,
void runLoopCallback() noexcept override;

// Utils
void addWriteCallback(AsyncTransport::WriteCallback* callback, size_t offset);
void handleWriteOffsetError(
AsyncTransport::WriteCallback* callback,
LocalErrorCode error);
void addWriteCallback(AsyncTransport::WriteCallback* callback);
bool handleWriteStateError(AsyncTransport::WriteCallback* callback);
void handleRead();
void send(uint64_t maxToSend);
folly::Expected<size_t, LocalErrorCode> getStreamWriteOffset() const;
void invokeWriteCallbacks(size_t sentOffset);
void invokeWriteCallbacks();
void failWrites(const folly::AsyncSocketException& ex);
void closeNowImpl(folly::AsyncSocketException&& ex);

enum class CloseState { OPEN, CLOSING, CLOSED };
CloseState state_{CloseState::OPEN};
std::shared_ptr<quic::QuicSocket> sock_;
Optional<quic::StreamId> id_;
enum class EOFState { NOT_SEEN, QUEUED, DELIVERED };
uint64_t streamWriteOffset_{0};
enum class EOFState { NOT_SEEN, QUEUED, DELIVERED, ERROR };
EOFState readEOF_{EOFState::NOT_SEEN};
EOFState writeEOF_{EOFState::NOT_SEEN};
AsyncTransport::ReadCallback* readCb_{nullptr};
Expand Down
95 changes: 64 additions & 31 deletions quic/api/test/QuicStreamAsyncTransportTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class QuicStreamAsyncTransportTest : public Test {
folly::test::MockReadCallback readCb;
QuicStreamAsyncTransport::UniquePtr transport;
std::array<uint8_t, 1024> buf;
bool echoFirstReadOnly{false};
uint8_t serverDone{2}; // need to finish reads & writes
};

Expand Down Expand Up @@ -83,52 +84,52 @@ class QuicStreamAsyncTransportTest : public Test {
serverAddr_ = server_->getAddress();
}

void serverExpectNewBidiStreamFromClient() {
void serverExpectNewBidiStreamFromClient(bool echoFirstReadOnly = true) {
EXPECT_CALL(serverConnectionCB_, onNewBidirectionalStream(_))
.WillOnce(Invoke([this](StreamId id) {
.WillOnce(Invoke([this, echoFirstReadOnly](StreamId id) {
auto stream = std::make_unique<Stream>();
stream->transport =
QuicStreamAsyncTransport::createWithExistingStream(
serverSocket_, id);
stream->echoFirstReadOnly = echoFirstReadOnly;

auto& transport = stream->transport;
auto& readCb = stream->readCb;
auto& writeCb = stream->writeCb;
auto& streamBuf = stream->buf;
auto& serverDone = stream->serverDone;
streams_[id] = std::move(stream);

EXPECT_CALL(readCb, readEOF_())
.WillOnce(Invoke([&transport, &serverDone] {
if (--serverDone == 0) {
transport->close();
.WillOnce(Invoke([stream = stream.get()] {
stream->transport->shutdownWrite();
if (--stream->serverDone == 0) {
stream->transport->close();
}
}));
EXPECT_CALL(readCb, isBufferMovable_()).WillRepeatedly(Return(false));
EXPECT_CALL(readCb, getReadBuffer(_, _))
.WillRepeatedly(Invoke([&streamBuf](void** buf, size_t* len) {
*buf = streamBuf.data();
*len = streamBuf.size();
}));
.WillRepeatedly(
Invoke([stream = stream.get()](void** buf, size_t* len) {
*buf = stream->buf.data();
*len = stream->buf.size();
}));
EXPECT_CALL(readCb, readDataAvailable_(_))
.WillRepeatedly(Invoke(
[&streamBuf, &serverDone, &writeCb, &transport](auto len) {
auto echoData = folly::IOBuf::copyBuffer("echo ");
echoData->appendChain(
folly::IOBuf::wrapBuffer(streamBuf.data(), len));
EXPECT_CALL(writeCb, writeSuccess_())
.WillOnce(Return())
.RetiresOnSaturation();
if (transport->good()) {
// Echo the first readDataAvailable_ only
transport->writeChain(&writeCb, std::move(echoData));
transport->shutdownWrite();
if (--serverDone == 0) {
transport->close();
}
.WillRepeatedly(Invoke([stream = stream.get()](auto len) {
auto echoData = folly::IOBuf::copyBuffer("echo ");
echoData->appendChain(
folly::IOBuf::wrapBuffer(stream->buf.data(), len));
EXPECT_CALL(stream->writeCb, writeSuccess_())
.WillOnce(Return())
.RetiresOnSaturation();
if (stream->transport->good()) {
stream->transport->writeChain(
&stream->writeCb, std::move(echoData));
if (stream->echoFirstReadOnly) {
stream->transport->shutdownWrite();
if (--stream->serverDone == 0) {
stream->transport->close();
}
}));
transport->setReadCB(&readCb);
}
}
}));
stream->transport->setReadCB(&readCb);
streams_[id] = std::move(stream);
}))
.RetiresOnSaturation();
}
Expand Down Expand Up @@ -337,4 +338,36 @@ TEST_F(QuicStreamAsyncTransportTest, closeNow) {
clientEvb_.loopOnce();
}

// Test to ensure that write callbacks are correctly scheduled even when
// write invoked from writeSuccess
TEST_F(QuicStreamAsyncTransportTest, WriteFromWriteSuccess) {
serverExpectNewBidiStreamFromClient(false);
auto clientStream = createClient();
folly::test::MockWriteCallback writeCb1, writeCb2;
bool wcb2Fire = false;
EXPECT_CALL(writeCb1, writeSuccess_()).WillOnce(Invoke([&] {
// write from writeSuccess, should get correct offset
clientStream->transport->writeChain(&writeCb2, buildRandomInputData(1000));
}));
EXPECT_CALL(writeCb2, writeSuccess_()).WillOnce(Invoke([&] {
wcb2Fire = true;
clientStream->transport->shutdownWrite();
}));
// fill fc window exactly,
clientStream->transport->writeChain(&writeCb1, buildRandomInputData(66560));
clientEvb_.loopOnce();
EXPECT_FALSE(wcb2Fire);
EXPECT_EQ(clientStream->transport->getAppBytesWritten(), 67560);

EXPECT_CALL(clientStream->readCb, readDataAvailable_(_))
.WillRepeatedly(Return());
bool done = false;
EXPECT_CALL(clientStream->readCb, readEOF_()).WillOnce(Assign(&done, true));
// eventually all gets flushed
while (!done) {
clientEvb_.loopOnce();
}
EXPECT_TRUE(wcb2Fire);
}

} // namespace quic::test

0 comments on commit e1675e2

Please sign in to comment.