From 988b888c6308d3b819918a6ec50a31aca76b8e43 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 19 Dec 2025 11:19:07 -1000 Subject: [PATCH] [ota] Replace std::function callbacks with listener interface (#12167) --- .../components/esp32_ble_tracker/__init__.py | 4 +- .../esp32_ble_tracker/esp32_ble_tracker.cpp | 27 +++--- .../esp32_ble_tracker/esp32_ble_tracker.h | 11 +++ .../components/esphome/ota/ota_esphome.cpp | 20 ++-- .../http_request/ota/ota_http_request.cpp | 22 ++--- .../http_request/ota/ota_http_request.h | 1 - .../http_request/update/__init__.py | 4 +- .../update/http_request_update.cpp | 26 +++--- .../http_request/update/http_request_update.h | 4 +- .../components/micro_wake_word/__init__.py | 4 +- .../micro_wake_word/micro_wake_word.cpp | 21 +++-- .../micro_wake_word/micro_wake_word.h | 16 +++- esphome/components/ota/__init__.py | 22 ++++- esphome/components/ota/automation.h | 92 +++++++++---------- esphome/components/ota/ota_backend.cpp | 9 +- esphome/components/ota/ota_backend.h | 91 ++++++++++-------- .../speaker/media_player/__init__.py | 4 +- .../media_player/speaker_media_player.cpp | 42 +++++---- .../media_player/speaker_media_player.h | 18 +++- .../web_server/ota/ota_web_server.cpp | 40 ++++---- esphome/core/defines.h | 2 +- 21 files changed, 274 insertions(+), 206 deletions(-) diff --git a/esphome/components/esp32_ble_tracker/__init__.py b/esphome/components/esp32_ble_tracker/__init__.py index 4e25434aad..37e74672ed 100644 --- a/esphome/components/esp32_ble_tracker/__init__.py +++ b/esphome/components/esp32_ble_tracker/__init__.py @@ -5,7 +5,7 @@ import logging from esphome import automation import esphome.codegen as cg -from esphome.components import esp32_ble +from esphome.components import esp32_ble, ota from esphome.components.esp32 import add_idf_sdkconfig_option from esphome.components.esp32_ble import ( IDF_MAX_CONNECTIONS, @@ -328,7 +328,7 @@ async def to_code(config): # Note: CONFIG_BT_ACL_CONNECTIONS and CONFIG_BTDM_CTRL_BLE_MAX_CONN are now # configured in esp32_ble component based on max_connections setting - cg.add_define("USE_OTA_STATE_CALLBACK") # To be notified when an OTA update starts + ota.request_ota_state_listeners() # To be notified when an OTA update starts cg.add_define("USE_ESP32_BLE_CLIENT") CORE.add_job(_add_ble_features) diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp index cb83eb5a0d..47da2e3570 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp @@ -71,21 +71,24 @@ void ESP32BLETracker::setup() { global_esp32_ble_tracker = this; -#ifdef USE_OTA - ota::get_global_ota_callback()->add_on_state_callback( - [this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) { - if (state == ota::OTA_STARTED) { - this->stop_scan(); -#ifdef ESPHOME_ESP32_BLE_TRACKER_CLIENT_COUNT - for (auto *client : this->clients_) { - client->disconnect(); - } -#endif - } - }); +#ifdef USE_OTA_STATE_LISTENER + ota::get_global_ota_callback()->add_global_state_listener(this); #endif } +#ifdef USE_OTA_STATE_LISTENER +void ESP32BLETracker::on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) { + if (state == ota::OTA_STARTED) { + this->stop_scan(); +#ifdef ESPHOME_ESP32_BLE_TRACKER_CLIENT_COUNT + for (auto *client : this->clients_) { + client->disconnect(); + } +#endif + } +} +#endif + void ESP32BLETracker::loop() { if (!this->parent_->is_active()) { this->ble_was_disabled_ = true; diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h index 92d13a62ad..b64e36279c 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h @@ -22,6 +22,10 @@ #include "esphome/components/esp32_ble/ble_uuid.h" #include "esphome/components/esp32_ble/ble_scan_result.h" +#ifdef USE_OTA_STATE_LISTENER +#include "esphome/components/ota/ota_backend.h" +#endif + namespace esphome::esp32_ble_tracker { using namespace esp32_ble; @@ -241,6 +245,9 @@ class ESP32BLETracker : public Component, public GAPScanEventHandler, public GATTcEventHandler, public BLEStatusEventHandler, +#ifdef USE_OTA_STATE_LISTENER + public ota::OTAGlobalStateListener, +#endif public Parented { public: void set_scan_duration(uint32_t scan_duration) { scan_duration_ = scan_duration; } @@ -274,6 +281,10 @@ class ESP32BLETracker : public Component, void gap_scan_event_handler(const BLEScanResult &scan_result) override; void ble_before_disabled_event_handler() override; +#ifdef USE_OTA_STATE_LISTENER + void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override; +#endif + /// Add a listener for scanner state changes void add_scanner_state_listener(BLEScannerStateListener *listener) { this->scanner_state_listeners_.push_back(listener); diff --git a/esphome/components/esphome/ota/ota_esphome.cpp b/esphome/components/esphome/ota/ota_esphome.cpp index 6cfd543553..b589a6119f 100644 --- a/esphome/components/esphome/ota/ota_esphome.cpp +++ b/esphome/components/esphome/ota/ota_esphome.cpp @@ -41,10 +41,6 @@ static constexpr size_t SHA256_HEX_SIZE = 64; // SHA256 hash as hex string (32 #endif // USE_OTA_PASSWORD void ESPHomeOTAComponent::setup() { -#ifdef USE_OTA_STATE_CALLBACK - ota::register_ota_platform(this); -#endif - this->server_ = socket::socket_ip_loop_monitored(SOCK_STREAM, 0); // monitored for incoming connections if (this->server_ == nullptr) { this->log_socket_error_(LOG_STR("creation")); @@ -297,8 +293,8 @@ void ESPHomeOTAComponent::handle_data_() { // accidentally trigger the update process. this->log_start_(LOG_STR("update")); this->status_set_warning(); -#ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0); +#ifdef USE_OTA_STATE_LISTENER + this->notify_state_(ota::OTA_STARTED, 0.0f, 0); #endif // This will block for a few seconds as it locks flash @@ -357,8 +353,8 @@ void ESPHomeOTAComponent::handle_data_() { last_progress = now; float percentage = (total * 100.0f) / ota_size; ESP_LOGD(TAG, "Progress: %0.1f%%", percentage); -#ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(ota::OTA_IN_PROGRESS, percentage, 0); +#ifdef USE_OTA_STATE_LISTENER + this->notify_state_(ota::OTA_IN_PROGRESS, percentage, 0); #endif // feed watchdog and give other tasks a chance to run this->yield_and_feed_watchdog_(); @@ -387,8 +383,8 @@ void ESPHomeOTAComponent::handle_data_() { delay(10); ESP_LOGI(TAG, "Update complete"); this->status_clear_warning(); -#ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(ota::OTA_COMPLETED, 100.0f, 0); +#ifdef USE_OTA_STATE_LISTENER + this->notify_state_(ota::OTA_COMPLETED, 100.0f, 0); #endif delay(100); // NOLINT App.safe_reboot(); @@ -402,8 +398,8 @@ error: } this->status_momentary_error("err", 5000); -#ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(ota::OTA_ERROR, 0.0f, static_cast(error_code)); +#ifdef USE_OTA_STATE_LISTENER + this->notify_state_(ota::OTA_ERROR, 0.0f, static_cast(error_code)); #endif } diff --git a/esphome/components/http_request/ota/ota_http_request.cpp b/esphome/components/http_request/ota/ota_http_request.cpp index b257518e06..058579752e 100644 --- a/esphome/components/http_request/ota/ota_http_request.cpp +++ b/esphome/components/http_request/ota/ota_http_request.cpp @@ -16,12 +16,6 @@ namespace http_request { static const char *const TAG = "http_request.ota"; -void OtaHttpRequestComponent::setup() { -#ifdef USE_OTA_STATE_CALLBACK - ota::register_ota_platform(this); -#endif -} - void OtaHttpRequestComponent::dump_config() { ESP_LOGCONFIG(TAG, "Over-The-Air updates via HTTP request"); }; void OtaHttpRequestComponent::set_md5_url(const std::string &url) { @@ -48,24 +42,24 @@ void OtaHttpRequestComponent::flash() { } ESP_LOGI(TAG, "Starting update"); -#ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0); +#ifdef USE_OTA_STATE_LISTENER + this->notify_state_(ota::OTA_STARTED, 0.0f, 0); #endif auto ota_status = this->do_ota_(); switch (ota_status) { case ota::OTA_RESPONSE_OK: -#ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(ota::OTA_COMPLETED, 100.0f, ota_status); +#ifdef USE_OTA_STATE_LISTENER + this->notify_state_(ota::OTA_COMPLETED, 100.0f, ota_status); #endif delay(10); App.safe_reboot(); break; default: -#ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(ota::OTA_ERROR, 0.0f, ota_status); +#ifdef USE_OTA_STATE_LISTENER + this->notify_state_(ota::OTA_ERROR, 0.0f, ota_status); #endif this->md5_computed_.clear(); // will be reset at next attempt this->md5_expected_.clear(); // will be reset at next attempt @@ -165,8 +159,8 @@ uint8_t OtaHttpRequestComponent::do_ota_() { last_progress = now; float percentage = container->get_bytes_read() * 100.0f / container->content_length; ESP_LOGD(TAG, "Progress: %0.1f%%", percentage); -#ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(ota::OTA_IN_PROGRESS, percentage, 0); +#ifdef USE_OTA_STATE_LISTENER + this->notify_state_(ota::OTA_IN_PROGRESS, percentage, 0); #endif } } // while diff --git a/esphome/components/http_request/ota/ota_http_request.h b/esphome/components/http_request/ota/ota_http_request.h index 6a86b4ab43..8735189e99 100644 --- a/esphome/components/http_request/ota/ota_http_request.h +++ b/esphome/components/http_request/ota/ota_http_request.h @@ -24,7 +24,6 @@ enum OtaHttpRequestError : uint8_t { class OtaHttpRequestComponent : public ota::OTAComponent, public Parented { public: - void setup() override; void dump_config() override; float get_setup_priority() const override { return setup_priority::AFTER_WIFI; } diff --git a/esphome/components/http_request/update/__init__.py b/esphome/components/http_request/update/__init__.py index abb4b2a430..d84d80109a 100644 --- a/esphome/components/http_request/update/__init__.py +++ b/esphome/components/http_request/update/__init__.py @@ -1,5 +1,5 @@ import esphome.codegen as cg -from esphome.components import update +from esphome.components import ota, update import esphome.config_validation as cv from esphome.const import CONF_SOURCE @@ -38,6 +38,6 @@ async def to_code(config): cg.add(var.set_source_url(config[CONF_SOURCE])) - cg.add_define("USE_OTA_STATE_CALLBACK") + ota.request_ota_state_listeners() await cg.register_component(var, config) diff --git a/esphome/components/http_request/update/http_request_update.cpp b/esphome/components/http_request/update/http_request_update.cpp index 22cad625d1..a9392ad736 100644 --- a/esphome/components/http_request/update/http_request_update.cpp +++ b/esphome/components/http_request/update/http_request_update.cpp @@ -20,19 +20,19 @@ static const char *const TAG = "http_request.update"; static const size_t MAX_READ_SIZE = 256; -void HttpRequestUpdate::setup() { - this->ota_parent_->add_on_state_callback([this](ota::OTAState state, float progress, uint8_t err) { - if (state == ota::OTAState::OTA_IN_PROGRESS) { - this->state_ = update::UPDATE_STATE_INSTALLING; - this->update_info_.has_progress = true; - this->update_info_.progress = progress; - this->publish_state(); - } else if (state == ota::OTAState::OTA_ABORT || state == ota::OTAState::OTA_ERROR) { - this->state_ = update::UPDATE_STATE_AVAILABLE; - this->status_set_error(LOG_STR("Failed to install firmware")); - this->publish_state(); - } - }); +void HttpRequestUpdate::setup() { this->ota_parent_->add_state_listener(this); } + +void HttpRequestUpdate::on_ota_state(ota::OTAState state, float progress, uint8_t error) { + if (state == ota::OTAState::OTA_IN_PROGRESS) { + this->state_ = update::UPDATE_STATE_INSTALLING; + this->update_info_.has_progress = true; + this->update_info_.progress = progress; + this->publish_state(); + } else if (state == ota::OTAState::OTA_ABORT || state == ota::OTAState::OTA_ERROR) { + this->state_ = update::UPDATE_STATE_AVAILABLE; + this->status_set_error(LOG_STR("Failed to install firmware")); + this->publish_state(); + } } void HttpRequestUpdate::update() { diff --git a/esphome/components/http_request/update/http_request_update.h b/esphome/components/http_request/update/http_request_update.h index e05fdb0cc2..cf34ace18e 100644 --- a/esphome/components/http_request/update/http_request_update.h +++ b/esphome/components/http_request/update/http_request_update.h @@ -14,7 +14,7 @@ namespace esphome { namespace http_request { -class HttpRequestUpdate : public update::UpdateEntity, public PollingComponent { +class HttpRequestUpdate final : public update::UpdateEntity, public PollingComponent, public ota::OTAStateListener { public: void setup() override; void update() override; @@ -29,6 +29,8 @@ class HttpRequestUpdate : public update::UpdateEntity, public PollingComponent { float get_setup_priority() const override { return setup_priority::AFTER_WIFI; } + void on_ota_state(ota::OTAState state, float progress, uint8_t error) override; + protected: HttpRequestComponent *request_parent_; OtaHttpRequestComponent *ota_parent_; diff --git a/esphome/components/micro_wake_word/__init__.py b/esphome/components/micro_wake_word/__init__.py index 575fb97799..0d478f749b 100644 --- a/esphome/components/micro_wake_word/__init__.py +++ b/esphome/components/micro_wake_word/__init__.py @@ -7,7 +7,7 @@ from urllib.parse import urljoin from esphome import automation, external_files, git from esphome.automation import register_action, register_condition import esphome.codegen as cg -from esphome.components import esp32, microphone, socket +from esphome.components import esp32, microphone, ota, socket import esphome.config_validation as cv from esphome.const import ( CONF_FILE, @@ -452,7 +452,7 @@ async def to_code(config): cg.add(var.set_microphone_source(mic_source)) cg.add_define("USE_MICRO_WAKE_WORD") - cg.add_define("USE_OTA_STATE_CALLBACK") + ota.request_ota_state_listeners() esp32.add_idf_component(name="espressif/esp-tflite-micro", ref="1.3.3~1") diff --git a/esphome/components/micro_wake_word/micro_wake_word.cpp b/esphome/components/micro_wake_word/micro_wake_word.cpp index ec8fa34da4..b8377ead38 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.cpp +++ b/esphome/components/micro_wake_word/micro_wake_word.cpp @@ -119,18 +119,21 @@ void MicroWakeWord::setup() { } }); -#ifdef USE_OTA - ota::get_global_ota_callback()->add_on_state_callback( - [this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) { - if (state == ota::OTA_STARTED) { - this->suspend_task_(); - } else if (state == ota::OTA_ERROR) { - this->resume_task_(); - } - }); +#ifdef USE_OTA_STATE_LISTENER + ota::get_global_ota_callback()->add_global_state_listener(this); #endif } +#ifdef USE_OTA_STATE_LISTENER +void MicroWakeWord::on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) { + if (state == ota::OTA_STARTED) { + this->suspend_task_(); + } else if (state == ota::OTA_ERROR) { + this->resume_task_(); + } +} +#endif + void MicroWakeWord::inference_task(void *params) { MicroWakeWord *this_mww = (MicroWakeWord *) params; diff --git a/esphome/components/micro_wake_word/micro_wake_word.h b/esphome/components/micro_wake_word/micro_wake_word.h index d46c40e48b..84261eaa5b 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.h +++ b/esphome/components/micro_wake_word/micro_wake_word.h @@ -9,8 +9,13 @@ #include "esphome/core/automation.h" #include "esphome/core/component.h" +#include "esphome/core/defines.h" #include "esphome/core/ring_buffer.h" +#ifdef USE_OTA_STATE_LISTENER +#include "esphome/components/ota/ota_backend.h" +#endif + #include #include @@ -26,13 +31,22 @@ enum State { STOPPED, }; -class MicroWakeWord : public Component { +class MicroWakeWord : public Component +#ifdef USE_OTA_STATE_LISTENER + , + public ota::OTAGlobalStateListener +#endif +{ public: void setup() override; void loop() override; float get_setup_priority() const override; void dump_config() override; +#ifdef USE_OTA_STATE_LISTENER + void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override; +#endif + void start(); void stop(); diff --git a/esphome/components/ota/__init__.py b/esphome/components/ota/__init__.py index be1b6da241..8bed9cee42 100644 --- a/esphome/components/ota/__init__.py +++ b/esphome/components/ota/__init__.py @@ -13,6 +13,8 @@ from esphome.const import ( from esphome.core import CORE, coroutine_with_priority from esphome.coroutine import CoroPriority +OTA_STATE_LISTENER_KEY = "ota_state_listener" + CODEOWNERS = ["@esphome/core"] AUTO_LOAD = ["md5", "safe_mode"] @@ -86,6 +88,7 @@ BASE_OTA_SCHEMA = cv.Schema( @coroutine_with_priority(CoroPriority.OTA_UPDATES) async def to_code(config): cg.add_define("USE_OTA") + CORE.add_job(final_step) if CORE.is_rp2040 and CORE.using_arduino: cg.add_library("Updater", None) @@ -119,7 +122,24 @@ async def ota_to_code(var, config): await automation.build_automation(trigger, [(cg.uint8, "x")], conf) use_state_callback = True if use_state_callback: - cg.add_define("USE_OTA_STATE_CALLBACK") + request_ota_state_listeners() + + +def request_ota_state_listeners() -> None: + """Request that OTA state listeners be compiled in. + + Components that need to be notified about OTA state changes (start, progress, + complete, error) should call this function during their code generation. + This enables the add_state_listener() API on OTAComponent. + """ + CORE.data[OTA_STATE_LISTENER_KEY] = True + + +@coroutine_with_priority(CoroPriority.FINAL) +async def final_step(): + """Final code generation step to configure optional OTA features.""" + if CORE.data.get(OTA_STATE_LISTENER_KEY, False): + cg.add_define("USE_OTA_STATE_LISTENER") FILTER_SOURCE_FILES = filter_source_files_from_platform( diff --git a/esphome/components/ota/automation.h b/esphome/components/ota/automation.h index 7e1a60f3ce..92c0050ba0 100644 --- a/esphome/components/ota/automation.h +++ b/esphome/components/ota/automation.h @@ -1,5 +1,5 @@ #pragma once -#ifdef USE_OTA_STATE_CALLBACK +#ifdef USE_OTA_STATE_LISTENER #include "ota_backend.h" #include "esphome/core/automation.h" @@ -7,70 +7,64 @@ namespace esphome { namespace ota { -class OTAStateChangeTrigger : public Trigger { +class OTAStateChangeTrigger final : public Trigger, public OTAStateListener { public: - explicit OTAStateChangeTrigger(OTAComponent *parent) { - parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { - if (!parent->is_failed()) { - trigger(state); - } - }); + explicit OTAStateChangeTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); } + + void on_ota_state(OTAState state, float progress, uint8_t error) override { + if (!this->parent_->is_failed()) { + this->trigger(state); + } } + + protected: + OTAComponent *parent_; }; -class OTAStartTrigger : public Trigger<> { +template class OTAStateTrigger final : public Trigger<>, public OTAStateListener { public: - explicit OTAStartTrigger(OTAComponent *parent) { - parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { - if (state == OTA_STARTED && !parent->is_failed()) { - trigger(); - } - }); + explicit OTAStateTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); } + + void on_ota_state(OTAState state, float progress, uint8_t error) override { + if (state == State && !this->parent_->is_failed()) { + this->trigger(); + } } + + protected: + OTAComponent *parent_; }; -class OTAProgressTrigger : public Trigger { +using OTAStartTrigger = OTAStateTrigger; +using OTAEndTrigger = OTAStateTrigger; +using OTAAbortTrigger = OTAStateTrigger; + +class OTAProgressTrigger final : public Trigger, public OTAStateListener { public: - explicit OTAProgressTrigger(OTAComponent *parent) { - parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { - if (state == OTA_IN_PROGRESS && !parent->is_failed()) { - trigger(progress); - } - }); + explicit OTAProgressTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); } + + void on_ota_state(OTAState state, float progress, uint8_t error) override { + if (state == OTA_IN_PROGRESS && !this->parent_->is_failed()) { + this->trigger(progress); + } } + + protected: + OTAComponent *parent_; }; -class OTAEndTrigger : public Trigger<> { +class OTAErrorTrigger final : public Trigger, public OTAStateListener { public: - explicit OTAEndTrigger(OTAComponent *parent) { - parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { - if (state == OTA_COMPLETED && !parent->is_failed()) { - trigger(); - } - }); - } -}; + explicit OTAErrorTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); } -class OTAAbortTrigger : public Trigger<> { - public: - explicit OTAAbortTrigger(OTAComponent *parent) { - parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { - if (state == OTA_ABORT && !parent->is_failed()) { - trigger(); - } - }); + void on_ota_state(OTAState state, float progress, uint8_t error) override { + if (state == OTA_ERROR && !this->parent_->is_failed()) { + this->trigger(error); + } } -}; -class OTAErrorTrigger : public Trigger { - public: - explicit OTAErrorTrigger(OTAComponent *parent) { - parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { - if (state == OTA_ERROR && !parent->is_failed()) { - trigger(error); - } - }); - } + protected: + OTAComponent *parent_; }; } // namespace ota diff --git a/esphome/components/ota/ota_backend.cpp b/esphome/components/ota/ota_backend.cpp index 30de4ec4b3..8fb9f67214 100644 --- a/esphome/components/ota/ota_backend.cpp +++ b/esphome/components/ota/ota_backend.cpp @@ -3,7 +3,7 @@ namespace esphome { namespace ota { -#ifdef USE_OTA_STATE_CALLBACK +#ifdef USE_OTA_STATE_LISTENER OTAGlobalCallback *global_ota_callback{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) OTAGlobalCallback *get_global_ota_callback() { @@ -13,7 +13,12 @@ OTAGlobalCallback *get_global_ota_callback() { return global_ota_callback; } -void register_ota_platform(OTAComponent *ota_caller) { get_global_ota_callback()->register_ota(ota_caller); } +void OTAComponent::notify_state_(OTAState state, float progress, uint8_t error) { + for (auto *listener : this->state_listeners_) { + listener->on_ota_state(state, progress, error); + } + get_global_ota_callback()->notify_ota_state(state, progress, error, this); +} #endif } // namespace ota diff --git a/esphome/components/ota/ota_backend.h b/esphome/components/ota/ota_backend.h index 64ee0b9f7c..e03afd4fc6 100644 --- a/esphome/components/ota/ota_backend.h +++ b/esphome/components/ota/ota_backend.h @@ -4,8 +4,8 @@ #include "esphome/core/defines.h" #include "esphome/core/helpers.h" -#ifdef USE_OTA_STATE_CALLBACK -#include "esphome/core/automation.h" +#ifdef USE_OTA_STATE_LISTENER +#include #endif namespace esphome { @@ -60,62 +60,75 @@ class OTABackend { virtual bool supports_compression() = 0; }; -class OTAComponent : public Component { -#ifdef USE_OTA_STATE_CALLBACK +/** Listener interface for OTA state changes. + * + * Components can implement this interface to receive OTA state updates + * without the overhead of std::function callbacks. + */ +class OTAStateListener { public: - void add_on_state_callback(std::function &&callback) { - this->state_callback_.add(std::move(callback)); - } + virtual ~OTAStateListener() = default; + virtual void on_ota_state(OTAState state, float progress, uint8_t error) = 0; +}; + +class OTAComponent : public Component { +#ifdef USE_OTA_STATE_LISTENER + public: + void add_state_listener(OTAStateListener *listener) { this->state_listeners_.push_back(listener); } protected: - /** Extended callback manager with deferred call support. + void notify_state_(OTAState state, float progress, uint8_t error); + + /** Notify state with deferral to main loop (for thread safety). * - * This adds a call_deferred() method for thread-safe execution from other tasks. + * This should be used by OTA implementations that run in separate tasks + * (like web_server OTA) to ensure listeners execute in the main loop. */ - class StateCallbackManager : public CallbackManager { - public: - StateCallbackManager(OTAComponent *component) : component_(component) {} + void notify_state_deferred_(OTAState state, float progress, uint8_t error) { + this->defer([this, state, progress, error]() { this->notify_state_(state, progress, error); }); + } - /** Call callbacks with deferral to main loop (for thread safety). - * - * This should be used by OTA implementations that run in separate tasks - * (like web_server OTA) to ensure callbacks execute in the main loop. - */ - void call_deferred(ota::OTAState state, float progress, uint8_t error) { - component_->defer([this, state, progress, error]() { this->call(state, progress, error); }); - } - - private: - OTAComponent *component_; - }; - - StateCallbackManager state_callback_{this}; + std::vector state_listeners_; #endif }; -#ifdef USE_OTA_STATE_CALLBACK +#ifdef USE_OTA_STATE_LISTENER + +/** Listener interface for global OTA state changes (includes OTA component pointer). + * + * Used by OTAGlobalCallback to aggregate state from multiple OTA components. + */ +class OTAGlobalStateListener { + public: + virtual ~OTAGlobalStateListener() = default; + virtual void on_ota_global_state(OTAState state, float progress, uint8_t error, OTAComponent *component) = 0; +}; + +/** Global callback that aggregates OTA state from all OTA components. + * + * OTA components call notify_ota_state() directly with their pointer, + * which forwards the event to all registered global listeners. + */ class OTAGlobalCallback { public: - void register_ota(OTAComponent *ota_caller) { - ota_caller->add_on_state_callback([this, ota_caller](OTAState state, float progress, uint8_t error) { - this->state_callback_.call(state, progress, error, ota_caller); - }); - } - void add_on_state_callback(std::function &&callback) { - this->state_callback_.add(std::move(callback)); + void add_global_state_listener(OTAGlobalStateListener *listener) { this->global_listeners_.push_back(listener); } + + void notify_ota_state(OTAState state, float progress, uint8_t error, OTAComponent *component) { + for (auto *listener : this->global_listeners_) { + listener->on_ota_global_state(state, progress, error, component); + } } protected: - CallbackManager state_callback_{}; + std::vector global_listeners_; }; OTAGlobalCallback *get_global_ota_callback(); -void register_ota_platform(OTAComponent *ota_caller); // OTA implementations should use: -// - state_callback_.call() when already in main loop (e.g., esphome OTA) -// - state_callback_.call_deferred() when in separate task (e.g., web_server OTA) -// This ensures proper callback execution in all contexts. +// - notify_state_() when already in main loop (e.g., esphome OTA) +// - notify_state_deferred_() when in separate task (e.g., web_server OTA) +// This ensures proper listener execution in all contexts. #endif std::unique_ptr make_ota_backend(); diff --git a/esphome/components/speaker/media_player/__init__.py b/esphome/components/speaker/media_player/__init__.py index 062bff92f8..4ca57f2c4a 100644 --- a/esphome/components/speaker/media_player/__init__.py +++ b/esphome/components/speaker/media_player/__init__.py @@ -6,7 +6,7 @@ from pathlib import Path from esphome import automation, external_files import esphome.codegen as cg -from esphome.components import audio, esp32, media_player, network, psram, speaker +from esphome.components import audio, esp32, media_player, network, ota, psram, speaker import esphome.config_validation as cv from esphome.const import ( CONF_BUFFER_SIZE, @@ -342,7 +342,7 @@ async def to_code(config): var = await media_player.new_media_player(config) await cg.register_component(var, config) - cg.add_define("USE_OTA_STATE_CALLBACK") + ota.request_ota_state_listeners() cg.add(var.set_buffer_size(config[CONF_BUFFER_SIZE])) diff --git a/esphome/components/speaker/media_player/speaker_media_player.cpp b/esphome/components/speaker/media_player/speaker_media_player.cpp index b45a78010a..5722aab195 100644 --- a/esphome/components/speaker/media_player/speaker_media_player.cpp +++ b/esphome/components/speaker/media_player/speaker_media_player.cpp @@ -66,25 +66,8 @@ void SpeakerMediaPlayer::setup() { this->set_mute_state_(false); } -#ifdef USE_OTA - ota::get_global_ota_callback()->add_on_state_callback( - [this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) { - if (state == ota::OTA_STARTED) { - if (this->media_pipeline_ != nullptr) { - this->media_pipeline_->suspend_tasks(); - } - if (this->announcement_pipeline_ != nullptr) { - this->announcement_pipeline_->suspend_tasks(); - } - } else if (state == ota::OTA_ERROR) { - if (this->media_pipeline_ != nullptr) { - this->media_pipeline_->resume_tasks(); - } - if (this->announcement_pipeline_ != nullptr) { - this->announcement_pipeline_->resume_tasks(); - } - } - }); +#ifdef USE_OTA_STATE_LISTENER + ota::get_global_ota_callback()->add_global_state_listener(this); #endif this->announcement_pipeline_ = @@ -300,6 +283,27 @@ void SpeakerMediaPlayer::watch_media_commands_() { } } +#ifdef USE_OTA_STATE_LISTENER +void SpeakerMediaPlayer::on_ota_global_state(ota::OTAState state, float progress, uint8_t error, + ota::OTAComponent *comp) { + if (state == ota::OTA_STARTED) { + if (this->media_pipeline_ != nullptr) { + this->media_pipeline_->suspend_tasks(); + } + if (this->announcement_pipeline_ != nullptr) { + this->announcement_pipeline_->suspend_tasks(); + } + } else if (state == ota::OTA_ERROR) { + if (this->media_pipeline_ != nullptr) { + this->media_pipeline_->resume_tasks(); + } + if (this->announcement_pipeline_ != nullptr) { + this->announcement_pipeline_->resume_tasks(); + } + } +} +#endif + void SpeakerMediaPlayer::loop() { this->watch_media_commands_(); diff --git a/esphome/components/speaker/media_player/speaker_media_player.h b/esphome/components/speaker/media_player/speaker_media_player.h index 967772d1a5..f1c564b63d 100644 --- a/esphome/components/speaker/media_player/speaker_media_player.h +++ b/esphome/components/speaker/media_player/speaker_media_player.h @@ -5,14 +5,18 @@ #include "audio_pipeline.h" #include "esphome/components/audio/audio.h" - #include "esphome/components/media_player/media_player.h" #include "esphome/components/speaker/speaker.h" #include "esphome/core/automation.h" #include "esphome/core/component.h" +#include "esphome/core/defines.h" #include "esphome/core/preferences.h" +#ifdef USE_OTA_STATE_LISTENER +#include "esphome/components/ota/ota_backend.h" +#endif + #include #include #include @@ -39,12 +43,22 @@ struct VolumeRestoreState { bool is_muted; }; -class SpeakerMediaPlayer : public Component, public media_player::MediaPlayer { +class SpeakerMediaPlayer : public Component, + public media_player::MediaPlayer +#ifdef USE_OTA_STATE_LISTENER + , + public ota::OTAGlobalStateListener +#endif +{ public: float get_setup_priority() const override { return esphome::setup_priority::PROCESSOR; } void setup() override; void loop() override; +#ifdef USE_OTA_STATE_LISTENER + void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override; +#endif + // MediaPlayer implementations media_player::MediaPlayerTraits get_traits() override; bool is_muted() const override { return this->is_muted_; } diff --git a/esphome/components/web_server/ota/ota_web_server.cpp b/esphome/components/web_server/ota/ota_web_server.cpp index 7929f3647f..f612aa056c 100644 --- a/esphome/components/web_server/ota/ota_web_server.cpp +++ b/esphome/components/web_server/ota/ota_web_server.cpp @@ -84,9 +84,9 @@ void OTARequestHandler::report_ota_progress_(AsyncWebServerRequest *request) { } else { ESP_LOGD(TAG, "OTA in progress: %" PRIu32 " bytes read", this->ota_read_length_); } -#ifdef USE_OTA_STATE_CALLBACK - // Report progress - use call_deferred since we're in web server task - this->parent_->state_callback_.call_deferred(ota::OTA_IN_PROGRESS, percentage, 0); +#ifdef USE_OTA_STATE_LISTENER + // Report progress - use notify_state_deferred_ since we're in web server task + this->parent_->notify_state_deferred_(ota::OTA_IN_PROGRESS, percentage, 0); #endif this->last_ota_progress_ = now; } @@ -114,9 +114,9 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf // Initialize OTA on first call this->ota_init_(filename.c_str()); -#ifdef USE_OTA_STATE_CALLBACK - // Notify OTA started - use call_deferred since we're in web server task - this->parent_->state_callback_.call_deferred(ota::OTA_STARTED, 0.0f, 0); +#ifdef USE_OTA_STATE_LISTENER + // Notify OTA started - use notify_state_deferred_ since we're in web server task + this->parent_->notify_state_deferred_(ota::OTA_STARTED, 0.0f, 0); #endif // Platform-specific pre-initialization @@ -134,9 +134,9 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf this->ota_backend_ = ota::make_ota_backend(); if (!this->ota_backend_) { ESP_LOGE(TAG, "Failed to create OTA backend"); -#ifdef USE_OTA_STATE_CALLBACK - this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, - static_cast(ota::OTA_RESPONSE_ERROR_UNKNOWN)); +#ifdef USE_OTA_STATE_LISTENER + this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, + static_cast(ota::OTA_RESPONSE_ERROR_UNKNOWN)); #endif return; } @@ -148,8 +148,8 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf if (error_code != ota::OTA_RESPONSE_OK) { ESP_LOGE(TAG, "OTA begin failed: %d", error_code); this->ota_backend_.reset(); -#ifdef USE_OTA_STATE_CALLBACK - this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, static_cast(error_code)); +#ifdef USE_OTA_STATE_LISTENER + this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, static_cast(error_code)); #endif return; } @@ -166,8 +166,8 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf ESP_LOGE(TAG, "OTA write failed: %d", error_code); this->ota_backend_->abort(); this->ota_backend_.reset(); -#ifdef USE_OTA_STATE_CALLBACK - this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, static_cast(error_code)); +#ifdef USE_OTA_STATE_LISTENER + this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, static_cast(error_code)); #endif return; } @@ -186,15 +186,15 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf error_code = this->ota_backend_->end(); if (error_code == ota::OTA_RESPONSE_OK) { this->ota_success_ = true; -#ifdef USE_OTA_STATE_CALLBACK - // Report completion before reboot - use call_deferred since we're in web server task - this->parent_->state_callback_.call_deferred(ota::OTA_COMPLETED, 100.0f, 0); +#ifdef USE_OTA_STATE_LISTENER + // Report completion before reboot - use notify_state_deferred_ since we're in web server task + this->parent_->notify_state_deferred_(ota::OTA_COMPLETED, 100.0f, 0); #endif this->schedule_ota_reboot_(); } else { ESP_LOGE(TAG, "OTA end failed: %d", error_code); -#ifdef USE_OTA_STATE_CALLBACK - this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, static_cast(error_code)); +#ifdef USE_OTA_STATE_LISTENER + this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, static_cast(error_code)); #endif } this->ota_backend_.reset(); @@ -232,10 +232,6 @@ void WebServerOTAComponent::setup() { // AsyncWebServer takes ownership of the handler and will delete it when the server is destroyed base->add_handler(new OTARequestHandler(this)); // NOLINT -#ifdef USE_OTA_STATE_CALLBACK - // Register with global OTA callback system - ota::register_ota_platform(this); -#endif } void WebServerOTAComponent::dump_config() { ESP_LOGCONFIG(TAG, "Web Server OTA"); } diff --git a/esphome/core/defines.h b/esphome/core/defines.h index 4cbe683723..0c12b29eb7 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -146,7 +146,7 @@ #define USE_OTA_PASSWORD #define USE_OTA_SHA256 #define ALLOW_OTA_DOWNGRADE_MD5 -#define USE_OTA_STATE_CALLBACK +#define USE_OTA_STATE_LISTENER #define USE_OTA_VERSION 2 #define USE_TIME_TIMEZONE #define USE_WIFI