batch state sends
This commit is contained in:
@@ -170,6 +170,12 @@ void APIConnection::loop() {
|
||||
this->deferred_message_queue_.process_queue();
|
||||
}
|
||||
|
||||
// Process deferred state batch if scheduled
|
||||
if (this->deferred_state_batch_.batch_scheduled &&
|
||||
App.get_loop_component_start_time() - this->deferred_state_batch_.batch_start_time >= STATE_BATCH_DELAY_MS) {
|
||||
this->process_state_batch_();
|
||||
}
|
||||
|
||||
if (!this->list_entities_iterator_.completed())
|
||||
this->list_entities_iterator_.advance();
|
||||
if (!this->initial_state_iterator_.completed() && this->list_entities_iterator_.completed())
|
||||
@@ -1650,7 +1656,13 @@ bool APIConnection::try_to_clear_buffer(bool log_out_of_space) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) {
|
||||
bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint16_t message_type) {
|
||||
// If we're in batch mode, just capture the message type and return success
|
||||
if (this->batch_mode_) {
|
||||
this->captured_message_type_ = message_type;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!this->try_to_clear_buffer(message_type != 29)) { // SubscribeLogsResponse
|
||||
return false;
|
||||
}
|
||||
@@ -1684,6 +1696,139 @@ void APIConnection::on_fatal_error() {
|
||||
this->remove_ = true;
|
||||
}
|
||||
|
||||
void APIConnection::DeferredStateBatch::add_update(void *entity, send_message_t send_func) {
|
||||
// Check if we already have an update for this entity
|
||||
for (auto &update : updates) {
|
||||
if (update.entity == entity && update.send_func == send_func) {
|
||||
// Update timestamp to latest
|
||||
update.timestamp = App.get_loop_component_start_time();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Add new update
|
||||
updates.push_back({entity, send_func, App.get_loop_component_start_time()});
|
||||
}
|
||||
|
||||
void APIConnection::schedule_state_batch_() {
|
||||
if (!this->deferred_state_batch_.batch_scheduled) {
|
||||
this->deferred_state_batch_.batch_scheduled = true;
|
||||
this->deferred_state_batch_.batch_start_time = App.get_loop_component_start_time();
|
||||
}
|
||||
}
|
||||
|
||||
void APIConnection::process_state_batch_() {
|
||||
if (this->deferred_state_batch_.empty()) {
|
||||
this->deferred_state_batch_.batch_scheduled = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to clear buffer first
|
||||
if (!this->helper_->can_write_without_blocking()) {
|
||||
// Can't write now, defer everything to the regular deferred queue
|
||||
for (const auto &update : this->deferred_state_batch_.updates) {
|
||||
this->deferred_message_queue_.defer(update.entity, update.send_func);
|
||||
}
|
||||
this->deferred_state_batch_.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
// Enable batch mode to capture message types
|
||||
this->batch_mode_ = true;
|
||||
|
||||
// Track packet information (type, offset, length)
|
||||
std::vector<std::tuple<uint16_t, uint32_t, uint16_t>> packet_info;
|
||||
size_t total_size = 0;
|
||||
size_t processed_count = 0;
|
||||
|
||||
// Create initial buffer with estimated size
|
||||
ProtoWriteBuffer batch_buffer = this->create_buffer(MAX_BATCH_SIZE_BYTES);
|
||||
|
||||
// Conservative estimate for minimum packet size: 6 byte header + 100 bytes minimum message + footer
|
||||
const uint16_t min_next_packet_size = 106 + this->helper_->frame_footer_size();
|
||||
|
||||
for (size_t i = 0; i < this->deferred_state_batch_.updates.size(); i++) {
|
||||
const auto &update = this->deferred_state_batch_.updates[i];
|
||||
|
||||
// For the first message, check if we have enough space for at least one message
|
||||
// Use conservative estimates: max header (6 bytes) + some payload + footer
|
||||
if (processed_count == 0) {
|
||||
// Always try to send at least one message
|
||||
} else if (total_size + min_next_packet_size > MAX_BATCH_SIZE_BYTES) {
|
||||
// For subsequent messages, check if we have reasonable space left
|
||||
// Probably won't fit, stop here
|
||||
break;
|
||||
}
|
||||
|
||||
// Save current buffer position before extending
|
||||
uint32_t msg_offset = 0;
|
||||
this->captured_message_type_ = 0;
|
||||
|
||||
// For messages after the first, extend the buffer with padding
|
||||
if (processed_count > 0) {
|
||||
msg_offset = static_cast<uint32_t>(this->proto_write_buffer_.size());
|
||||
batch_buffer = this->extend_buffer();
|
||||
}
|
||||
|
||||
// Try to encode the message
|
||||
if (!(this->*update.send_func)(update.entity)) {
|
||||
// Encoding failed, revert buffer to previous size
|
||||
this->proto_write_buffer_.resize(msg_offset);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get the captured message type
|
||||
uint16_t message_type = this->captured_message_type_;
|
||||
|
||||
// Calculate message length
|
||||
uint16_t msg_len =
|
||||
static_cast<uint16_t>(this->proto_write_buffer_.size() - msg_offset - this->helper_->frame_header_padding());
|
||||
|
||||
// Record packet info
|
||||
packet_info.push_back(std::make_tuple(message_type, msg_offset, msg_len));
|
||||
processed_count++;
|
||||
|
||||
// Calculate actual packet size including protocol overhead
|
||||
uint16_t packet_overhead = this->helper_->calculate_packet_overhead(message_type, msg_len);
|
||||
uint16_t packet_size = msg_len + packet_overhead;
|
||||
total_size += packet_size;
|
||||
}
|
||||
|
||||
// Disable batch mode
|
||||
this->batch_mode_ = false;
|
||||
|
||||
// Send all collected packets
|
||||
if (!packet_info.empty()) {
|
||||
// Add final footer space for Noise if needed
|
||||
if (this->helper_->frame_footer_size() > 0) {
|
||||
this->proto_write_buffer_.resize(this->proto_write_buffer_.size() + this->helper_->frame_footer_size());
|
||||
}
|
||||
|
||||
APIError err = this->helper_->write_protobuf_packets(batch_buffer, packet_info);
|
||||
if (err != APIError::OK && err != APIError::WOULD_BLOCK) {
|
||||
on_fatal_error();
|
||||
if (err == APIError::SOCKET_WRITE_FAILED && errno == ECONNRESET) {
|
||||
ESP_LOGW(TAG, "%s: Connection reset during batch write", this->client_combined_info_.c_str());
|
||||
} else {
|
||||
ESP_LOGW(TAG, "%s: Batch write failed %s errno=%d", this->client_combined_info_.c_str(), api_error_to_str(err),
|
||||
errno);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove processed updates from the batch
|
||||
if (processed_count < this->deferred_state_batch_.updates.size()) {
|
||||
// Some updates weren't processed, keep them for next batch
|
||||
this->deferred_state_batch_.updates.erase(this->deferred_state_batch_.updates.begin(),
|
||||
this->deferred_state_batch_.updates.begin() + processed_count);
|
||||
// Reschedule for remaining updates
|
||||
this->schedule_state_batch_();
|
||||
} else {
|
||||
// All updates processed
|
||||
this->deferred_state_batch_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
#endif
|
||||
|
||||
@@ -418,8 +418,30 @@ class APIConnection : public APIServerConnection {
|
||||
this->proto_write_buffer_.insert(this->proto_write_buffer_.begin(), header_padding, 0);
|
||||
return {&this->proto_write_buffer_};
|
||||
}
|
||||
|
||||
// Extend buffer for batching - adds padding for next message
|
||||
ProtoWriteBuffer extend_buffer() {
|
||||
// Get current size
|
||||
size_t current_size = this->proto_write_buffer_.size();
|
||||
|
||||
// Add padding for next message
|
||||
uint8_t header_padding = this->helper_->frame_header_padding();
|
||||
uint8_t footer_size = this->helper_->frame_footer_size();
|
||||
|
||||
// Add footer space for previous message (if using Noise)
|
||||
if (footer_size > 0) {
|
||||
this->proto_write_buffer_.resize(current_size + footer_size);
|
||||
current_size += footer_size;
|
||||
}
|
||||
|
||||
// Add header padding for next message
|
||||
this->proto_write_buffer_.resize(current_size + header_padding);
|
||||
|
||||
return {&this->proto_write_buffer_};
|
||||
}
|
||||
|
||||
bool try_to_clear_buffer(bool log_out_of_space);
|
||||
bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) override;
|
||||
bool send_buffer(ProtoWriteBuffer buffer, uint16_t message_type) override;
|
||||
|
||||
std::string get_client_combined_info() const { return this->client_combined_info_; }
|
||||
|
||||
@@ -439,10 +461,9 @@ class APIConnection : public APIServerConnection {
|
||||
bool send_state_(esphome::EntityBase *entity, send_message_t try_send_func) {
|
||||
if (!this->state_subscription_)
|
||||
return false;
|
||||
if (this->try_to_clear_buffer(true) && (this->*try_send_func)(entity)) {
|
||||
return true;
|
||||
}
|
||||
this->deferred_message_queue_.defer(entity, try_send_func);
|
||||
// Add to batch instead of sending immediately
|
||||
this->deferred_state_batch_.add_update(entity, try_send_func);
|
||||
this->schedule_state_batch_();
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -470,10 +491,10 @@ class APIConnection : public APIServerConnection {
|
||||
Args... args) {
|
||||
if (!this->state_subscription_)
|
||||
return false;
|
||||
if (this->try_to_clear_buffer(true) && (this->*try_send_state_func)(entity, state, args...)) {
|
||||
return true;
|
||||
}
|
||||
this->deferred_message_queue_.defer(entity, reinterpret_cast<send_message_t>(try_send_entity_func));
|
||||
// For state updates with values, we defer using the entity-only function
|
||||
// The current state will be read when the batch is processed
|
||||
this->deferred_state_batch_.add_update(entity, reinterpret_cast<send_message_t>(try_send_entity_func));
|
||||
this->schedule_state_batch_();
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -556,6 +577,39 @@ class APIConnection : public APIServerConnection {
|
||||
InitialStateIterator initial_state_iterator_;
|
||||
ListEntitiesIterator list_entities_iterator_;
|
||||
int state_subs_at_ = -1;
|
||||
|
||||
// State batching mechanism
|
||||
struct DeferredStateBatch {
|
||||
struct StateUpdate {
|
||||
void *entity;
|
||||
send_message_t send_func;
|
||||
uint32_t timestamp; // When this update was queued
|
||||
};
|
||||
|
||||
std::vector<StateUpdate> updates;
|
||||
uint32_t batch_start_time{0};
|
||||
bool batch_scheduled{false};
|
||||
|
||||
// Add update with deduplication - newer updates replace older ones for same entity
|
||||
void add_update(void *entity, send_message_t send_func);
|
||||
void clear() {
|
||||
updates.clear();
|
||||
batch_scheduled = false;
|
||||
batch_start_time = 0;
|
||||
}
|
||||
bool empty() const { return updates.empty(); }
|
||||
};
|
||||
|
||||
DeferredStateBatch deferred_state_batch_;
|
||||
static constexpr uint32_t STATE_BATCH_DELAY_MS = 10;
|
||||
static constexpr size_t MAX_BATCH_SIZE_BYTES = 1360; // MTU - 100 bytes safety margin
|
||||
|
||||
// Batch mode state for capturing message types
|
||||
bool batch_mode_{false};
|
||||
uint16_t captured_message_type_{0};
|
||||
|
||||
void schedule_state_batch_();
|
||||
void process_state_batch_();
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
|
||||
@@ -605,9 +605,22 @@ APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APINoiseFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
aerr = state_action_();
|
||||
std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
|
||||
uint16_t payload_len = static_cast<uint16_t>(raw_buffer->size() - frame_header_padding_);
|
||||
|
||||
// Resize to include MAC space (required for Noise encryption)
|
||||
raw_buffer->resize(raw_buffer->size() + frame_footer_size_);
|
||||
|
||||
// Use write_protobuf_packets with a single packet
|
||||
std::vector<std::tuple<uint16_t, uint32_t, uint16_t>> packets;
|
||||
packets.push_back(std::make_tuple(type, 0, payload_len));
|
||||
|
||||
return write_protobuf_packets(buffer, packets);
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::write_protobuf_packets(
|
||||
ProtoWriteBuffer buffer, const std::vector<std::tuple<uint16_t, uint32_t, uint16_t>> &packets) {
|
||||
APIError aerr = state_action_();
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
@@ -616,56 +629,66 @@ APIError APINoiseFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuf
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
|
||||
// Message data starts after padding
|
||||
uint16_t payload_len = raw_buffer->size() - frame_header_padding_;
|
||||
uint16_t padding = 0;
|
||||
uint16_t msg_len = 4 + payload_len + padding;
|
||||
|
||||
// We need to resize to include MAC space, but we already reserved it in create_buffer
|
||||
raw_buffer->resize(raw_buffer->size() + frame_footer_size_);
|
||||
|
||||
// Write the noise header in the padded area
|
||||
// Buffer layout:
|
||||
// [0] - 0x01 indicator byte
|
||||
// [1-2] - Size of encrypted payload (filled after encryption)
|
||||
// [3-4] - Message type (encrypted)
|
||||
// [5-6] - Payload length (encrypted)
|
||||
// [7...] - Actual payload data (encrypted)
|
||||
uint8_t *buf_start = raw_buffer->data();
|
||||
buf_start[0] = 0x01; // indicator
|
||||
// buf_start[1], buf_start[2] to be set later after encryption
|
||||
const uint8_t msg_offset = 3;
|
||||
buf_start[msg_offset + 0] = (uint8_t) (type >> 8); // type high byte
|
||||
buf_start[msg_offset + 1] = (uint8_t) type; // type low byte
|
||||
buf_start[msg_offset + 2] = (uint8_t) (payload_len >> 8); // data_len high byte
|
||||
buf_start[msg_offset + 3] = (uint8_t) payload_len; // data_len low byte
|
||||
// payload data is already in the buffer starting at position 7
|
||||
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
// The capacity parameter should be msg_len + frame_footer_size_ (MAC length) to allow space for encryption
|
||||
noise_buffer_set_inout(mbuf, buf_start + msg_offset, msg_len, msg_len + frame_footer_size_);
|
||||
err = noise_cipherstate_encrypt(send_cipher_, &mbuf);
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("noise_cipherstate_encrypt failed: %s", noise_err_to_str(err).c_str());
|
||||
return APIError::CIPHERSTATE_ENCRYPT_FAILED;
|
||||
if (packets.empty()) {
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
uint16_t total_len = 3 + mbuf.size;
|
||||
buf_start[1] = (uint8_t) (mbuf.size >> 8);
|
||||
buf_start[2] = (uint8_t) mbuf.size;
|
||||
std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
|
||||
std::vector<struct iovec> iovs;
|
||||
iovs.reserve(packets.size());
|
||||
|
||||
struct iovec iov;
|
||||
// Point iov_base to the beginning of the buffer (no unused padding in Noise)
|
||||
// We send the entire frame: indicator + size + encrypted(type + data_len + payload + MAC)
|
||||
iov.iov_base = buf_start;
|
||||
iov.iov_len = total_len;
|
||||
// We need to encrypt each packet in place
|
||||
for (const auto &packet : packets) {
|
||||
uint16_t type = std::get<0>(packet);
|
||||
uint32_t offset = std::get<1>(packet);
|
||||
uint16_t payload_len = std::get<2>(packet);
|
||||
uint16_t msg_len = 4 + payload_len; // type(2) + data_len(2) + payload
|
||||
|
||||
// write raw to not have two packets sent if NAGLE disabled
|
||||
return this->write_raw_(&iov, 1);
|
||||
// The buffer already has padding at offset
|
||||
uint8_t *buf_start = raw_buffer->data() + offset;
|
||||
|
||||
// Write noise header
|
||||
buf_start[0] = 0x01; // indicator
|
||||
// buf_start[1], buf_start[2] to be set after encryption
|
||||
|
||||
// Write message header (to be encrypted)
|
||||
const uint8_t msg_offset = 3;
|
||||
buf_start[msg_offset + 0] = (uint8_t) (type >> 8); // type high byte
|
||||
buf_start[msg_offset + 1] = (uint8_t) type; // type low byte
|
||||
buf_start[msg_offset + 2] = (uint8_t) (payload_len >> 8); // data_len high byte
|
||||
buf_start[msg_offset + 3] = (uint8_t) payload_len; // data_len low byte
|
||||
// payload data is already in the buffer starting at offset + 7
|
||||
|
||||
// Make sure we have space for MAC
|
||||
// The buffer should already have been sized appropriately
|
||||
|
||||
// Encrypt the message in place
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_inout(mbuf, buf_start + msg_offset, msg_len, msg_len + frame_footer_size_);
|
||||
|
||||
int err = noise_cipherstate_encrypt(send_cipher_, &mbuf);
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("noise_cipherstate_encrypt failed: %s", noise_err_to_str(err).c_str());
|
||||
return APIError::CIPHERSTATE_ENCRYPT_FAILED;
|
||||
}
|
||||
|
||||
// Fill in the encrypted size
|
||||
buf_start[1] = (uint8_t) (mbuf.size >> 8);
|
||||
buf_start[2] = (uint8_t) mbuf.size;
|
||||
|
||||
// Add iovec for this encrypted packet
|
||||
struct iovec iov;
|
||||
iov.iov_base = buf_start;
|
||||
iov.iov_len = 3 + mbuf.size; // indicator + size + encrypted data
|
||||
iovs.push_back(iov);
|
||||
}
|
||||
|
||||
// Send all encrypted packets in one writev call
|
||||
return this->write_raw_(iovs.data(), iovs.size());
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, uint16_t len) {
|
||||
uint8_t header[3];
|
||||
header[0] = 0x01; // indicator
|
||||
@@ -1004,65 +1027,95 @@ APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) {
|
||||
std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
|
||||
uint16_t payload_len = static_cast<uint16_t>(raw_buffer->size() - frame_header_padding_);
|
||||
|
||||
// Use write_protobuf_packets with a single packet
|
||||
std::vector<std::tuple<uint16_t, uint32_t, uint16_t>> packets;
|
||||
packets.push_back(std::make_tuple(type, 0, payload_len));
|
||||
|
||||
return write_protobuf_packets(buffer, packets);
|
||||
}
|
||||
|
||||
APIError APIPlaintextFrameHelper::write_protobuf_packets(
|
||||
ProtoWriteBuffer buffer, const std::vector<std::tuple<uint16_t, uint32_t, uint16_t>> &packets) {
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
|
||||
// Message data starts after padding (frame_header_padding_ = 6)
|
||||
uint16_t payload_len = static_cast<uint16_t>(raw_buffer->size() - frame_header_padding_);
|
||||
|
||||
// Calculate varint sizes for header components
|
||||
uint8_t size_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(payload_len));
|
||||
uint8_t type_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(type));
|
||||
uint8_t total_header_len = 1 + size_varint_len + type_varint_len;
|
||||
|
||||
if (total_header_len > frame_header_padding_) {
|
||||
// Header is too large to fit in the padding
|
||||
return APIError::BAD_ARG;
|
||||
if (packets.empty()) {
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
// Calculate where to start writing the header
|
||||
// The header starts at the latest possible position to minimize unused padding
|
||||
//
|
||||
// Example 1 (small values): total_header_len = 3, header_offset = 6 - 3 = 3
|
||||
// [0-2] - Unused padding
|
||||
// [3] - 0x00 indicator byte
|
||||
// [4] - Payload size varint (1 byte, for sizes 0-127)
|
||||
// [5] - Message type varint (1 byte, for types 0-127)
|
||||
// [6...] - Actual payload data
|
||||
//
|
||||
// Example 2 (medium values): total_header_len = 4, header_offset = 6 - 4 = 2
|
||||
// [0-1] - Unused padding
|
||||
// [2] - 0x00 indicator byte
|
||||
// [3-4] - Payload size varint (2 bytes, for sizes 128-16383)
|
||||
// [5] - Message type varint (1 byte, for types 0-127)
|
||||
// [6...] - Actual payload data
|
||||
//
|
||||
// Example 3 (large values): total_header_len = 6, header_offset = 6 - 6 = 0
|
||||
// [0] - 0x00 indicator byte
|
||||
// [1-3] - Payload size varint (3 bytes, for sizes 16384-2097151)
|
||||
// [4-5] - Message type varint (2 bytes, for types 128-32767)
|
||||
// [6...] - Actual payload data
|
||||
uint8_t *buf_start = raw_buffer->data();
|
||||
uint8_t header_offset = frame_header_padding_ - total_header_len;
|
||||
std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
|
||||
std::vector<struct iovec> iovs;
|
||||
iovs.reserve(packets.size());
|
||||
|
||||
// Write the plaintext header
|
||||
buf_start[header_offset] = 0x00; // indicator
|
||||
for (const auto &packet : packets) {
|
||||
uint16_t type = std::get<0>(packet);
|
||||
uint32_t offset = std::get<1>(packet);
|
||||
uint16_t payload_len = std::get<2>(packet);
|
||||
|
||||
// Encode size varint directly into buffer
|
||||
ProtoVarInt(payload_len).encode_to_buffer_unchecked(buf_start + header_offset + 1, size_varint_len);
|
||||
// Calculate varint sizes for header components
|
||||
uint8_t size_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(payload_len));
|
||||
uint8_t type_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(type));
|
||||
uint8_t total_header_len = 1 + size_varint_len + type_varint_len;
|
||||
|
||||
// Encode type varint directly into buffer
|
||||
ProtoVarInt(type).encode_to_buffer_unchecked(buf_start + header_offset + 1 + size_varint_len, type_varint_len);
|
||||
// Calculate where to start writing the header
|
||||
// The header starts at the latest possible position to minimize unused padding
|
||||
//
|
||||
// Example 1 (small values): total_header_len = 3, header_offset = 6 - 3 = 3
|
||||
// [0-2] - Unused padding
|
||||
// [3] - 0x00 indicator byte
|
||||
// [4] - Payload size varint (1 byte, for sizes 0-127)
|
||||
// [5] - Message type varint (1 byte, for types 0-127)
|
||||
// [6...] - Actual payload data
|
||||
//
|
||||
// Example 2 (medium values): total_header_len = 4, header_offset = 6 - 4 = 2
|
||||
// [0-1] - Unused padding
|
||||
// [2] - 0x00 indicator byte
|
||||
// [3-4] - Payload size varint (2 bytes, for sizes 128-16383)
|
||||
// [5] - Message type varint (1 byte, for types 0-127)
|
||||
// [6...] - Actual payload data
|
||||
//
|
||||
// Example 3 (large values): total_header_len = 6, header_offset = 6 - 6 = 0
|
||||
// [0] - 0x00 indicator byte
|
||||
// [1-3] - Payload size varint (3 bytes, for sizes 16384-2097151)
|
||||
// [4-5] - Message type varint (2 bytes, for types 128-32767)
|
||||
// [6...] - Actual payload data
|
||||
//
|
||||
// The message starts at offset + frame_header_padding_
|
||||
// So we write the header starting at offset + frame_header_padding_ - total_header_len
|
||||
uint8_t *buf_start = raw_buffer->data() + offset;
|
||||
uint32_t header_offset = frame_header_padding_ - total_header_len;
|
||||
|
||||
struct iovec iov;
|
||||
// Point iov_base to the beginning of our header (skip unused padding)
|
||||
// This ensures we only send the actual header and payload, not the empty padding bytes
|
||||
iov.iov_base = buf_start + header_offset;
|
||||
iov.iov_len = total_header_len + payload_len;
|
||||
// Write the plaintext header
|
||||
buf_start[header_offset] = 0x00; // indicator
|
||||
|
||||
return write_raw_(&iov, 1);
|
||||
// Encode size varint directly into buffer
|
||||
ProtoVarInt(payload_len).encode_to_buffer_unchecked(buf_start + header_offset + 1, size_varint_len);
|
||||
|
||||
// Encode type varint directly into buffer
|
||||
ProtoVarInt(type).encode_to_buffer_unchecked(buf_start + header_offset + 1 + size_varint_len, type_varint_len);
|
||||
|
||||
// Add iovec for this packet (header + payload)
|
||||
struct iovec iov;
|
||||
iov.iov_base = buf_start + header_offset;
|
||||
iov.iov_len = total_header_len + payload_len;
|
||||
iovs.push_back(iov);
|
||||
}
|
||||
|
||||
// Send all packets in one writev call
|
||||
return write_raw_(iovs.data(), iovs.size());
|
||||
}
|
||||
|
||||
uint16_t APIPlaintextFrameHelper::calculate_packet_overhead(uint16_t message_type, uint16_t payload_len) {
|
||||
// Calculate varint sizes
|
||||
uint8_t size_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(payload_len));
|
||||
uint8_t type_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(message_type));
|
||||
|
||||
// Plaintext overhead: indicator(1) + size_varint + type_varint + footer(0)
|
||||
return 1 + size_varint_len + type_varint_len + frame_footer_size_;
|
||||
}
|
||||
|
||||
#endif // USE_API_PLAINTEXT
|
||||
|
||||
@@ -87,10 +87,17 @@ class APIFrameHelper {
|
||||
// Give this helper a name for logging
|
||||
void set_log_info(std::string info) { info_ = std::move(info); }
|
||||
virtual APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) = 0;
|
||||
// Write multiple protobuf packets in a single operation
|
||||
// packets contains (message_type, offset, length) for each message in the buffer
|
||||
// The buffer contains all messages with appropriate padding before each
|
||||
virtual APIError write_protobuf_packets(ProtoWriteBuffer buffer,
|
||||
const std::vector<std::tuple<uint16_t, uint32_t, uint16_t>> &packets) = 0;
|
||||
// Get the frame header padding required by this protocol
|
||||
virtual uint8_t frame_header_padding() = 0;
|
||||
// Get the frame footer size required by this protocol
|
||||
virtual uint8_t frame_footer_size() = 0;
|
||||
// Calculate the actual packet overhead (header + footer) for a given message
|
||||
virtual uint16_t calculate_packet_overhead(uint16_t message_type, uint16_t payload_len) = 0;
|
||||
// Check if socket has data ready to read
|
||||
bool is_socket_ready() const { return socket_ != nullptr && socket_->ready(); }
|
||||
|
||||
@@ -182,10 +189,17 @@ class APINoiseFrameHelper : public APIFrameHelper {
|
||||
APIError loop() override;
|
||||
APIError read_packet(ReadPacketBuffer *buffer) override;
|
||||
APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) override;
|
||||
APIError write_protobuf_packets(ProtoWriteBuffer buffer,
|
||||
const std::vector<std::tuple<uint16_t, uint32_t, uint16_t>> &packets) override;
|
||||
// Get the frame header padding required by this protocol
|
||||
uint8_t frame_header_padding() override { return frame_header_padding_; }
|
||||
// Get the frame footer size required by this protocol
|
||||
uint8_t frame_footer_size() override { return frame_footer_size_; }
|
||||
// Calculate the actual packet overhead for Noise protocol
|
||||
uint16_t calculate_packet_overhead(uint16_t message_type, uint16_t payload_len) override {
|
||||
// Noise: fixed 3 byte header (indicator + 2-byte size) + 16 byte MAC
|
||||
return 3 + frame_footer_size_;
|
||||
}
|
||||
|
||||
protected:
|
||||
APIError state_action_();
|
||||
@@ -226,9 +240,13 @@ class APIPlaintextFrameHelper : public APIFrameHelper {
|
||||
APIError loop() override;
|
||||
APIError read_packet(ReadPacketBuffer *buffer) override;
|
||||
APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) override;
|
||||
APIError write_protobuf_packets(ProtoWriteBuffer buffer,
|
||||
const std::vector<std::tuple<uint16_t, uint32_t, uint16_t>> &packets) override;
|
||||
uint8_t frame_header_padding() override { return frame_header_padding_; }
|
||||
// Get the frame footer size required by this protocol
|
||||
uint8_t frame_footer_size() override { return frame_footer_size_; }
|
||||
// Calculate the actual packet overhead for Plaintext protocol
|
||||
uint16_t calculate_packet_overhead(uint16_t message_type, uint16_t payload_len) override; // Implemented in .cpp
|
||||
|
||||
protected:
|
||||
APIError try_read_frame_(ParsedFrame *frame);
|
||||
|
||||
@@ -360,11 +360,11 @@ class ProtoService {
|
||||
* @return A ProtoWriteBuffer object with the reserved size.
|
||||
*/
|
||||
virtual ProtoWriteBuffer create_buffer(uint32_t reserve_size) = 0;
|
||||
virtual bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) = 0;
|
||||
virtual bool send_buffer(ProtoWriteBuffer buffer, uint16_t message_type) = 0;
|
||||
virtual bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) = 0;
|
||||
|
||||
// Optimized method that pre-allocates buffer based on message size
|
||||
template<class C> bool send_message_(const C &msg, uint32_t message_type) {
|
||||
template<class C> bool send_message_(const C &msg, uint16_t message_type) {
|
||||
uint32_t msg_size = 0;
|
||||
msg.calculate_size(msg_size);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user