This commit is contained in:
J. Nick Koston
2025-05-27 10:29:23 -05:00
parent 2288cd65ad
commit 456a475cea
10 changed files with 54 additions and 45 deletions

View File

@@ -308,7 +308,7 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) {
}
// Check if socket has data available before attempting to read
if (!this->is_socket_ready_for_read_()) {
if (!socket_->ready()) {
return APIError::WOULD_BLOCK;
}
@@ -835,7 +835,7 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) {
}
// Check if socket has data available before attempting to read
if (!this->is_socket_ready_for_read_()) {
if (!socket_->ready()) {
return APIError::WOULD_BLOCK;
}

View File

@@ -93,12 +93,6 @@ class APIFrameHelper {
virtual uint8_t frame_footer_size() = 0;
protected:
// Helper to check if socket has data available
inline bool is_socket_ready_for_read_() const {
int fd = socket_->get_fd();
return fd < 0 || App.is_socket_ready(fd); // If no fd, assume ready (fallback behavior)
}
// Struct for holding parsed frame data
struct ParsedFrame {
std::vector<uint8_t> msg;

View File

@@ -43,7 +43,7 @@ void APIServer::setup() {
}
#endif
this->socket_ = socket::socket_ip_monitored(SOCK_STREAM, 0); // monitored for incoming connections
this->socket_ = socket::socket_ip_loop_monitored(SOCK_STREAM, 0); // monitored for incoming connections
if (this->socket_ == nullptr) {
ESP_LOGW(TAG, "Could not create socket");
this->mark_failed();
@@ -113,12 +113,11 @@ void APIServer::setup() {
void APIServer::loop() {
// Accept new clients only if the socket has incoming connections
int server_fd = this->socket_->get_fd();
if (server_fd >= 0 && App.is_socket_ready(server_fd)) {
if (this->socket_->ready()) {
while (true) {
struct sockaddr_storage source_addr;
socklen_t addr_len = sizeof(source_addr);
auto sock = this->socket_->accept_monitored((struct sockaddr *) &source_addr, &addr_len);
auto sock = this->socket_->accept_loop_monitored((struct sockaddr *) &source_addr, &addr_len);
if (!sock)
break;
ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str());

View File

@@ -26,7 +26,7 @@ void ESPHomeOTAComponent::setup() {
ota::register_ota_platform(this);
#endif
server_ = socket::socket_ip_monitored(SOCK_STREAM, 0); // monitored for incoming connections
server_ = socket::socket_ip_loop_monitored(SOCK_STREAM, 0); // monitored for incoming connections
if (server_ == nullptr) {
ESP_LOGW(TAG, "Could not create socket");
this->mark_failed();
@@ -101,8 +101,7 @@ void ESPHomeOTAComponent::handle_() {
if (client_ == nullptr) {
// Check if the server socket is ready before accepting
int server_fd = this->server_->get_fd();
if (server_fd >= 0 && App.is_socket_ready(server_fd)) {
if (this->server_->ready()) {
struct sockaddr_storage source_addr;
socklen_t addr_len = sizeof(source_addr);
client_ = server_->accept((struct sockaddr *) &source_addr, &addr_len);

View File

@@ -42,7 +42,7 @@ std::string format_sockaddr(const struct sockaddr_storage &storage) {
class BSDSocketImpl : public Socket {
public:
BSDSocketImpl(int fd, bool monitor_loop = false) : fd_(fd) {
monitored_ = monitor_loop;
loop_monitored_ = monitor_loop;
// Register new socket with the application for select() if monitoring requested
if (monitor_loop && fd_ >= 0) {
App.register_socket_fd(fd_);
@@ -60,7 +60,7 @@ class BSDSocketImpl : public Socket {
return {};
return make_unique<BSDSocketImpl>(fd);
}
std::unique_ptr<Socket> accept_monitored(struct sockaddr *addr, socklen_t *addrlen) override {
std::unique_ptr<Socket> accept_loop_monitored(struct sockaddr *addr, socklen_t *addrlen) override {
int fd = ::accept(fd_, addr, addrlen);
if (fd == -1)
return {};
@@ -70,7 +70,7 @@ class BSDSocketImpl : public Socket {
int close() override {
if (!closed_) {
// Unregister from select() before closing if monitored
if (monitored_) {
if (loop_monitored_) {
App.unregister_socket_fd(fd_);
}
int ret = ::close(fd_);
@@ -160,7 +160,7 @@ std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
return std::unique_ptr<Socket>{new BSDSocketImpl(ret)};
}
std::unique_ptr<Socket> socket_monitored(int domain, int type, int protocol) {
std::unique_ptr<Socket> socket_loop_monitored(int domain, int type, int protocol) {
int ret = ::socket(domain, type, protocol);
if (ret == -1)
return nullptr;

View File

@@ -606,7 +606,7 @@ std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
return std::unique_ptr<Socket>{sock};
}
std::unique_ptr<Socket> socket_monitored(int domain, int type, int protocol) {
std::unique_ptr<Socket> socket_loop_monitored(int domain, int type, int protocol) {
// LWIPRawImpl doesn't use file descriptors, so monitoring is not applicable
return socket(domain, type, protocol);
}

View File

@@ -35,7 +35,7 @@ std::string format_sockaddr(const struct sockaddr_storage &storage) {
class LwIPSocketImpl : public Socket {
public:
LwIPSocketImpl(int fd, bool monitor_loop = false) : fd_(fd) {
monitored_ = monitor_loop;
loop_monitored_ = monitor_loop;
// Register new socket with the application for select() if monitoring requested
if (monitor_loop && fd_ >= 0) {
App.register_socket_fd(fd_);
@@ -53,7 +53,7 @@ class LwIPSocketImpl : public Socket {
return {};
return make_unique<LwIPSocketImpl>(fd); // Default: not monitored
}
std::unique_ptr<Socket> accept_monitored(struct sockaddr *addr, socklen_t *addrlen) override {
std::unique_ptr<Socket> accept_loop_monitored(struct sockaddr *addr, socklen_t *addrlen) override {
int fd = lwip_accept(fd_, addr, addrlen);
if (fd == -1)
return {};
@@ -63,7 +63,7 @@ class LwIPSocketImpl : public Socket {
int close() override {
if (!closed_) {
// Unregister from select() before closing if monitored
if (monitored_) {
if (loop_monitored_) {
App.unregister_socket_fd(fd_);
}
int ret = lwip_close(fd_);
@@ -132,7 +132,7 @@ std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
return std::unique_ptr<Socket>{new LwIPSocketImpl(ret)};
}
std::unique_ptr<Socket> socket_monitored(int domain, int type, int protocol) {
std::unique_ptr<Socket> socket_loop_monitored(int domain, int type, int protocol) {
int ret = lwip_socket(domain, type, protocol);
if (ret == -1)
return nullptr;

View File

@@ -4,12 +4,29 @@
#include <cstring>
#include <string>
#include "esphome/core/log.h"
#include "esphome/core/application.h"
namespace esphome {
namespace socket {
Socket::~Socket() {}
bool Socket::ready() const {
if (!loop_monitored_) {
// Non-monitored sockets always return true (assume data may be available)
return true;
}
// For loop-monitored sockets, check with the Application's select() results
int fd = this->get_fd();
if (fd < 0) {
// No valid file descriptor, assume ready (fallback behavior)
return true;
}
return App.is_socket_ready(fd);
}
std::unique_ptr<Socket> socket_ip(int type, int protocol) {
#if USE_NETWORK_IPV6
return socket(AF_INET6, type, protocol);
@@ -18,11 +35,11 @@ std::unique_ptr<Socket> socket_ip(int type, int protocol) {
#endif /* USE_NETWORK_IPV6 */
}
std::unique_ptr<Socket> socket_ip_monitored(int type, int protocol) {
std::unique_ptr<Socket> socket_ip_loop_monitored(int type, int protocol) {
#if USE_NETWORK_IPV6
return socket_monitored(AF_INET6, type, protocol);
return socket_loop_monitored(AF_INET6, type, protocol);
#else
return socket_monitored(AF_INET, type, protocol);
return socket_loop_monitored(AF_INET, type, protocol);
#endif /* USE_NETWORK_IPV6 */
}

View File

@@ -17,9 +17,9 @@ class Socket {
Socket &operator=(const Socket &) = delete;
virtual std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) = 0;
/// Accept a connection and optionally monitor it in the main loop
/// NOTE: Setting monitor_loop is NOT thread-safe and must only be called from the main loop
virtual std::unique_ptr<Socket> accept_monitored(struct sockaddr *addr, socklen_t *addrlen) {
/// Accept a connection and monitor it in the main loop
/// NOTE: This function is NOT thread-safe and must only be called from the main loop
virtual std::unique_ptr<Socket> accept_loop_monitored(struct sockaddr *addr, socklen_t *addrlen) {
return accept(addr, addrlen); // Default implementation for backward compatibility
}
virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0;
@@ -53,25 +53,26 @@ class Socket {
/// Get the underlying file descriptor (returns -1 if not supported)
virtual int get_fd() const { return -1; }
/// Check if socket has data ready to read
/// For loop-monitored sockets, checks with the Application's select() results
/// For non-monitored sockets, always returns true (assumes data may be available)
bool ready() const;
protected:
bool monitored_{false}; ///< Whether this socket is monitored by the event loop
bool loop_monitored_{false}; ///< Whether this socket is monitored by the event loop
};
/// Create a socket of the given domain, type and protocol.
std::unique_ptr<Socket> socket(int domain, int type, int protocol);
/// Create a socket and monitor it for data in the main loop
/// WARNING: This function is NOT thread-safe. It must only be called from the main loop
/// as it registers the socket file descriptor with the global Application instance.
std::unique_ptr<Socket> socket_monitored(int domain, int type, int protocol);
/// Create a socket in the newest available IP domain (IPv6 or IPv4) of the given type and protocol.
std::unique_ptr<Socket> socket_ip(int type, int protocol);
/// Create a socket in the newest available IP domain and monitor it for data in the main loop
/// WARNING: This function is NOT thread-safe. It must only be called from the main loop
/// as it registers the socket file descriptor with the global Application instance.
std::unique_ptr<Socket> socket_ip_monitored(int type, int protocol);
/// Create a socket and monitor it for data in the main loop.
/// Like socket() but also registers the socket with the Application's select() loop.
/// WARNING: These functions are NOT thread-safe. They must only be called from the main loop
/// as they register the socket file descriptor with the global Application instance.
std::unique_ptr<Socket> socket_loop_monitored(int domain, int type, int protocol);
std::unique_ptr<Socket> socket_ip_loop_monitored(int type, int protocol);
/// Set a sockaddr to the specified address and port for the IP version used by socket_ip().
socklen_t set_sockaddr(struct sockaddr *addr, socklen_t addrlen, const std::string &ip_address, uint16_t port);

View File

@@ -474,11 +474,10 @@ class Application {
Scheduler scheduler;
/// Register a socket file descriptor to be monitored for read events
/// WARNING: This function is NOT thread-safe. It must only be called from the main loop.
/// Register/unregister a socket file descriptor to be monitored for read events.
/// These functions update the fd_set used by select() in the main loop.
/// WARNING: These functions are NOT thread-safe. They must only be called from the main loop.
void register_socket_fd(int fd);
/// Unregister a socket file descriptor
/// WARNING: This function is NOT thread-safe. It must only be called from the main loop.
void unregister_socket_fd(int fd);
/// Check if there's data available on a socket without blocking
/// This function is thread-safe for reading, but should be called after select() has run