Bail quickly if there is no data to read

This commit is contained in:
J. Nick Koston
2025-05-15 12:47:50 -05:00
parent e8e0e34702
commit 488dc40f2e
6 changed files with 113 additions and 49 deletions

View File

@@ -298,6 +298,18 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) {
return APIError::BAD_ARG;
}
// Only check for available data when starting a new frame read
if (rx_header_buf_len_ == 0) {
ssize_t available = socket_->available();
if (available == 0) {
return APIError::WOULD_BLOCK;
} else if (available == -1) {
state_ = State::FAILED;
HELPER_LOG("Socket available failed with errno %d", errno);
return APIError::SOCKET_READ_FAILED;
}
}
// read header
if (rx_header_buf_len_ < 3) {
// no header information yet
@@ -815,64 +827,84 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) {
return APIError::BAD_ARG;
}
// Only check for available data when starting a new frame read
if (rx_header_buf_pos_ == 0) {
ssize_t available = socket_->available();
if (available == 0) {
return APIError::WOULD_BLOCK;
} else if (available == -1) {
state_ = State::FAILED;
HELPER_LOG("Socket read_available failed with errno %d", errno);
return APIError::SOCKET_READ_FAILED;
}
}
// read header
while (!rx_header_parsed_) {
uint8_t data;
// Reading one byte at a time is fastest in practice for ESP32 when
// there is no data on the wire (which is the common case).
// This results in faster failure detection compared to
// attempting to read multiple bytes at once.
ssize_t received = this->socket_->read(&data, 1);
if (received == -1) {
if (errno == EWOULDBLOCK || errno == EAGAIN) {
return APIError::WOULD_BLOCK;
}
state_ = State::FAILED;
HELPER_LOG("Socket read failed with errno %d", errno);
return APIError::SOCKET_READ_FAILED;
} else if (received == 0) {
state_ = State::FAILED;
HELPER_LOG("Connection closed");
return APIError::CONNECTION_CLOSED;
}
// Successfully read a byte
// Process byte according to current buffer position
if (rx_header_buf_pos_ == 0) { // Case 1: First byte (indicator byte)
if (data != 0x00) {
if (rx_header_buf_pos_ == 0) {
// Try to read the first 3 bytes at once (indicator + 2 initial bytes)
// We can safely read 3 bytes because the minimum header is indicator + 2 varint bytes
ssize_t received = socket_->read(rx_header_buf_, 3);
if (received == -1) {
if (errno == EWOULDBLOCK || errno == EAGAIN) {
return APIError::WOULD_BLOCK;
}
state_ = State::FAILED;
HELPER_LOG("Bad indicator byte %u", data);
HELPER_LOG("Socket read failed with errno %d", errno);
return APIError::SOCKET_READ_FAILED;
} else if (received == 0) {
state_ = State::FAILED;
HELPER_LOG("Connection closed");
return APIError::CONNECTION_CLOSED;
}
// Validate indicator byte
if (rx_header_buf_[0] != 0x00) {
state_ = State::FAILED;
HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]);
return APIError::BAD_INDICATOR;
}
// We don't store the indicator byte, just increment position
rx_header_buf_pos_ = 1; // Set to 1 directly
continue; // Need more bytes before we can parse
}
// Check buffer overflow before storing
if (rx_header_buf_pos_ == 5) { // Case 2: Buffer would overflow (5 bytes is max allowed)
state_ = State::FAILED;
HELPER_LOG("Header buffer overflow");
return APIError::BAD_DATA_PACKET;
}
// Update our position based on how many bytes we got
rx_header_buf_pos_ = received;
// Store byte in buffer (adjust index to account for skipped indicator byte)
rx_header_buf_[rx_header_buf_pos_ - 1] = data;
// If we didn't get all 3 bytes, need more
if (rx_header_buf_pos_ < 3)
continue;
} else {
// For additional bytes (beyond the first 3), read one at a time
// Check buffer overflow before reading
if (rx_header_buf_pos_ >= 6) { // 6 bytes is max allowed (indicator + 5 varint bytes)
state_ = State::FAILED;
HELPER_LOG("Header buffer overflow");
return APIError::BAD_DATA_PACKET;
}
// Increment position after storing
rx_header_buf_pos_++;
// Read one byte at a time to avoid reading into message body
ssize_t received = socket_->read(&rx_header_buf_[rx_header_buf_pos_], 1);
if (received == -1) {
if (errno == EWOULDBLOCK || errno == EAGAIN) {
return APIError::WOULD_BLOCK;
}
state_ = State::FAILED;
HELPER_LOG("Socket read failed with errno %d", errno);
return APIError::SOCKET_READ_FAILED;
} else if (received == 0) {
state_ = State::FAILED;
HELPER_LOG("Connection closed");
return APIError::CONNECTION_CLOSED;
}
// Case 3: If we only have one varint byte, we need more
if (rx_header_buf_pos_ == 2) { // Have read indicator + 1 byte
continue; // Need more bytes before we can parse
// Increment position
rx_header_buf_pos_++;
}
// At this point, we have at least 3 bytes total:
// - Validated indicator byte (0x00) but not stored
// - Validated indicator byte (0x00) in the first position
// - At least 2 bytes in the buffer for the varints
// Buffer layout:
// First 1-3 bytes: Message size varint (variable length)
// Byte 0: Indicator byte (0x00)
// Bytes 1-3: Message size varint (variable length)
// - 2 bytes would only allow up to 16383, which is less than noise's 65535
// - 3 bytes allows up to 2097151, ensuring we support at least as much as noise
// Remaining 1-2 bytes: Message type varint (variable length)
@@ -880,7 +912,7 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) {
// we'll continue reading more bytes.
uint32_t consumed = 0;
auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[0], rx_header_buf_pos_ - 1, &consumed);
auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[1], rx_header_buf_pos_ - 1, &consumed);
if (!msg_size_varint.has_value()) {
// not enough data there yet
continue;
@@ -888,7 +920,8 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) {
rx_header_parsed_len_ = msg_size_varint->as_uint32();
auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[consumed], rx_header_buf_pos_ - 1 - consumed, &consumed);
auto msg_type_varint =
ProtoVarInt::parse(&rx_header_buf_[1 + consumed], rx_header_buf_pos_ - 1 - consumed, &consumed);
if (!msg_type_varint.has_value()) {
// not enough data there yet
continue;

View File

@@ -219,14 +219,14 @@ class APIPlaintextFrameHelper : public APIFrameHelper {
protected:
APIError try_read_frame_(ParsedFrame *frame);
// Fixed-size header buffer for plaintext protocol:
// We only need space for the two varints since we validate the indicator byte separately.
// We need space for the indicator byte and the two varints.
// To match noise protocol's maximum message size (65535), we need:
// 3 bytes for message size varint (supports up to 2097151) + 2 bytes for message type varint
// 1 byte for indicator + 3 bytes for message size varint (supports up to 2097151) + 2 bytes for message type varint
//
// While varints could theoretically be up to 10 bytes each for 64-bit values,
// attempting to process messages with headers that large would likely crash the
// ESP32 due to memory constraints.
uint8_t rx_header_buf_[5]; // 5 bytes for varints (3 for size + 2 for type)
uint8_t rx_header_buf_[6]; // 1 byte for indicator + 5 bytes for varints (3 for size + 2 for type)
uint8_t rx_header_buf_pos_ = 0;
bool rx_header_parsed_ = false;
uint32_t rx_header_parsed_type_ = 0;

View File

@@ -101,6 +101,13 @@ class BSDSocketImpl : public Socket {
return ::readv(fd_, iov, iovcnt);
#endif
}
ssize_t available() override {
int bytes_available = 0;
int ret = ::ioctl(fd_, FIONREAD, &bytes_available);
if (ret == -1)
return -1;
return bytes_available;
}
ssize_t write(const void *buf, size_t len) override { return ::write(fd_, buf, len); }
ssize_t send(void *buf, size_t len, int flags) { return ::send(fd_, buf, len, flags); }
ssize_t writev(const struct iovec *iov, int iovcnt) override {

View File

@@ -380,6 +380,22 @@ class LWIPRawImpl : public Socket {
}
return ret;
}
ssize_t available() override {
if (pcb_ == nullptr) {
errno = ECONNRESET;
return -1;
}
// Check if we have data in the receive buffer
if (rx_buf_ != nullptr) {
size_t pb_len = rx_buf_->len;
size_t pb_left = pb_len - rx_buf_offset_;
return pb_left;
}
// No data in buffer
return 0;
}
ssize_t internal_write(const void *buf, size_t len) {
if (pcb_ == nullptr) {
errno = ECONNRESET;

View File

@@ -81,6 +81,13 @@ class LwIPSocketImpl : public Socket {
int listen(int backlog) override { return lwip_listen(fd_, backlog); }
ssize_t read(void *buf, size_t len) override { return lwip_read(fd_, buf, len); }
ssize_t readv(const struct iovec *iov, int iovcnt) override { return lwip_readv(fd_, iov, iovcnt); }
ssize_t available() override {
int bytes_available = 0;
int ret = lwip_ioctl(fd_, FIONREAD, &bytes_available);
if (ret == -1)
return -1;
return bytes_available;
}
ssize_t write(const void *buf, size_t len) override { return lwip_write(fd_, buf, len); }
ssize_t send(void *buf, size_t len, int flags) { return lwip_send(fd_, buf, len, flags); }
ssize_t writev(const struct iovec *iov, int iovcnt) override { return lwip_writev(fd_, iov, iovcnt); }

View File

@@ -38,6 +38,7 @@ class Socket {
virtual ssize_t recvfrom(void *buf, size_t len, sockaddr *addr, socklen_t *addr_len) = 0;
#endif
virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
virtual ssize_t available() = 0; // Returns number of bytes available to read without blocking, or -1 on error
virtual ssize_t write(const void *buf, size_t len) = 0;
virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
virtual ssize_t sendto(const void *buf, size_t len, int flags, const struct sockaddr *to, socklen_t tolen) = 0;