thread safe
This commit is contained in:
@@ -43,7 +43,7 @@ void APIServer::setup() {
|
||||
}
|
||||
#endif
|
||||
|
||||
this->socket_ = socket::socket_ip(SOCK_STREAM, 0);
|
||||
this->socket_ = socket::socket_ip_monitored(SOCK_STREAM, 0); // monitored for incoming connections
|
||||
if (this->socket_ == nullptr) {
|
||||
ESP_LOGW(TAG, "Could not create socket");
|
||||
this->mark_failed();
|
||||
@@ -118,7 +118,7 @@ void APIServer::loop() {
|
||||
while (true) {
|
||||
struct sockaddr_storage source_addr;
|
||||
socklen_t addr_len = sizeof(source_addr);
|
||||
auto sock = this->socket_->accept((struct sockaddr *) &source_addr, &addr_len);
|
||||
auto sock = this->socket_->accept_monitored((struct sockaddr *) &source_addr, &addr_len);
|
||||
if (!sock)
|
||||
break;
|
||||
ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str());
|
||||
|
||||
@@ -26,7 +26,7 @@ void ESPHomeOTAComponent::setup() {
|
||||
ota::register_ota_platform(this);
|
||||
#endif
|
||||
|
||||
server_ = socket::socket_ip(SOCK_STREAM, 0);
|
||||
server_ = socket::socket_ip_monitored(SOCK_STREAM, 0); // monitored for incoming connections
|
||||
if (server_ == nullptr) {
|
||||
ESP_LOGW(TAG, "Could not create socket");
|
||||
this->mark_failed();
|
||||
|
||||
@@ -41,9 +41,9 @@ std::string format_sockaddr(const struct sockaddr_storage &storage) {
|
||||
|
||||
class BSDSocketImpl : public Socket {
|
||||
public:
|
||||
BSDSocketImpl(int fd) : fd_(fd) {
|
||||
// Register new socket with the application for select()
|
||||
if (fd_ >= 0) {
|
||||
BSDSocketImpl(int fd, bool monitor_loop = false) : fd_(fd) {
|
||||
// Register new socket with the application for select() if monitoring requested
|
||||
if (monitor_loop && fd_ >= 0) {
|
||||
App.register_socket_fd(fd_);
|
||||
}
|
||||
}
|
||||
@@ -57,7 +57,13 @@ class BSDSocketImpl : public Socket {
|
||||
int fd = ::accept(fd_, addr, addrlen);
|
||||
if (fd == -1)
|
||||
return {};
|
||||
return make_unique<BSDSocketImpl>(fd);
|
||||
return make_unique<BSDSocketImpl>(fd); // Default: not monitored
|
||||
}
|
||||
std::unique_ptr<Socket> accept_monitored(struct sockaddr *addr, socklen_t *addrlen) override {
|
||||
int fd = ::accept(fd_, addr, addrlen);
|
||||
if (fd == -1)
|
||||
return {};
|
||||
return make_unique<BSDSocketImpl>(fd, true); // Monitored for incoming data
|
||||
}
|
||||
int bind(const struct sockaddr *addr, socklen_t addrlen) override { return ::bind(fd_, addr, addrlen); }
|
||||
int close() override {
|
||||
@@ -151,6 +157,13 @@ std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
|
||||
return std::unique_ptr<Socket>{new BSDSocketImpl(ret)};
|
||||
}
|
||||
|
||||
std::unique_ptr<Socket> socket_monitored(int domain, int type, int protocol) {
|
||||
int ret = ::socket(domain, type, protocol);
|
||||
if (ret == -1)
|
||||
return nullptr;
|
||||
return std::unique_ptr<Socket>{new BSDSocketImpl(ret, true)};
|
||||
}
|
||||
|
||||
} // namespace socket
|
||||
} // namespace esphome
|
||||
|
||||
|
||||
@@ -606,6 +606,11 @@ std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
|
||||
return std::unique_ptr<Socket>{sock};
|
||||
}
|
||||
|
||||
std::unique_ptr<Socket> socket_monitored(int domain, int type, int protocol) {
|
||||
// LWIPRawImpl doesn't use file descriptors, so monitoring is not applicable
|
||||
return socket(domain, type, protocol);
|
||||
}
|
||||
|
||||
} // namespace socket
|
||||
} // namespace esphome
|
||||
|
||||
|
||||
@@ -34,9 +34,9 @@ std::string format_sockaddr(const struct sockaddr_storage &storage) {
|
||||
|
||||
class LwIPSocketImpl : public Socket {
|
||||
public:
|
||||
LwIPSocketImpl(int fd) : fd_(fd) {
|
||||
// Register new socket with the application for select()
|
||||
if (fd_ >= 0) {
|
||||
LwIPSocketImpl(int fd, bool monitor_loop = false) : fd_(fd) {
|
||||
// Register new socket with the application for select() if monitoring requested
|
||||
if (monitor_loop && fd_ >= 0) {
|
||||
App.register_socket_fd(fd_);
|
||||
}
|
||||
}
|
||||
@@ -50,7 +50,13 @@ class LwIPSocketImpl : public Socket {
|
||||
int fd = lwip_accept(fd_, addr, addrlen);
|
||||
if (fd == -1)
|
||||
return {};
|
||||
return make_unique<LwIPSocketImpl>(fd);
|
||||
return make_unique<LwIPSocketImpl>(fd); // Default: not monitored
|
||||
}
|
||||
std::unique_ptr<Socket> accept_monitored(struct sockaddr *addr, socklen_t *addrlen) override {
|
||||
int fd = lwip_accept(fd_, addr, addrlen);
|
||||
if (fd == -1)
|
||||
return {};
|
||||
return make_unique<LwIPSocketImpl>(fd, true); // Monitored for incoming data
|
||||
}
|
||||
int bind(const struct sockaddr *addr, socklen_t addrlen) override { return lwip_bind(fd_, addr, addrlen); }
|
||||
int close() override {
|
||||
@@ -123,6 +129,13 @@ std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
|
||||
return std::unique_ptr<Socket>{new LwIPSocketImpl(ret)};
|
||||
}
|
||||
|
||||
std::unique_ptr<Socket> socket_monitored(int domain, int type, int protocol) {
|
||||
int ret = lwip_socket(domain, type, protocol);
|
||||
if (ret == -1)
|
||||
return nullptr;
|
||||
return std::unique_ptr<Socket>{new LwIPSocketImpl(ret, true)};
|
||||
}
|
||||
|
||||
} // namespace socket
|
||||
} // namespace esphome
|
||||
|
||||
|
||||
@@ -18,6 +18,14 @@ std::unique_ptr<Socket> socket_ip(int type, int protocol) {
|
||||
#endif /* USE_NETWORK_IPV6 */
|
||||
}
|
||||
|
||||
std::unique_ptr<Socket> socket_ip_monitored(int type, int protocol) {
|
||||
#if USE_NETWORK_IPV6
|
||||
return socket_monitored(AF_INET6, type, protocol);
|
||||
#else
|
||||
return socket_monitored(AF_INET, type, protocol);
|
||||
#endif /* USE_NETWORK_IPV6 */
|
||||
}
|
||||
|
||||
socklen_t set_sockaddr(struct sockaddr *addr, socklen_t addrlen, const std::string &ip_address, uint16_t port) {
|
||||
#if USE_NETWORK_IPV6
|
||||
if (ip_address.find(':') != std::string::npos) {
|
||||
|
||||
@@ -17,6 +17,11 @@ class Socket {
|
||||
Socket &operator=(const Socket &) = delete;
|
||||
|
||||
virtual std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) = 0;
|
||||
/// Accept a connection and optionally monitor it in the main loop
|
||||
/// NOTE: Setting monitor_loop is NOT thread-safe and must only be called from the main loop
|
||||
virtual std::unique_ptr<Socket> accept_monitored(struct sockaddr *addr, socklen_t *addrlen) {
|
||||
return accept(addr, addrlen); // Default implementation for backward compatibility
|
||||
}
|
||||
virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0;
|
||||
virtual int close() = 0;
|
||||
// not supported yet:
|
||||
@@ -52,9 +57,17 @@ class Socket {
|
||||
/// Create a socket of the given domain, type and protocol.
|
||||
std::unique_ptr<Socket> socket(int domain, int type, int protocol);
|
||||
|
||||
/// Create a socket and monitor it for data in the main loop
|
||||
/// NOTE: Setting monitor_loop is NOT thread-safe and must only be called from the main loop
|
||||
std::unique_ptr<Socket> socket_monitored(int domain, int type, int protocol);
|
||||
|
||||
/// Create a socket in the newest available IP domain (IPv6 or IPv4) of the given type and protocol.
|
||||
std::unique_ptr<Socket> socket_ip(int type, int protocol);
|
||||
|
||||
/// Create a socket in the newest available IP domain and monitor it for data in the main loop
|
||||
/// NOTE: Setting monitor_loop is NOT thread-safe and must only be called from the main loop
|
||||
std::unique_ptr<Socket> socket_ip_monitored(int type, int protocol);
|
||||
|
||||
/// Set a sockaddr to the specified address and port for the IP version used by socket_ip().
|
||||
socklen_t set_sockaddr(struct sockaddr *addr, socklen_t addrlen, const std::string &ip_address, uint16_t port);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user