#include "api_frame_helper.h" #ifdef USE_API #include "esphome/core/log.h" #include "esphome/core/hal.h" #include "esphome/core/helpers.h" #include "esphome/core/application.h" #include "proto.h" #include "api_pb2_size.h" #include namespace esphome { namespace api { static const char *const TAG = "api.socket"; const char *api_error_to_str(APIError err) { // not using switch to ensure compiler doesn't try to build a big table out of it if (err == APIError::OK) { return "OK"; } else if (err == APIError::WOULD_BLOCK) { return "WOULD_BLOCK"; } else if (err == APIError::BAD_HANDSHAKE_PACKET_LEN) { return "BAD_HANDSHAKE_PACKET_LEN"; } else if (err == APIError::BAD_INDICATOR) { return "BAD_INDICATOR"; } else if (err == APIError::BAD_DATA_PACKET) { return "BAD_DATA_PACKET"; } else if (err == APIError::TCP_NODELAY_FAILED) { return "TCP_NODELAY_FAILED"; } else if (err == APIError::TCP_NONBLOCKING_FAILED) { return "TCP_NONBLOCKING_FAILED"; } else if (err == APIError::CLOSE_FAILED) { return "CLOSE_FAILED"; } else if (err == APIError::SHUTDOWN_FAILED) { return "SHUTDOWN_FAILED"; } else if (err == APIError::BAD_STATE) { return "BAD_STATE"; } else if (err == APIError::BAD_ARG) { return "BAD_ARG"; } else if (err == APIError::SOCKET_READ_FAILED) { return "SOCKET_READ_FAILED"; } else if (err == APIError::SOCKET_WRITE_FAILED) { return "SOCKET_WRITE_FAILED"; } else if (err == APIError::HANDSHAKESTATE_READ_FAILED) { return "HANDSHAKESTATE_READ_FAILED"; } else if (err == APIError::HANDSHAKESTATE_WRITE_FAILED) { return "HANDSHAKESTATE_WRITE_FAILED"; } else if (err == APIError::HANDSHAKESTATE_BAD_STATE) { return "HANDSHAKESTATE_BAD_STATE"; } else if (err == APIError::CIPHERSTATE_DECRYPT_FAILED) { return "CIPHERSTATE_DECRYPT_FAILED"; } else if (err == APIError::CIPHERSTATE_ENCRYPT_FAILED) { return "CIPHERSTATE_ENCRYPT_FAILED"; } else if (err == APIError::OUT_OF_MEMORY) { return "OUT_OF_MEMORY"; } else if (err == APIError::HANDSHAKESTATE_SETUP_FAILED) { return "HANDSHAKESTATE_SETUP_FAILED"; } else if (err == APIError::HANDSHAKESTATE_SPLIT_FAILED) { return "HANDSHAKESTATE_SPLIT_FAILED"; } else if (err == APIError::BAD_HANDSHAKE_ERROR_BYTE) { return "BAD_HANDSHAKE_ERROR_BYTE"; } else if (err == APIError::CONNECTION_CLOSED) { return "CONNECTION_CLOSED"; } return "UNKNOWN"; } // Common implementation for writing raw data to socket APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt) { // This method writes data to socket or buffers it // Returns APIError::OK if successful (or would block, but data has been buffered) // Returns APIError::SOCKET_WRITE_FAILED if socket write failed, and sets state to FAILED if (iovcnt == 0) return APIError::OK; // Nothing to do, success size_t total_write_len = 0; for (int i = 0; i < iovcnt; i++) { #ifdef HELPER_LOG_PACKETS ESP_LOGVV(TAG, "Sending raw: %s", format_hex_pretty(reinterpret_cast(iov[i].iov_base), iov[i].iov_len).c_str()); #endif total_write_len += iov[i].iov_len; } // Try to send any existing buffered data first bool tx_buf_empty = this->tx_buf_.empty(); while (!tx_buf_empty) { // Get the first buffer in the queue SendBuffer &front_buffer = this->tx_buf_.front(); // Try to send the remaining data in this buffer ssize_t sent = this->socket_->write(front_buffer.current_data(), front_buffer.remaining()); if (sent == -1) { if (errno != EWOULDBLOCK && errno != EAGAIN) { // Real socket error (not just would block) ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", this->info_.c_str(), errno); this->state_ = State::FAILED; return APIError::SOCKET_WRITE_FAILED; // Socket write failed } // Socket would block, we'll try again later and continue execution to append new data to the buffer break; } else if (sent == 0) { // Nothing sent but not an error break; } else if (static_cast(sent) < front_buffer.remaining()) { // Partially sent, update offset front_buffer.offset += sent; break; // Stop processing more buffers if we couldn't send a complete buffer } else { // Buffer completely sent, remove it from the queue this->tx_buf_.pop_front(); // Update empty status for the loop condition tx_buf_empty = this->tx_buf_.empty(); // Continue loop to try sending the next buffer } } // If we still have pending data, append the new data to the queue if (!tx_buf_empty) { // Add new data as a new buffer SendBuffer buffer; // Calculate total size needed buffer.data.reserve(total_write_len); // Copy all iov segments to the buffer for (int i = 0; i < iovcnt; i++) { const uint8_t *data = reinterpret_cast(iov[i].iov_base); buffer.data.insert(buffer.data.end(), data, data + iov[i].iov_len); } // Add to the queue this->tx_buf_.push_back(std::move(buffer)); return APIError::OK; // Success, data buffered } // Try to send directly if no buffered data ssize_t sent = this->socket_->writev(iov, iovcnt); if (sent == -1) { if (errno == EWOULDBLOCK || errno == EAGAIN) { // Socket would block, buffer the data SendBuffer buffer; buffer.data.reserve(total_write_len); for (int i = 0; i < iovcnt; i++) { const uint8_t *data = reinterpret_cast(iov[i].iov_base); buffer.data.insert(buffer.data.end(), data, data + iov[i].iov_len); } this->tx_buf_.push_back(std::move(buffer)); return APIError::OK; // Success, data buffered } // Socket error ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", this->info_.c_str(), errno); this->state_ = State::FAILED; return APIError::SOCKET_WRITE_FAILED; // Socket write failed } else if (static_cast(sent) < total_write_len) { // Partially sent, buffer the remaining data SendBuffer buffer; size_t to_consume = sent; size_t remaining = total_write_len - sent; buffer.data.reserve(remaining); for (int i = 0; i < iovcnt; i++) { if (to_consume >= iov[i].iov_len) { // This segment was fully sent to_consume -= iov[i].iov_len; } else { // This segment was partially sent or not sent at all const uint8_t *data = reinterpret_cast(iov[i].iov_base) + to_consume; size_t len = iov[i].iov_len - to_consume; buffer.data.insert(buffer.data.end(), data, data + len); to_consume = 0; } } this->tx_buf_.push_back(std::move(buffer)); return APIError::OK; // Success, remaining data buffered } return APIError::OK; // Success, all data sent } // Common implementation for trying to send buffered data APIError APIFrameHelper::try_send_tx_buf_() { // Try to send from tx_buf while (state_ != State::CLOSED && !this->tx_buf_.empty()) { // Get the first buffer in the queue SendBuffer &front_buffer = this->tx_buf_.front(); // Try to send the remaining data in this buffer ssize_t sent = this->socket_->write(front_buffer.current_data(), front_buffer.remaining()); if (sent == -1) { if (errno == EWOULDBLOCK || errno == EAGAIN) { // Socket would block, try again later break; } // Socket error state_ = State::FAILED; ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", this->info_.c_str(), errno); return APIError::SOCKET_WRITE_FAILED; } else if (sent == 0) { // No data sent but not an error, try again later break; } else if (static_cast(sent) < front_buffer.remaining()) { // Partially sent, update offset front_buffer.offset += sent; break; // Stop if we couldn't send a complete buffer } else { // Buffer completely sent, remove it from the queue this->tx_buf_.pop_front(); // Continue loop to try sending the next buffer } } return APIError::OK; } #define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->info_.c_str(), ##__VA_ARGS__) // uncomment to log raw packets //#define HELPER_LOG_PACKETS #ifdef USE_API_NOISE static const char *const PROLOGUE_INIT = "NoiseAPIInit"; /// Convert a noise error code to a readable error std::string noise_err_to_str(int err) { if (err == NOISE_ERROR_NO_MEMORY) return "NO_MEMORY"; if (err == NOISE_ERROR_UNKNOWN_ID) return "UNKNOWN_ID"; if (err == NOISE_ERROR_UNKNOWN_NAME) return "UNKNOWN_NAME"; if (err == NOISE_ERROR_MAC_FAILURE) return "MAC_FAILURE"; if (err == NOISE_ERROR_NOT_APPLICABLE) return "NOT_APPLICABLE"; if (err == NOISE_ERROR_SYSTEM) return "SYSTEM"; if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED) return "REMOTE_KEY_REQUIRED"; if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED) return "LOCAL_KEY_REQUIRED"; if (err == NOISE_ERROR_PSK_REQUIRED) return "PSK_REQUIRED"; if (err == NOISE_ERROR_INVALID_LENGTH) return "INVALID_LENGTH"; if (err == NOISE_ERROR_INVALID_PARAM) return "INVALID_PARAM"; if (err == NOISE_ERROR_INVALID_STATE) return "INVALID_STATE"; if (err == NOISE_ERROR_INVALID_NONCE) return "INVALID_NONCE"; if (err == NOISE_ERROR_INVALID_PRIVATE_KEY) return "INVALID_PRIVATE_KEY"; if (err == NOISE_ERROR_INVALID_PUBLIC_KEY) return "INVALID_PUBLIC_KEY"; if (err == NOISE_ERROR_INVALID_FORMAT) return "INVALID_FORMAT"; if (err == NOISE_ERROR_INVALID_SIGNATURE) return "INVALID_SIGNATURE"; return to_string(err); } /// Initialize the frame helper, returns OK if successful. APIError APINoiseFrameHelper::init() { if (state_ != State::INITIALIZE || this->socket_ == nullptr) { HELPER_LOG("Bad state for init %d", (int) state_); return APIError::BAD_STATE; } int err = this->socket_->setblocking(false); if (err != 0) { state_ = State::FAILED; HELPER_LOG("Setting nonblocking failed with errno %d", errno); return APIError::TCP_NONBLOCKING_FAILED; } int enable = 1; err = this->socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); if (err != 0) { state_ = State::FAILED; HELPER_LOG("Setting nodelay failed with errno %d", errno); return APIError::TCP_NODELAY_FAILED; } // init prologue prologue_.insert(prologue_.end(), PROLOGUE_INIT, PROLOGUE_INIT + strlen(PROLOGUE_INIT)); state_ = State::CLIENT_HELLO; return APIError::OK; } /// Run through handshake messages (if in that phase) APIError APINoiseFrameHelper::loop() { APIError err = state_action_(); if (err == APIError::WOULD_BLOCK) return APIError::OK; if (err != APIError::OK) return err; if (!this->tx_buf_.empty()) { err = try_send_tx_buf_(); if (err != APIError::OK) { return err; } } return APIError::OK; } /** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter * * @param frame: The struct to hold the frame information in. * msg_start: points to the start of the payload - this pointer is only valid until the next * try_receive_raw_ call * * @return 0 if a full packet is in rx_buf_ * @return -1 if error, check errno. * * errno EWOULDBLOCK: Packet could not be read without blocking. Try again later. * errno ENOMEM: Not enough memory for reading packet. * errno API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame. * errno API_ERROR_HANDSHAKE_PACKET_LEN: Packet too big for this phase. */ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) { if (frame == nullptr) { HELPER_LOG("Bad argument for try_read_frame_"); return APIError::BAD_ARG; } // read header if (rx_header_buf_len_ < 3) { // no header information yet size_t to_read = 3 - rx_header_buf_len_; ssize_t received = this->socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read); if (received == -1) { if (errno == EWOULDBLOCK || errno == EAGAIN) { return APIError::WOULD_BLOCK; } state_ = State::FAILED; HELPER_LOG("Socket read failed with errno %d", errno); return APIError::SOCKET_READ_FAILED; } else if (received == 0) { state_ = State::FAILED; HELPER_LOG("Connection closed"); return APIError::CONNECTION_CLOSED; } rx_header_buf_len_ += received; if ((size_t) received != to_read) { // not a full read return APIError::WOULD_BLOCK; } // header reading done } // read body uint8_t indicator = rx_header_buf_[0]; if (indicator != 0x01) { state_ = State::FAILED; HELPER_LOG("Bad indicator byte %u", indicator); return APIError::BAD_INDICATOR; } uint16_t msg_size = (((uint16_t) rx_header_buf_[1]) << 8) | rx_header_buf_[2]; if (state_ != State::DATA && msg_size > 128) { // for handshake message only permit up to 128 bytes state_ = State::FAILED; HELPER_LOG("Bad packet len for handshake: %d", msg_size); return APIError::BAD_HANDSHAKE_PACKET_LEN; } // reserve space for body if (rx_buf_.size() != msg_size) { rx_buf_.resize(msg_size); } if (rx_buf_len_ < msg_size) { // more data to read size_t to_read = msg_size - rx_buf_len_; ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read); if (received == -1) { if (errno == EWOULDBLOCK || errno == EAGAIN) { return APIError::WOULD_BLOCK; } state_ = State::FAILED; HELPER_LOG("Socket read failed with errno %d", errno); return APIError::SOCKET_READ_FAILED; } else if (received == 0) { state_ = State::FAILED; HELPER_LOG("Connection closed"); return APIError::CONNECTION_CLOSED; } rx_buf_len_ += received; if ((size_t) received != to_read) { // not all read return APIError::WOULD_BLOCK; } } // uncomment for even more debugging #ifdef HELPER_LOG_PACKETS ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(rx_buf_).c_str()); #endif frame->msg = std::move(rx_buf_); // consume msg rx_buf_ = {}; rx_buf_len_ = 0; rx_header_buf_len_ = 0; return APIError::OK; } /** To be called from read/write methods. * * This method runs through the internal handshake methods, if in that state. * * If the handshake is still active when this method returns and a read/write can't take place at * the moment, returns WOULD_BLOCK. * If an error occurred, returns that error. Only returns OK if the transport is ready for data * traffic. */ APIError APINoiseFrameHelper::state_action_() { int err; APIError aerr; if (state_ == State::INITIALIZE) { HELPER_LOG("Bad state for method: %d", (int) state_); return APIError::BAD_STATE; } if (state_ == State::CLIENT_HELLO) { // waiting for client hello ParsedFrame frame; aerr = try_read_frame_(&frame); if (aerr == APIError::BAD_INDICATOR) { send_explicit_handshake_reject_("Bad indicator byte"); return aerr; } if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) { send_explicit_handshake_reject_("Bad handshake packet len"); return aerr; } if (aerr != APIError::OK) return aerr; // ignore contents, may be used in future for flags prologue_.push_back((uint8_t) (frame.msg.size() >> 8)); prologue_.push_back((uint8_t) frame.msg.size()); prologue_.insert(prologue_.end(), frame.msg.begin(), frame.msg.end()); state_ = State::SERVER_HELLO; } if (state_ == State::SERVER_HELLO) { // send server hello std::vector msg; // chosen proto msg.push_back(0x01); // node name, terminated by null byte const std::string &name = App.get_name(); const uint8_t *name_ptr = reinterpret_cast(name.c_str()); msg.insert(msg.end(), name_ptr, name_ptr + name.size() + 1); // node mac, terminated by null byte const std::string &mac = get_mac_address(); const uint8_t *mac_ptr = reinterpret_cast(mac.c_str()); msg.insert(msg.end(), mac_ptr, mac_ptr + mac.size() + 1); aerr = write_frame_(msg.data(), msg.size()); if (aerr != APIError::OK) return aerr; // start handshake aerr = init_handshake_(); if (aerr != APIError::OK) return aerr; state_ = State::HANDSHAKE; } if (state_ == State::HANDSHAKE) { int action = noise_handshakestate_get_action(handshake_); if (action == NOISE_ACTION_READ_MESSAGE) { // waiting for handshake msg ParsedFrame frame; aerr = try_read_frame_(&frame); if (aerr == APIError::BAD_INDICATOR) { send_explicit_handshake_reject_("Bad indicator byte"); return aerr; } if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) { send_explicit_handshake_reject_("Bad handshake packet len"); return aerr; } if (aerr != APIError::OK) return aerr; if (frame.msg.empty()) { send_explicit_handshake_reject_("Empty handshake message"); return APIError::BAD_HANDSHAKE_ERROR_BYTE; } else if (frame.msg[0] != 0x00) { HELPER_LOG("Bad handshake error byte: %u", frame.msg[0]); send_explicit_handshake_reject_("Bad handshake error byte"); return APIError::BAD_HANDSHAKE_ERROR_BYTE; } NoiseBuffer mbuf; noise_buffer_init(mbuf); noise_buffer_set_input(mbuf, frame.msg.data() + 1, frame.msg.size() - 1); err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr); if (err != 0) { state_ = State::FAILED; HELPER_LOG("noise_handshakestate_read_message failed: %s", noise_err_to_str(err).c_str()); if (err == NOISE_ERROR_MAC_FAILURE) { send_explicit_handshake_reject_("Handshake MAC failure"); } else { send_explicit_handshake_reject_("Handshake error"); } return APIError::HANDSHAKESTATE_READ_FAILED; } aerr = check_handshake_finished_(); if (aerr != APIError::OK) return aerr; } else if (action == NOISE_ACTION_WRITE_MESSAGE) { uint8_t buffer[65]; NoiseBuffer mbuf; noise_buffer_init(mbuf); noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1); err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr); if (err != 0) { state_ = State::FAILED; HELPER_LOG("noise_handshakestate_write_message failed: %s", noise_err_to_str(err).c_str()); return APIError::HANDSHAKESTATE_WRITE_FAILED; } buffer[0] = 0x00; // success aerr = write_frame_(buffer, mbuf.size + 1); if (aerr != APIError::OK) return aerr; aerr = check_handshake_finished_(); if (aerr != APIError::OK) return aerr; } else { // bad state for action state_ = State::FAILED; HELPER_LOG("Bad action for handshake: %d", action); return APIError::HANDSHAKESTATE_BAD_STATE; } } if (state_ == State::CLOSED || state_ == State::FAILED) { return APIError::BAD_STATE; } return APIError::OK; } void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) { std::vector data; data.resize(reason.length() + 1); data[0] = 0x01; // failure for (size_t i = 0; i < reason.length(); i++) { data[i + 1] = (uint8_t) reason[i]; } // temporarily remove failed state auto orig_state = state_; state_ = State::EXPLICIT_REJECT; write_frame_(data.data(), data.size()); state_ = orig_state; } APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { int err; APIError aerr; aerr = state_action_(); if (aerr != APIError::OK) { return aerr; } if (state_ != State::DATA) { return APIError::WOULD_BLOCK; } ParsedFrame frame; aerr = try_read_frame_(&frame); if (aerr != APIError::OK) return aerr; NoiseBuffer mbuf; noise_buffer_init(mbuf); noise_buffer_set_inout(mbuf, frame.msg.data(), frame.msg.size(), frame.msg.size()); err = noise_cipherstate_decrypt(recv_cipher_, &mbuf); if (err != 0) { state_ = State::FAILED; HELPER_LOG("noise_cipherstate_decrypt failed: %s", noise_err_to_str(err).c_str()); return APIError::CIPHERSTATE_DECRYPT_FAILED; } size_t msg_size = mbuf.size; uint8_t *msg_data = frame.msg.data(); if (msg_size < 4) { state_ = State::FAILED; HELPER_LOG("Bad data packet: size %d too short", msg_size); return APIError::BAD_DATA_PACKET; } // uint16_t type; // uint16_t data_len; // uint8_t *data; // uint8_t *padding; zero or more bytes to fill up the rest of the packet uint16_t type = (((uint16_t) msg_data[0]) << 8) | msg_data[1]; uint16_t data_len = (((uint16_t) msg_data[2]) << 8) | msg_data[3]; if (data_len > msg_size - 4) { state_ = State::FAILED; HELPER_LOG("Bad data packet: data_len %u greater than msg_size %u", data_len, msg_size); return APIError::BAD_DATA_PACKET; } buffer->container = std::move(frame.msg); buffer->data_offset = 4; buffer->data_len = data_len; buffer->type = type; return APIError::OK; } APIError APINoiseFrameHelper::write_packet(uint16_t type, const uint8_t *payload, size_t payload_len) { int err; APIError aerr; aerr = state_action_(); if (aerr != APIError::OK) { return aerr; } if (state_ != State::DATA) { return APIError::WOULD_BLOCK; } size_t padding = 0; size_t msg_len = 4 + payload_len + padding; size_t frame_len = 3 + msg_len + noise_cipherstate_get_mac_length(send_cipher_); auto tmpbuf = std::unique_ptr{new (std::nothrow) uint8_t[frame_len]}; if (tmpbuf == nullptr) { HELPER_LOG("Could not allocate for writing packet"); return APIError::OUT_OF_MEMORY; } tmpbuf[0] = 0x01; // indicator // tmpbuf[1], tmpbuf[2] to be set later const uint8_t msg_offset = 3; const uint8_t payload_offset = msg_offset + 4; tmpbuf[msg_offset + 0] = (uint8_t) (type >> 8); // type tmpbuf[msg_offset + 1] = (uint8_t) type; tmpbuf[msg_offset + 2] = (uint8_t) (payload_len >> 8); // data_len tmpbuf[msg_offset + 3] = (uint8_t) payload_len; // copy data std::copy(payload, payload + payload_len, &tmpbuf[payload_offset]); // fill padding with zeros std::fill(&tmpbuf[payload_offset + payload_len], &tmpbuf[frame_len], 0); NoiseBuffer mbuf; noise_buffer_init(mbuf); noise_buffer_set_inout(mbuf, &tmpbuf[msg_offset], msg_len, frame_len - msg_offset); err = noise_cipherstate_encrypt(send_cipher_, &mbuf); if (err != 0) { state_ = State::FAILED; HELPER_LOG("noise_cipherstate_encrypt failed: %s", noise_err_to_str(err).c_str()); return APIError::CIPHERSTATE_ENCRYPT_FAILED; } size_t total_len = 3 + mbuf.size; tmpbuf[1] = (uint8_t) (mbuf.size >> 8); tmpbuf[2] = (uint8_t) mbuf.size; struct iovec iov; iov.iov_base = &tmpbuf[0]; iov.iov_len = total_len; // write raw to not have two packets sent if NAGLE disabled return APIFrameHelper::write_raw_(&iov, 1); } APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, size_t len) { uint8_t header[3]; header[0] = 0x01; // indicator header[1] = (uint8_t) (len >> 8); header[2] = (uint8_t) len; struct iovec iov[2]; iov[0].iov_base = header; iov[0].iov_len = 3; if (len == 0) { return APIFrameHelper::write_raw_(iov, 1); } iov[1].iov_base = const_cast(data); iov[1].iov_len = len; return APIFrameHelper::write_raw_(iov, 2); } /** Initiate the data structures for the handshake. * * @return 0 on success, -1 on error (check errno) */ APIError APINoiseFrameHelper::init_handshake_() { int err; memset(&nid_, 0, sizeof(nid_)); // const char *proto = "Noise_NNpsk0_25519_ChaChaPoly_SHA256"; // err = noise_protocol_name_to_id(&nid_, proto, strlen(proto)); nid_.pattern_id = NOISE_PATTERN_NN; nid_.cipher_id = NOISE_CIPHER_CHACHAPOLY; nid_.dh_id = NOISE_DH_CURVE25519; nid_.prefix_id = NOISE_PREFIX_STANDARD; nid_.hybrid_id = NOISE_DH_NONE; nid_.hash_id = NOISE_HASH_SHA256; nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0; err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER); if (err != 0) { state_ = State::FAILED; HELPER_LOG("noise_handshakestate_new_by_id failed: %s", noise_err_to_str(err).c_str()); return APIError::HANDSHAKESTATE_SETUP_FAILED; } const auto &psk = ctx_->get_psk(); err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size()); if (err != 0) { state_ = State::FAILED; HELPER_LOG("noise_handshakestate_set_pre_shared_key failed: %s", noise_err_to_str(err).c_str()); return APIError::HANDSHAKESTATE_SETUP_FAILED; } err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size()); if (err != 0) { state_ = State::FAILED; HELPER_LOG("noise_handshakestate_set_prologue failed: %s", noise_err_to_str(err).c_str()); return APIError::HANDSHAKESTATE_SETUP_FAILED; } // set_prologue copies it into handshakestate, so we can get rid of it now prologue_ = {}; err = noise_handshakestate_start(handshake_); if (err != 0) { state_ = State::FAILED; HELPER_LOG("noise_handshakestate_start failed: %s", noise_err_to_str(err).c_str()); return APIError::HANDSHAKESTATE_SETUP_FAILED; } return APIError::OK; } APIError APINoiseFrameHelper::check_handshake_finished_() { assert(state_ == State::HANDSHAKE); int action = noise_handshakestate_get_action(handshake_); if (action == NOISE_ACTION_READ_MESSAGE || action == NOISE_ACTION_WRITE_MESSAGE) return APIError::OK; if (action != NOISE_ACTION_SPLIT) { state_ = State::FAILED; HELPER_LOG("Bad action for handshake: %d", action); return APIError::HANDSHAKESTATE_BAD_STATE; } int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_); if (err != 0) { state_ = State::FAILED; HELPER_LOG("noise_handshakestate_split failed: %s", noise_err_to_str(err).c_str()); return APIError::HANDSHAKESTATE_SPLIT_FAILED; } HELPER_LOG("Handshake complete!"); noise_handshakestate_free(handshake_); handshake_ = nullptr; state_ = State::DATA; return APIError::OK; } APINoiseFrameHelper::~APINoiseFrameHelper() { if (handshake_ != nullptr) { noise_handshakestate_free(handshake_); handshake_ = nullptr; } if (send_cipher_ != nullptr) { noise_cipherstate_free(send_cipher_); send_cipher_ = nullptr; } if (recv_cipher_ != nullptr) { noise_cipherstate_free(recv_cipher_); recv_cipher_ = nullptr; } } extern "C" { // declare how noise generates random bytes (here with a good HWRNG based on the RF system) void noise_rand_bytes(void *output, size_t len) { if (!esphome::random_bytes(reinterpret_cast(output), len)) { ESP_LOGE(TAG, "Failed to acquire random bytes, rebooting!"); arch_restart(); } } } #endif // USE_API_NOISE #ifdef USE_API_PLAINTEXT /// Initialize the frame helper, returns OK if successful. APIError APIPlaintextFrameHelper::init() { if (state_ != State::INITIALIZE || this->socket_ == nullptr) { HELPER_LOG("Bad state for init %d", (int) state_); return APIError::BAD_STATE; } int err = this->socket_->setblocking(false); if (err != 0) { state_ = State::FAILED; HELPER_LOG("Setting nonblocking failed with errno %d", errno); return APIError::TCP_NONBLOCKING_FAILED; } int enable = 1; err = this->socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); if (err != 0) { state_ = State::FAILED; HELPER_LOG("Setting nodelay failed with errno %d", errno); return APIError::TCP_NODELAY_FAILED; } state_ = State::DATA; return APIError::OK; } /// Not used for plaintext APIError APIPlaintextFrameHelper::loop() { if (state_ != State::DATA) { return APIError::BAD_STATE; } // try send pending TX data if (!this->tx_buf_.empty()) { APIError err = try_send_tx_buf_(); if (err != APIError::OK) { return err; } } return APIError::OK; } /** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter * * @param frame: The struct to hold the frame information in. * msg: store the parsed frame in that struct * * @return See APIError * * error API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame. */ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) { if (frame == nullptr) { HELPER_LOG("Bad argument for try_read_frame_"); return APIError::BAD_ARG; } // read header while (!rx_header_parsed_) { uint8_t data; // Reading one byte at a time is fastest in practice for ESP32 when // there is no data on the wire (which is the common case). // This results in faster failure detection compared to // attempting to read multiple bytes at once. ssize_t received = this->socket_->read(&data, 1); if (received == -1) { if (errno == EWOULDBLOCK || errno == EAGAIN) { return APIError::WOULD_BLOCK; } state_ = State::FAILED; HELPER_LOG("Socket read failed with errno %d", errno); return APIError::SOCKET_READ_FAILED; } else if (received == 0) { state_ = State::FAILED; HELPER_LOG("Connection closed"); return APIError::CONNECTION_CLOSED; } // Successfully read a byte // Process byte according to current buffer position if (rx_header_buf_pos_ == 0) { // Case 1: First byte (indicator byte) if (data != 0x00) { state_ = State::FAILED; HELPER_LOG("Bad indicator byte %u", data); return APIError::BAD_INDICATOR; } // We don't store the indicator byte, just increment position rx_header_buf_pos_ = 1; // Set to 1 directly continue; // Need more bytes before we can parse } // Check buffer overflow before storing if (rx_header_buf_pos_ == 5) { // Case 2: Buffer would overflow (5 bytes is max allowed) state_ = State::FAILED; HELPER_LOG("Header buffer overflow"); return APIError::BAD_DATA_PACKET; } // Store byte in buffer (adjust index to account for skipped indicator byte) rx_header_buf_[rx_header_buf_pos_ - 1] = data; // Increment position after storing rx_header_buf_pos_++; // Case 3: If we only have one varint byte, we need more if (rx_header_buf_pos_ == 2) { // Have read indicator + 1 byte continue; // Need more bytes before we can parse } // At this point, we have at least 3 bytes total: // - Validated indicator byte (0x00) but not stored // - At least 2 bytes in the buffer for the varints // Buffer layout: // First 1-3 bytes: Message size varint (variable length) // - 2 bytes would only allow up to 16383, which is less than noise's 65535 // - 3 bytes allows up to 2097151, ensuring we support at least as much as noise // Remaining 1-2 bytes: Message type varint (variable length) // We now attempt to parse both varints. If either is incomplete, // we'll continue reading more bytes. uint32_t consumed = 0; auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[0], rx_header_buf_pos_ - 1, &consumed); if (!msg_size_varint.has_value()) { // not enough data there yet continue; } rx_header_parsed_len_ = msg_size_varint->as_uint32(); auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[consumed], rx_header_buf_pos_ - 1 - consumed, &consumed); if (!msg_type_varint.has_value()) { // not enough data there yet continue; } rx_header_parsed_type_ = msg_type_varint->as_uint32(); rx_header_parsed_ = true; } // header reading done // reserve space for body if (rx_buf_.size() != rx_header_parsed_len_) { rx_buf_.resize(rx_header_parsed_len_); } if (rx_buf_len_ < rx_header_parsed_len_) { // more data to read size_t to_read = rx_header_parsed_len_ - rx_buf_len_; ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read); if (received == -1) { if (errno == EWOULDBLOCK || errno == EAGAIN) { return APIError::WOULD_BLOCK; } state_ = State::FAILED; HELPER_LOG("Socket read failed with errno %d", errno); return APIError::SOCKET_READ_FAILED; } else if (received == 0) { state_ = State::FAILED; HELPER_LOG("Connection closed"); return APIError::CONNECTION_CLOSED; } rx_buf_len_ += received; if ((size_t) received != to_read) { // not all read return APIError::WOULD_BLOCK; } } // uncomment for even more debugging #ifdef HELPER_LOG_PACKETS ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(rx_buf_).c_str()); #endif frame->msg = std::move(rx_buf_); // consume msg rx_buf_ = {}; rx_buf_len_ = 0; rx_header_buf_pos_ = 0; rx_header_parsed_ = false; return APIError::OK; } APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { APIError aerr; if (state_ != State::DATA) { return APIError::WOULD_BLOCK; } ParsedFrame frame; aerr = try_read_frame_(&frame); if (aerr != APIError::OK) { if (aerr == APIError::BAD_INDICATOR) { // Make sure to tell the remote that we don't // understand the indicator byte so it knows // we do not support it. struct iovec iov[1]; // The \x00 first byte is the marker for plaintext. // // The remote will know how to handle the indicator byte, // but it likely won't understand the rest of the message. // // We must send at least 3 bytes to be read, so we add // a message after the indicator byte to ensures its long // enough and can aid in debugging. const char msg[] = "\x00" "Bad indicator byte"; iov[0].iov_base = (void *) msg; iov[0].iov_len = 19; APIFrameHelper::write_raw_(iov, 1); } return aerr; } buffer->container = std::move(frame.msg); buffer->data_offset = 0; buffer->data_len = rx_header_parsed_len_; buffer->type = rx_header_parsed_type_; return APIError::OK; } APIError APIPlaintextFrameHelper::write_packet(uint16_t type, const uint8_t *payload, size_t payload_len) { if (state_ != State::DATA) { return APIError::BAD_STATE; } std::vector header; header.reserve(1 + api::ProtoSize::varint(static_cast(payload_len)) + api::ProtoSize::varint(static_cast(type))); header.push_back(0x00); ProtoVarInt(payload_len).encode(header); ProtoVarInt(type).encode(header); struct iovec iov[2]; iov[0].iov_base = &header[0]; iov[0].iov_len = header.size(); if (payload_len == 0) { return APIFrameHelper::write_raw_(iov, 1); } iov[1].iov_base = const_cast(payload); iov[1].iov_len = payload_len; return APIFrameHelper::write_raw_(iov, 2); } #endif // USE_API_PLAINTEXT } // namespace api } // namespace esphome #endif