diff --git a/esphome/components/api/api_server.cpp b/esphome/components/api/api_server.cpp index 4498f47214..addb2f4619 100644 --- a/esphome/components/api/api_server.cpp +++ b/esphome/components/api/api_server.cpp @@ -43,7 +43,7 @@ void APIServer::setup() { } #endif - this->socket_ = socket::socket_ip(SOCK_STREAM, 0); + this->socket_ = socket::socket_ip_monitored(SOCK_STREAM, 0); // monitored for incoming connections if (this->socket_ == nullptr) { ESP_LOGW(TAG, "Could not create socket"); this->mark_failed(); @@ -118,7 +118,7 @@ void APIServer::loop() { while (true) { struct sockaddr_storage source_addr; socklen_t addr_len = sizeof(source_addr); - auto sock = this->socket_->accept((struct sockaddr *) &source_addr, &addr_len); + auto sock = this->socket_->accept_monitored((struct sockaddr *) &source_addr, &addr_len); if (!sock) break; ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str()); diff --git a/esphome/components/esphome/ota/ota_esphome.cpp b/esphome/components/esphome/ota/ota_esphome.cpp index b949f4de81..4af6a44a6e 100644 --- a/esphome/components/esphome/ota/ota_esphome.cpp +++ b/esphome/components/esphome/ota/ota_esphome.cpp @@ -26,7 +26,7 @@ void ESPHomeOTAComponent::setup() { ota::register_ota_platform(this); #endif - server_ = socket::socket_ip(SOCK_STREAM, 0); + server_ = socket::socket_ip_monitored(SOCK_STREAM, 0); // monitored for incoming connections if (server_ == nullptr) { ESP_LOGW(TAG, "Could not create socket"); this->mark_failed(); diff --git a/esphome/components/socket/bsd_sockets_impl.cpp b/esphome/components/socket/bsd_sockets_impl.cpp index 77358f462c..51b75a1057 100644 --- a/esphome/components/socket/bsd_sockets_impl.cpp +++ b/esphome/components/socket/bsd_sockets_impl.cpp @@ -41,9 +41,9 @@ std::string format_sockaddr(const struct sockaddr_storage &storage) { class BSDSocketImpl : public Socket { public: - BSDSocketImpl(int fd) : fd_(fd) { - // Register new socket with the application for select() - if (fd_ >= 0) { + BSDSocketImpl(int fd, bool monitor_loop = false) : fd_(fd) { + // Register new socket with the application for select() if monitoring requested + if (monitor_loop && fd_ >= 0) { App.register_socket_fd(fd_); } } @@ -57,7 +57,13 @@ class BSDSocketImpl : public Socket { int fd = ::accept(fd_, addr, addrlen); if (fd == -1) return {}; - return make_unique(fd); + return make_unique(fd); // Default: not monitored + } + std::unique_ptr accept_monitored(struct sockaddr *addr, socklen_t *addrlen) override { + int fd = ::accept(fd_, addr, addrlen); + if (fd == -1) + return {}; + return make_unique(fd, true); // Monitored for incoming data } int bind(const struct sockaddr *addr, socklen_t addrlen) override { return ::bind(fd_, addr, addrlen); } int close() override { @@ -151,6 +157,13 @@ std::unique_ptr socket(int domain, int type, int protocol) { return std::unique_ptr{new BSDSocketImpl(ret)}; } +std::unique_ptr socket_monitored(int domain, int type, int protocol) { + int ret = ::socket(domain, type, protocol); + if (ret == -1) + return nullptr; + return std::unique_ptr{new BSDSocketImpl(ret, true)}; +} + } // namespace socket } // namespace esphome diff --git a/esphome/components/socket/lwip_raw_tcp_impl.cpp b/esphome/components/socket/lwip_raw_tcp_impl.cpp index 1d998902ff..deb5eb1870 100644 --- a/esphome/components/socket/lwip_raw_tcp_impl.cpp +++ b/esphome/components/socket/lwip_raw_tcp_impl.cpp @@ -606,6 +606,11 @@ std::unique_ptr socket(int domain, int type, int protocol) { return std::unique_ptr{sock}; } +std::unique_ptr socket_monitored(int domain, int type, int protocol) { + // LWIPRawImpl doesn't use file descriptors, so monitoring is not applicable + return socket(domain, type, protocol); +} + } // namespace socket } // namespace esphome diff --git a/esphome/components/socket/lwip_sockets_impl.cpp b/esphome/components/socket/lwip_sockets_impl.cpp index 1ed4706cdf..7b946a1dc2 100644 --- a/esphome/components/socket/lwip_sockets_impl.cpp +++ b/esphome/components/socket/lwip_sockets_impl.cpp @@ -34,9 +34,9 @@ std::string format_sockaddr(const struct sockaddr_storage &storage) { class LwIPSocketImpl : public Socket { public: - LwIPSocketImpl(int fd) : fd_(fd) { - // Register new socket with the application for select() - if (fd_ >= 0) { + LwIPSocketImpl(int fd, bool monitor_loop = false) : fd_(fd) { + // Register new socket with the application for select() if monitoring requested + if (monitor_loop && fd_ >= 0) { App.register_socket_fd(fd_); } } @@ -50,7 +50,13 @@ class LwIPSocketImpl : public Socket { int fd = lwip_accept(fd_, addr, addrlen); if (fd == -1) return {}; - return make_unique(fd); + return make_unique(fd); // Default: not monitored + } + std::unique_ptr accept_monitored(struct sockaddr *addr, socklen_t *addrlen) override { + int fd = lwip_accept(fd_, addr, addrlen); + if (fd == -1) + return {}; + return make_unique(fd, true); // Monitored for incoming data } int bind(const struct sockaddr *addr, socklen_t addrlen) override { return lwip_bind(fd_, addr, addrlen); } int close() override { @@ -123,6 +129,13 @@ std::unique_ptr socket(int domain, int type, int protocol) { return std::unique_ptr{new LwIPSocketImpl(ret)}; } +std::unique_ptr socket_monitored(int domain, int type, int protocol) { + int ret = lwip_socket(domain, type, protocol); + if (ret == -1) + return nullptr; + return std::unique_ptr{new LwIPSocketImpl(ret, true)}; +} + } // namespace socket } // namespace esphome diff --git a/esphome/components/socket/socket.cpp b/esphome/components/socket/socket.cpp index e260fce05e..3554b926fd 100644 --- a/esphome/components/socket/socket.cpp +++ b/esphome/components/socket/socket.cpp @@ -18,6 +18,14 @@ std::unique_ptr socket_ip(int type, int protocol) { #endif /* USE_NETWORK_IPV6 */ } +std::unique_ptr socket_ip_monitored(int type, int protocol) { +#if USE_NETWORK_IPV6 + return socket_monitored(AF_INET6, type, protocol); +#else + return socket_monitored(AF_INET, type, protocol); +#endif /* USE_NETWORK_IPV6 */ +} + socklen_t set_sockaddr(struct sockaddr *addr, socklen_t addrlen, const std::string &ip_address, uint16_t port) { #if USE_NETWORK_IPV6 if (ip_address.find(':') != std::string::npos) { diff --git a/esphome/components/socket/socket.h b/esphome/components/socket/socket.h index adc1e6b9c1..0b8d2b3f81 100644 --- a/esphome/components/socket/socket.h +++ b/esphome/components/socket/socket.h @@ -17,6 +17,11 @@ class Socket { Socket &operator=(const Socket &) = delete; virtual std::unique_ptr accept(struct sockaddr *addr, socklen_t *addrlen) = 0; + /// Accept a connection and optionally monitor it in the main loop + /// NOTE: Setting monitor_loop is NOT thread-safe and must only be called from the main loop + virtual std::unique_ptr accept_monitored(struct sockaddr *addr, socklen_t *addrlen) { + return accept(addr, addrlen); // Default implementation for backward compatibility + } virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0; virtual int close() = 0; // not supported yet: @@ -52,9 +57,17 @@ class Socket { /// Create a socket of the given domain, type and protocol. std::unique_ptr socket(int domain, int type, int protocol); +/// Create a socket and monitor it for data in the main loop +/// NOTE: Setting monitor_loop is NOT thread-safe and must only be called from the main loop +std::unique_ptr socket_monitored(int domain, int type, int protocol); + /// Create a socket in the newest available IP domain (IPv6 or IPv4) of the given type and protocol. std::unique_ptr socket_ip(int type, int protocol); +/// Create a socket in the newest available IP domain and monitor it for data in the main loop +/// NOTE: Setting monitor_loop is NOT thread-safe and must only be called from the main loop +std::unique_ptr socket_ip_monitored(int type, int protocol); + /// Set a sockaddr to the specified address and port for the IP version used by socket_ip(). socklen_t set_sockaddr(struct sockaddr *addr, socklen_t addrlen, const std::string &ip_address, uint16_t port);