thread safe

This commit is contained in:
J. Nick Koston
2025-05-27 09:59:58 -05:00
parent 281194738c
commit 5249736855
7 changed files with 63 additions and 11 deletions

View File

@@ -43,7 +43,7 @@ void APIServer::setup() {
}
#endif
this->socket_ = socket::socket_ip(SOCK_STREAM, 0);
this->socket_ = socket::socket_ip_monitored(SOCK_STREAM, 0); // monitored for incoming connections
if (this->socket_ == nullptr) {
ESP_LOGW(TAG, "Could not create socket");
this->mark_failed();
@@ -118,7 +118,7 @@ void APIServer::loop() {
while (true) {
struct sockaddr_storage source_addr;
socklen_t addr_len = sizeof(source_addr);
auto sock = this->socket_->accept((struct sockaddr *) &source_addr, &addr_len);
auto sock = this->socket_->accept_monitored((struct sockaddr *) &source_addr, &addr_len);
if (!sock)
break;
ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str());

View File

@@ -26,7 +26,7 @@ void ESPHomeOTAComponent::setup() {
ota::register_ota_platform(this);
#endif
server_ = socket::socket_ip(SOCK_STREAM, 0);
server_ = socket::socket_ip_monitored(SOCK_STREAM, 0); // monitored for incoming connections
if (server_ == nullptr) {
ESP_LOGW(TAG, "Could not create socket");
this->mark_failed();

View File

@@ -41,9 +41,9 @@ std::string format_sockaddr(const struct sockaddr_storage &storage) {
class BSDSocketImpl : public Socket {
public:
BSDSocketImpl(int fd) : fd_(fd) {
// Register new socket with the application for select()
if (fd_ >= 0) {
BSDSocketImpl(int fd, bool monitor_loop = false) : fd_(fd) {
// Register new socket with the application for select() if monitoring requested
if (monitor_loop && fd_ >= 0) {
App.register_socket_fd(fd_);
}
}
@@ -57,7 +57,13 @@ class BSDSocketImpl : public Socket {
int fd = ::accept(fd_, addr, addrlen);
if (fd == -1)
return {};
return make_unique<BSDSocketImpl>(fd);
return make_unique<BSDSocketImpl>(fd); // Default: not monitored
}
std::unique_ptr<Socket> accept_monitored(struct sockaddr *addr, socklen_t *addrlen) override {
int fd = ::accept(fd_, addr, addrlen);
if (fd == -1)
return {};
return make_unique<BSDSocketImpl>(fd, true); // Monitored for incoming data
}
int bind(const struct sockaddr *addr, socklen_t addrlen) override { return ::bind(fd_, addr, addrlen); }
int close() override {
@@ -151,6 +157,13 @@ std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
return std::unique_ptr<Socket>{new BSDSocketImpl(ret)};
}
std::unique_ptr<Socket> socket_monitored(int domain, int type, int protocol) {
int ret = ::socket(domain, type, protocol);
if (ret == -1)
return nullptr;
return std::unique_ptr<Socket>{new BSDSocketImpl(ret, true)};
}
} // namespace socket
} // namespace esphome

View File

@@ -606,6 +606,11 @@ std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
return std::unique_ptr<Socket>{sock};
}
std::unique_ptr<Socket> socket_monitored(int domain, int type, int protocol) {
// LWIPRawImpl doesn't use file descriptors, so monitoring is not applicable
return socket(domain, type, protocol);
}
} // namespace socket
} // namespace esphome

View File

@@ -34,9 +34,9 @@ std::string format_sockaddr(const struct sockaddr_storage &storage) {
class LwIPSocketImpl : public Socket {
public:
LwIPSocketImpl(int fd) : fd_(fd) {
// Register new socket with the application for select()
if (fd_ >= 0) {
LwIPSocketImpl(int fd, bool monitor_loop = false) : fd_(fd) {
// Register new socket with the application for select() if monitoring requested
if (monitor_loop && fd_ >= 0) {
App.register_socket_fd(fd_);
}
}
@@ -50,7 +50,13 @@ class LwIPSocketImpl : public Socket {
int fd = lwip_accept(fd_, addr, addrlen);
if (fd == -1)
return {};
return make_unique<LwIPSocketImpl>(fd);
return make_unique<LwIPSocketImpl>(fd); // Default: not monitored
}
std::unique_ptr<Socket> accept_monitored(struct sockaddr *addr, socklen_t *addrlen) override {
int fd = lwip_accept(fd_, addr, addrlen);
if (fd == -1)
return {};
return make_unique<LwIPSocketImpl>(fd, true); // Monitored for incoming data
}
int bind(const struct sockaddr *addr, socklen_t addrlen) override { return lwip_bind(fd_, addr, addrlen); }
int close() override {
@@ -123,6 +129,13 @@ std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
return std::unique_ptr<Socket>{new LwIPSocketImpl(ret)};
}
std::unique_ptr<Socket> socket_monitored(int domain, int type, int protocol) {
int ret = lwip_socket(domain, type, protocol);
if (ret == -1)
return nullptr;
return std::unique_ptr<Socket>{new LwIPSocketImpl(ret, true)};
}
} // namespace socket
} // namespace esphome

View File

@@ -18,6 +18,14 @@ std::unique_ptr<Socket> socket_ip(int type, int protocol) {
#endif /* USE_NETWORK_IPV6 */
}
std::unique_ptr<Socket> socket_ip_monitored(int type, int protocol) {
#if USE_NETWORK_IPV6
return socket_monitored(AF_INET6, type, protocol);
#else
return socket_monitored(AF_INET, type, protocol);
#endif /* USE_NETWORK_IPV6 */
}
socklen_t set_sockaddr(struct sockaddr *addr, socklen_t addrlen, const std::string &ip_address, uint16_t port) {
#if USE_NETWORK_IPV6
if (ip_address.find(':') != std::string::npos) {

View File

@@ -17,6 +17,11 @@ class Socket {
Socket &operator=(const Socket &) = delete;
virtual std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) = 0;
/// Accept a connection and optionally monitor it in the main loop
/// NOTE: Setting monitor_loop is NOT thread-safe and must only be called from the main loop
virtual std::unique_ptr<Socket> accept_monitored(struct sockaddr *addr, socklen_t *addrlen) {
return accept(addr, addrlen); // Default implementation for backward compatibility
}
virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0;
virtual int close() = 0;
// not supported yet:
@@ -52,9 +57,17 @@ class Socket {
/// Create a socket of the given domain, type and protocol.
std::unique_ptr<Socket> socket(int domain, int type, int protocol);
/// Create a socket and monitor it for data in the main loop
/// NOTE: Setting monitor_loop is NOT thread-safe and must only be called from the main loop
std::unique_ptr<Socket> socket_monitored(int domain, int type, int protocol);
/// Create a socket in the newest available IP domain (IPv6 or IPv4) of the given type and protocol.
std::unique_ptr<Socket> socket_ip(int type, int protocol);
/// Create a socket in the newest available IP domain and monitor it for data in the main loop
/// NOTE: Setting monitor_loop is NOT thread-safe and must only be called from the main loop
std::unique_ptr<Socket> socket_ip_monitored(int type, int protocol);
/// Set a sockaddr to the specified address and port for the IP version used by socket_ip().
socklen_t set_sockaddr(struct sockaddr *addr, socklen_t addrlen, const std::string &ip_address, uint16_t port);