[ota] Replace std::function callbacks with listener interface (#12167)

This commit is contained in:
J. Nick Koston
2025-12-19 11:19:07 -10:00
committed by GitHub
parent 940afdbb12
commit 988b888c63
21 changed files with 274 additions and 206 deletions

View File

@@ -5,7 +5,7 @@ import logging
from esphome import automation from esphome import automation
import esphome.codegen as cg import esphome.codegen as cg
from esphome.components import esp32_ble from esphome.components import esp32_ble, ota
from esphome.components.esp32 import add_idf_sdkconfig_option from esphome.components.esp32 import add_idf_sdkconfig_option
from esphome.components.esp32_ble import ( from esphome.components.esp32_ble import (
IDF_MAX_CONNECTIONS, IDF_MAX_CONNECTIONS,
@@ -328,7 +328,7 @@ async def to_code(config):
# Note: CONFIG_BT_ACL_CONNECTIONS and CONFIG_BTDM_CTRL_BLE_MAX_CONN are now # Note: CONFIG_BT_ACL_CONNECTIONS and CONFIG_BTDM_CTRL_BLE_MAX_CONN are now
# configured in esp32_ble component based on max_connections setting # configured in esp32_ble component based on max_connections setting
cg.add_define("USE_OTA_STATE_CALLBACK") # To be notified when an OTA update starts ota.request_ota_state_listeners() # To be notified when an OTA update starts
cg.add_define("USE_ESP32_BLE_CLIENT") cg.add_define("USE_ESP32_BLE_CLIENT")
CORE.add_job(_add_ble_features) CORE.add_job(_add_ble_features)

View File

@@ -71,21 +71,24 @@ void ESP32BLETracker::setup() {
global_esp32_ble_tracker = this; global_esp32_ble_tracker = this;
#ifdef USE_OTA #ifdef USE_OTA_STATE_LISTENER
ota::get_global_ota_callback()->add_on_state_callback( ota::get_global_ota_callback()->add_global_state_listener(this);
[this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
this->stop_scan();
#ifdef ESPHOME_ESP32_BLE_TRACKER_CLIENT_COUNT
for (auto *client : this->clients_) {
client->disconnect();
}
#endif
}
});
#endif #endif
} }
#ifdef USE_OTA_STATE_LISTENER
void ESP32BLETracker::on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
this->stop_scan();
#ifdef ESPHOME_ESP32_BLE_TRACKER_CLIENT_COUNT
for (auto *client : this->clients_) {
client->disconnect();
}
#endif
}
}
#endif
void ESP32BLETracker::loop() { void ESP32BLETracker::loop() {
if (!this->parent_->is_active()) { if (!this->parent_->is_active()) {
this->ble_was_disabled_ = true; this->ble_was_disabled_ = true;

View File

@@ -22,6 +22,10 @@
#include "esphome/components/esp32_ble/ble_uuid.h" #include "esphome/components/esp32_ble/ble_uuid.h"
#include "esphome/components/esp32_ble/ble_scan_result.h" #include "esphome/components/esp32_ble/ble_scan_result.h"
#ifdef USE_OTA_STATE_LISTENER
#include "esphome/components/ota/ota_backend.h"
#endif
namespace esphome::esp32_ble_tracker { namespace esphome::esp32_ble_tracker {
using namespace esp32_ble; using namespace esp32_ble;
@@ -241,6 +245,9 @@ class ESP32BLETracker : public Component,
public GAPScanEventHandler, public GAPScanEventHandler,
public GATTcEventHandler, public GATTcEventHandler,
public BLEStatusEventHandler, public BLEStatusEventHandler,
#ifdef USE_OTA_STATE_LISTENER
public ota::OTAGlobalStateListener,
#endif
public Parented<ESP32BLE> { public Parented<ESP32BLE> {
public: public:
void set_scan_duration(uint32_t scan_duration) { scan_duration_ = scan_duration; } void set_scan_duration(uint32_t scan_duration) { scan_duration_ = scan_duration; }
@@ -274,6 +281,10 @@ class ESP32BLETracker : public Component,
void gap_scan_event_handler(const BLEScanResult &scan_result) override; void gap_scan_event_handler(const BLEScanResult &scan_result) override;
void ble_before_disabled_event_handler() override; void ble_before_disabled_event_handler() override;
#ifdef USE_OTA_STATE_LISTENER
void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override;
#endif
/// Add a listener for scanner state changes /// Add a listener for scanner state changes
void add_scanner_state_listener(BLEScannerStateListener *listener) { void add_scanner_state_listener(BLEScannerStateListener *listener) {
this->scanner_state_listeners_.push_back(listener); this->scanner_state_listeners_.push_back(listener);

View File

@@ -41,10 +41,6 @@ static constexpr size_t SHA256_HEX_SIZE = 64; // SHA256 hash as hex string (32
#endif // USE_OTA_PASSWORD #endif // USE_OTA_PASSWORD
void ESPHomeOTAComponent::setup() { void ESPHomeOTAComponent::setup() {
#ifdef USE_OTA_STATE_CALLBACK
ota::register_ota_platform(this);
#endif
this->server_ = socket::socket_ip_loop_monitored(SOCK_STREAM, 0); // monitored for incoming connections this->server_ = socket::socket_ip_loop_monitored(SOCK_STREAM, 0); // monitored for incoming connections
if (this->server_ == nullptr) { if (this->server_ == nullptr) {
this->log_socket_error_(LOG_STR("creation")); this->log_socket_error_(LOG_STR("creation"));
@@ -297,8 +293,8 @@ void ESPHomeOTAComponent::handle_data_() {
// accidentally trigger the update process. // accidentally trigger the update process.
this->log_start_(LOG_STR("update")); this->log_start_(LOG_STR("update"));
this->status_set_warning(); this->status_set_warning();
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0); this->notify_state_(ota::OTA_STARTED, 0.0f, 0);
#endif #endif
// This will block for a few seconds as it locks flash // This will block for a few seconds as it locks flash
@@ -357,8 +353,8 @@ void ESPHomeOTAComponent::handle_data_() {
last_progress = now; last_progress = now;
float percentage = (total * 100.0f) / ota_size; float percentage = (total * 100.0f) / ota_size;
ESP_LOGD(TAG, "Progress: %0.1f%%", percentage); ESP_LOGD(TAG, "Progress: %0.1f%%", percentage);
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
this->state_callback_.call(ota::OTA_IN_PROGRESS, percentage, 0); this->notify_state_(ota::OTA_IN_PROGRESS, percentage, 0);
#endif #endif
// feed watchdog and give other tasks a chance to run // feed watchdog and give other tasks a chance to run
this->yield_and_feed_watchdog_(); this->yield_and_feed_watchdog_();
@@ -387,8 +383,8 @@ void ESPHomeOTAComponent::handle_data_() {
delay(10); delay(10);
ESP_LOGI(TAG, "Update complete"); ESP_LOGI(TAG, "Update complete");
this->status_clear_warning(); this->status_clear_warning();
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
this->state_callback_.call(ota::OTA_COMPLETED, 100.0f, 0); this->notify_state_(ota::OTA_COMPLETED, 100.0f, 0);
#endif #endif
delay(100); // NOLINT delay(100); // NOLINT
App.safe_reboot(); App.safe_reboot();
@@ -402,8 +398,8 @@ error:
} }
this->status_momentary_error("err", 5000); this->status_momentary_error("err", 5000);
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
this->state_callback_.call(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code)); this->notify_state_(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#endif #endif
} }

View File

@@ -16,12 +16,6 @@ namespace http_request {
static const char *const TAG = "http_request.ota"; static const char *const TAG = "http_request.ota";
void OtaHttpRequestComponent::setup() {
#ifdef USE_OTA_STATE_CALLBACK
ota::register_ota_platform(this);
#endif
}
void OtaHttpRequestComponent::dump_config() { ESP_LOGCONFIG(TAG, "Over-The-Air updates via HTTP request"); }; void OtaHttpRequestComponent::dump_config() { ESP_LOGCONFIG(TAG, "Over-The-Air updates via HTTP request"); };
void OtaHttpRequestComponent::set_md5_url(const std::string &url) { void OtaHttpRequestComponent::set_md5_url(const std::string &url) {
@@ -48,24 +42,24 @@ void OtaHttpRequestComponent::flash() {
} }
ESP_LOGI(TAG, "Starting update"); ESP_LOGI(TAG, "Starting update");
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0); this->notify_state_(ota::OTA_STARTED, 0.0f, 0);
#endif #endif
auto ota_status = this->do_ota_(); auto ota_status = this->do_ota_();
switch (ota_status) { switch (ota_status) {
case ota::OTA_RESPONSE_OK: case ota::OTA_RESPONSE_OK:
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
this->state_callback_.call(ota::OTA_COMPLETED, 100.0f, ota_status); this->notify_state_(ota::OTA_COMPLETED, 100.0f, ota_status);
#endif #endif
delay(10); delay(10);
App.safe_reboot(); App.safe_reboot();
break; break;
default: default:
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
this->state_callback_.call(ota::OTA_ERROR, 0.0f, ota_status); this->notify_state_(ota::OTA_ERROR, 0.0f, ota_status);
#endif #endif
this->md5_computed_.clear(); // will be reset at next attempt this->md5_computed_.clear(); // will be reset at next attempt
this->md5_expected_.clear(); // will be reset at next attempt this->md5_expected_.clear(); // will be reset at next attempt
@@ -165,8 +159,8 @@ uint8_t OtaHttpRequestComponent::do_ota_() {
last_progress = now; last_progress = now;
float percentage = container->get_bytes_read() * 100.0f / container->content_length; float percentage = container->get_bytes_read() * 100.0f / container->content_length;
ESP_LOGD(TAG, "Progress: %0.1f%%", percentage); ESP_LOGD(TAG, "Progress: %0.1f%%", percentage);
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
this->state_callback_.call(ota::OTA_IN_PROGRESS, percentage, 0); this->notify_state_(ota::OTA_IN_PROGRESS, percentage, 0);
#endif #endif
} }
} // while } // while

View File

@@ -24,7 +24,6 @@ enum OtaHttpRequestError : uint8_t {
class OtaHttpRequestComponent : public ota::OTAComponent, public Parented<HttpRequestComponent> { class OtaHttpRequestComponent : public ota::OTAComponent, public Parented<HttpRequestComponent> {
public: public:
void setup() override;
void dump_config() override; void dump_config() override;
float get_setup_priority() const override { return setup_priority::AFTER_WIFI; } float get_setup_priority() const override { return setup_priority::AFTER_WIFI; }

View File

@@ -1,5 +1,5 @@
import esphome.codegen as cg import esphome.codegen as cg
from esphome.components import update from esphome.components import ota, update
import esphome.config_validation as cv import esphome.config_validation as cv
from esphome.const import CONF_SOURCE from esphome.const import CONF_SOURCE
@@ -38,6 +38,6 @@ async def to_code(config):
cg.add(var.set_source_url(config[CONF_SOURCE])) cg.add(var.set_source_url(config[CONF_SOURCE]))
cg.add_define("USE_OTA_STATE_CALLBACK") ota.request_ota_state_listeners()
await cg.register_component(var, config) await cg.register_component(var, config)

View File

@@ -20,19 +20,19 @@ static const char *const TAG = "http_request.update";
static const size_t MAX_READ_SIZE = 256; static const size_t MAX_READ_SIZE = 256;
void HttpRequestUpdate::setup() { void HttpRequestUpdate::setup() { this->ota_parent_->add_state_listener(this); }
this->ota_parent_->add_on_state_callback([this](ota::OTAState state, float progress, uint8_t err) {
if (state == ota::OTAState::OTA_IN_PROGRESS) { void HttpRequestUpdate::on_ota_state(ota::OTAState state, float progress, uint8_t error) {
this->state_ = update::UPDATE_STATE_INSTALLING; if (state == ota::OTAState::OTA_IN_PROGRESS) {
this->update_info_.has_progress = true; this->state_ = update::UPDATE_STATE_INSTALLING;
this->update_info_.progress = progress; this->update_info_.has_progress = true;
this->publish_state(); this->update_info_.progress = progress;
} else if (state == ota::OTAState::OTA_ABORT || state == ota::OTAState::OTA_ERROR) { this->publish_state();
this->state_ = update::UPDATE_STATE_AVAILABLE; } else if (state == ota::OTAState::OTA_ABORT || state == ota::OTAState::OTA_ERROR) {
this->status_set_error(LOG_STR("Failed to install firmware")); this->state_ = update::UPDATE_STATE_AVAILABLE;
this->publish_state(); this->status_set_error(LOG_STR("Failed to install firmware"));
} this->publish_state();
}); }
} }
void HttpRequestUpdate::update() { void HttpRequestUpdate::update() {

View File

@@ -14,7 +14,7 @@
namespace esphome { namespace esphome {
namespace http_request { namespace http_request {
class HttpRequestUpdate : public update::UpdateEntity, public PollingComponent { class HttpRequestUpdate final : public update::UpdateEntity, public PollingComponent, public ota::OTAStateListener {
public: public:
void setup() override; void setup() override;
void update() override; void update() override;
@@ -29,6 +29,8 @@ class HttpRequestUpdate : public update::UpdateEntity, public PollingComponent {
float get_setup_priority() const override { return setup_priority::AFTER_WIFI; } float get_setup_priority() const override { return setup_priority::AFTER_WIFI; }
void on_ota_state(ota::OTAState state, float progress, uint8_t error) override;
protected: protected:
HttpRequestComponent *request_parent_; HttpRequestComponent *request_parent_;
OtaHttpRequestComponent *ota_parent_; OtaHttpRequestComponent *ota_parent_;

View File

@@ -7,7 +7,7 @@ from urllib.parse import urljoin
from esphome import automation, external_files, git from esphome import automation, external_files, git
from esphome.automation import register_action, register_condition from esphome.automation import register_action, register_condition
import esphome.codegen as cg import esphome.codegen as cg
from esphome.components import esp32, microphone, socket from esphome.components import esp32, microphone, ota, socket
import esphome.config_validation as cv import esphome.config_validation as cv
from esphome.const import ( from esphome.const import (
CONF_FILE, CONF_FILE,
@@ -452,7 +452,7 @@ async def to_code(config):
cg.add(var.set_microphone_source(mic_source)) cg.add(var.set_microphone_source(mic_source))
cg.add_define("USE_MICRO_WAKE_WORD") cg.add_define("USE_MICRO_WAKE_WORD")
cg.add_define("USE_OTA_STATE_CALLBACK") ota.request_ota_state_listeners()
esp32.add_idf_component(name="espressif/esp-tflite-micro", ref="1.3.3~1") esp32.add_idf_component(name="espressif/esp-tflite-micro", ref="1.3.3~1")

View File

@@ -119,18 +119,21 @@ void MicroWakeWord::setup() {
} }
}); });
#ifdef USE_OTA #ifdef USE_OTA_STATE_LISTENER
ota::get_global_ota_callback()->add_on_state_callback( ota::get_global_ota_callback()->add_global_state_listener(this);
[this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
this->suspend_task_();
} else if (state == ota::OTA_ERROR) {
this->resume_task_();
}
});
#endif #endif
} }
#ifdef USE_OTA_STATE_LISTENER
void MicroWakeWord::on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
this->suspend_task_();
} else if (state == ota::OTA_ERROR) {
this->resume_task_();
}
}
#endif
void MicroWakeWord::inference_task(void *params) { void MicroWakeWord::inference_task(void *params) {
MicroWakeWord *this_mww = (MicroWakeWord *) params; MicroWakeWord *this_mww = (MicroWakeWord *) params;

View File

@@ -9,8 +9,13 @@
#include "esphome/core/automation.h" #include "esphome/core/automation.h"
#include "esphome/core/component.h" #include "esphome/core/component.h"
#include "esphome/core/defines.h"
#include "esphome/core/ring_buffer.h" #include "esphome/core/ring_buffer.h"
#ifdef USE_OTA_STATE_LISTENER
#include "esphome/components/ota/ota_backend.h"
#endif
#include <freertos/event_groups.h> #include <freertos/event_groups.h>
#include <frontend.h> #include <frontend.h>
@@ -26,13 +31,22 @@ enum State {
STOPPED, STOPPED,
}; };
class MicroWakeWord : public Component { class MicroWakeWord : public Component
#ifdef USE_OTA_STATE_LISTENER
,
public ota::OTAGlobalStateListener
#endif
{
public: public:
void setup() override; void setup() override;
void loop() override; void loop() override;
float get_setup_priority() const override; float get_setup_priority() const override;
void dump_config() override; void dump_config() override;
#ifdef USE_OTA_STATE_LISTENER
void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override;
#endif
void start(); void start();
void stop(); void stop();

View File

@@ -13,6 +13,8 @@ from esphome.const import (
from esphome.core import CORE, coroutine_with_priority from esphome.core import CORE, coroutine_with_priority
from esphome.coroutine import CoroPriority from esphome.coroutine import CoroPriority
OTA_STATE_LISTENER_KEY = "ota_state_listener"
CODEOWNERS = ["@esphome/core"] CODEOWNERS = ["@esphome/core"]
AUTO_LOAD = ["md5", "safe_mode"] AUTO_LOAD = ["md5", "safe_mode"]
@@ -86,6 +88,7 @@ BASE_OTA_SCHEMA = cv.Schema(
@coroutine_with_priority(CoroPriority.OTA_UPDATES) @coroutine_with_priority(CoroPriority.OTA_UPDATES)
async def to_code(config): async def to_code(config):
cg.add_define("USE_OTA") cg.add_define("USE_OTA")
CORE.add_job(final_step)
if CORE.is_rp2040 and CORE.using_arduino: if CORE.is_rp2040 and CORE.using_arduino:
cg.add_library("Updater", None) cg.add_library("Updater", None)
@@ -119,7 +122,24 @@ async def ota_to_code(var, config):
await automation.build_automation(trigger, [(cg.uint8, "x")], conf) await automation.build_automation(trigger, [(cg.uint8, "x")], conf)
use_state_callback = True use_state_callback = True
if use_state_callback: if use_state_callback:
cg.add_define("USE_OTA_STATE_CALLBACK") request_ota_state_listeners()
def request_ota_state_listeners() -> None:
"""Request that OTA state listeners be compiled in.
Components that need to be notified about OTA state changes (start, progress,
complete, error) should call this function during their code generation.
This enables the add_state_listener() API on OTAComponent.
"""
CORE.data[OTA_STATE_LISTENER_KEY] = True
@coroutine_with_priority(CoroPriority.FINAL)
async def final_step():
"""Final code generation step to configure optional OTA features."""
if CORE.data.get(OTA_STATE_LISTENER_KEY, False):
cg.add_define("USE_OTA_STATE_LISTENER")
FILTER_SOURCE_FILES = filter_source_files_from_platform( FILTER_SOURCE_FILES = filter_source_files_from_platform(

View File

@@ -1,5 +1,5 @@
#pragma once #pragma once
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
#include "ota_backend.h" #include "ota_backend.h"
#include "esphome/core/automation.h" #include "esphome/core/automation.h"
@@ -7,70 +7,64 @@
namespace esphome { namespace esphome {
namespace ota { namespace ota {
class OTAStateChangeTrigger : public Trigger<OTAState> { class OTAStateChangeTrigger final : public Trigger<OTAState>, public OTAStateListener {
public: public:
explicit OTAStateChangeTrigger(OTAComponent *parent) { explicit OTAStateChangeTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); }
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (!parent->is_failed()) { void on_ota_state(OTAState state, float progress, uint8_t error) override {
trigger(state); if (!this->parent_->is_failed()) {
} this->trigger(state);
}); }
} }
protected:
OTAComponent *parent_;
}; };
class OTAStartTrigger : public Trigger<> { template<OTAState State> class OTAStateTrigger final : public Trigger<>, public OTAStateListener {
public: public:
explicit OTAStartTrigger(OTAComponent *parent) { explicit OTAStateTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); }
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_STARTED && !parent->is_failed()) { void on_ota_state(OTAState state, float progress, uint8_t error) override {
trigger(); if (state == State && !this->parent_->is_failed()) {
} this->trigger();
}); }
} }
protected:
OTAComponent *parent_;
}; };
class OTAProgressTrigger : public Trigger<float> { using OTAStartTrigger = OTAStateTrigger<OTA_STARTED>;
using OTAEndTrigger = OTAStateTrigger<OTA_COMPLETED>;
using OTAAbortTrigger = OTAStateTrigger<OTA_ABORT>;
class OTAProgressTrigger final : public Trigger<float>, public OTAStateListener {
public: public:
explicit OTAProgressTrigger(OTAComponent *parent) { explicit OTAProgressTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); }
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_IN_PROGRESS && !parent->is_failed()) { void on_ota_state(OTAState state, float progress, uint8_t error) override {
trigger(progress); if (state == OTA_IN_PROGRESS && !this->parent_->is_failed()) {
} this->trigger(progress);
}); }
} }
protected:
OTAComponent *parent_;
}; };
class OTAEndTrigger : public Trigger<> { class OTAErrorTrigger final : public Trigger<uint8_t>, public OTAStateListener {
public: public:
explicit OTAEndTrigger(OTAComponent *parent) { explicit OTAErrorTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); }
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_COMPLETED && !parent->is_failed()) {
trigger();
}
});
}
};
class OTAAbortTrigger : public Trigger<> { void on_ota_state(OTAState state, float progress, uint8_t error) override {
public: if (state == OTA_ERROR && !this->parent_->is_failed()) {
explicit OTAAbortTrigger(OTAComponent *parent) { this->trigger(error);
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { }
if (state == OTA_ABORT && !parent->is_failed()) {
trigger();
}
});
} }
};
class OTAErrorTrigger : public Trigger<uint8_t> { protected:
public: OTAComponent *parent_;
explicit OTAErrorTrigger(OTAComponent *parent) {
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_ERROR && !parent->is_failed()) {
trigger(error);
}
});
}
}; };
} // namespace ota } // namespace ota

View File

@@ -3,7 +3,7 @@
namespace esphome { namespace esphome {
namespace ota { namespace ota {
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
OTAGlobalCallback *global_ota_callback{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) OTAGlobalCallback *global_ota_callback{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
OTAGlobalCallback *get_global_ota_callback() { OTAGlobalCallback *get_global_ota_callback() {
@@ -13,7 +13,12 @@ OTAGlobalCallback *get_global_ota_callback() {
return global_ota_callback; return global_ota_callback;
} }
void register_ota_platform(OTAComponent *ota_caller) { get_global_ota_callback()->register_ota(ota_caller); } void OTAComponent::notify_state_(OTAState state, float progress, uint8_t error) {
for (auto *listener : this->state_listeners_) {
listener->on_ota_state(state, progress, error);
}
get_global_ota_callback()->notify_ota_state(state, progress, error, this);
}
#endif #endif
} // namespace ota } // namespace ota

View File

@@ -4,8 +4,8 @@
#include "esphome/core/defines.h" #include "esphome/core/defines.h"
#include "esphome/core/helpers.h" #include "esphome/core/helpers.h"
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
#include "esphome/core/automation.h" #include <vector>
#endif #endif
namespace esphome { namespace esphome {
@@ -60,62 +60,75 @@ class OTABackend {
virtual bool supports_compression() = 0; virtual bool supports_compression() = 0;
}; };
class OTAComponent : public Component { /** Listener interface for OTA state changes.
#ifdef USE_OTA_STATE_CALLBACK *
* Components can implement this interface to receive OTA state updates
* without the overhead of std::function callbacks.
*/
class OTAStateListener {
public: public:
void add_on_state_callback(std::function<void(ota::OTAState, float, uint8_t)> &&callback) { virtual ~OTAStateListener() = default;
this->state_callback_.add(std::move(callback)); virtual void on_ota_state(OTAState state, float progress, uint8_t error) = 0;
} };
class OTAComponent : public Component {
#ifdef USE_OTA_STATE_LISTENER
public:
void add_state_listener(OTAStateListener *listener) { this->state_listeners_.push_back(listener); }
protected: protected:
/** Extended callback manager with deferred call support. void notify_state_(OTAState state, float progress, uint8_t error);
/** Notify state with deferral to main loop (for thread safety).
* *
* This adds a call_deferred() method for thread-safe execution from other tasks. * This should be used by OTA implementations that run in separate tasks
* (like web_server OTA) to ensure listeners execute in the main loop.
*/ */
class StateCallbackManager : public CallbackManager<void(OTAState, float, uint8_t)> { void notify_state_deferred_(OTAState state, float progress, uint8_t error) {
public: this->defer([this, state, progress, error]() { this->notify_state_(state, progress, error); });
StateCallbackManager(OTAComponent *component) : component_(component) {} }
/** Call callbacks with deferral to main loop (for thread safety). std::vector<OTAStateListener *> state_listeners_;
*
* This should be used by OTA implementations that run in separate tasks
* (like web_server OTA) to ensure callbacks execute in the main loop.
*/
void call_deferred(ota::OTAState state, float progress, uint8_t error) {
component_->defer([this, state, progress, error]() { this->call(state, progress, error); });
}
private:
OTAComponent *component_;
};
StateCallbackManager state_callback_{this};
#endif #endif
}; };
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
/** Listener interface for global OTA state changes (includes OTA component pointer).
*
* Used by OTAGlobalCallback to aggregate state from multiple OTA components.
*/
class OTAGlobalStateListener {
public:
virtual ~OTAGlobalStateListener() = default;
virtual void on_ota_global_state(OTAState state, float progress, uint8_t error, OTAComponent *component) = 0;
};
/** Global callback that aggregates OTA state from all OTA components.
*
* OTA components call notify_ota_state() directly with their pointer,
* which forwards the event to all registered global listeners.
*/
class OTAGlobalCallback { class OTAGlobalCallback {
public: public:
void register_ota(OTAComponent *ota_caller) { void add_global_state_listener(OTAGlobalStateListener *listener) { this->global_listeners_.push_back(listener); }
ota_caller->add_on_state_callback([this, ota_caller](OTAState state, float progress, uint8_t error) {
this->state_callback_.call(state, progress, error, ota_caller); void notify_ota_state(OTAState state, float progress, uint8_t error, OTAComponent *component) {
}); for (auto *listener : this->global_listeners_) {
} listener->on_ota_global_state(state, progress, error, component);
void add_on_state_callback(std::function<void(OTAState, float, uint8_t, OTAComponent *)> &&callback) { }
this->state_callback_.add(std::move(callback));
} }
protected: protected:
CallbackManager<void(OTAState, float, uint8_t, OTAComponent *)> state_callback_{}; std::vector<OTAGlobalStateListener *> global_listeners_;
}; };
OTAGlobalCallback *get_global_ota_callback(); OTAGlobalCallback *get_global_ota_callback();
void register_ota_platform(OTAComponent *ota_caller);
// OTA implementations should use: // OTA implementations should use:
// - state_callback_.call() when already in main loop (e.g., esphome OTA) // - notify_state_() when already in main loop (e.g., esphome OTA)
// - state_callback_.call_deferred() when in separate task (e.g., web_server OTA) // - notify_state_deferred_() when in separate task (e.g., web_server OTA)
// This ensures proper callback execution in all contexts. // This ensures proper listener execution in all contexts.
#endif #endif
std::unique_ptr<ota::OTABackend> make_ota_backend(); std::unique_ptr<ota::OTABackend> make_ota_backend();

View File

@@ -6,7 +6,7 @@ from pathlib import Path
from esphome import automation, external_files from esphome import automation, external_files
import esphome.codegen as cg import esphome.codegen as cg
from esphome.components import audio, esp32, media_player, network, psram, speaker from esphome.components import audio, esp32, media_player, network, ota, psram, speaker
import esphome.config_validation as cv import esphome.config_validation as cv
from esphome.const import ( from esphome.const import (
CONF_BUFFER_SIZE, CONF_BUFFER_SIZE,
@@ -342,7 +342,7 @@ async def to_code(config):
var = await media_player.new_media_player(config) var = await media_player.new_media_player(config)
await cg.register_component(var, config) await cg.register_component(var, config)
cg.add_define("USE_OTA_STATE_CALLBACK") ota.request_ota_state_listeners()
cg.add(var.set_buffer_size(config[CONF_BUFFER_SIZE])) cg.add(var.set_buffer_size(config[CONF_BUFFER_SIZE]))

View File

@@ -66,25 +66,8 @@ void SpeakerMediaPlayer::setup() {
this->set_mute_state_(false); this->set_mute_state_(false);
} }
#ifdef USE_OTA #ifdef USE_OTA_STATE_LISTENER
ota::get_global_ota_callback()->add_on_state_callback( ota::get_global_ota_callback()->add_global_state_listener(this);
[this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
if (this->media_pipeline_ != nullptr) {
this->media_pipeline_->suspend_tasks();
}
if (this->announcement_pipeline_ != nullptr) {
this->announcement_pipeline_->suspend_tasks();
}
} else if (state == ota::OTA_ERROR) {
if (this->media_pipeline_ != nullptr) {
this->media_pipeline_->resume_tasks();
}
if (this->announcement_pipeline_ != nullptr) {
this->announcement_pipeline_->resume_tasks();
}
}
});
#endif #endif
this->announcement_pipeline_ = this->announcement_pipeline_ =
@@ -300,6 +283,27 @@ void SpeakerMediaPlayer::watch_media_commands_() {
} }
} }
#ifdef USE_OTA_STATE_LISTENER
void SpeakerMediaPlayer::on_ota_global_state(ota::OTAState state, float progress, uint8_t error,
ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
if (this->media_pipeline_ != nullptr) {
this->media_pipeline_->suspend_tasks();
}
if (this->announcement_pipeline_ != nullptr) {
this->announcement_pipeline_->suspend_tasks();
}
} else if (state == ota::OTA_ERROR) {
if (this->media_pipeline_ != nullptr) {
this->media_pipeline_->resume_tasks();
}
if (this->announcement_pipeline_ != nullptr) {
this->announcement_pipeline_->resume_tasks();
}
}
}
#endif
void SpeakerMediaPlayer::loop() { void SpeakerMediaPlayer::loop() {
this->watch_media_commands_(); this->watch_media_commands_();

View File

@@ -5,14 +5,18 @@
#include "audio_pipeline.h" #include "audio_pipeline.h"
#include "esphome/components/audio/audio.h" #include "esphome/components/audio/audio.h"
#include "esphome/components/media_player/media_player.h" #include "esphome/components/media_player/media_player.h"
#include "esphome/components/speaker/speaker.h" #include "esphome/components/speaker/speaker.h"
#include "esphome/core/automation.h" #include "esphome/core/automation.h"
#include "esphome/core/component.h" #include "esphome/core/component.h"
#include "esphome/core/defines.h"
#include "esphome/core/preferences.h" #include "esphome/core/preferences.h"
#ifdef USE_OTA_STATE_LISTENER
#include "esphome/components/ota/ota_backend.h"
#endif
#include <deque> #include <deque>
#include <freertos/FreeRTOS.h> #include <freertos/FreeRTOS.h>
#include <freertos/queue.h> #include <freertos/queue.h>
@@ -39,12 +43,22 @@ struct VolumeRestoreState {
bool is_muted; bool is_muted;
}; };
class SpeakerMediaPlayer : public Component, public media_player::MediaPlayer { class SpeakerMediaPlayer : public Component,
public media_player::MediaPlayer
#ifdef USE_OTA_STATE_LISTENER
,
public ota::OTAGlobalStateListener
#endif
{
public: public:
float get_setup_priority() const override { return esphome::setup_priority::PROCESSOR; } float get_setup_priority() const override { return esphome::setup_priority::PROCESSOR; }
void setup() override; void setup() override;
void loop() override; void loop() override;
#ifdef USE_OTA_STATE_LISTENER
void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override;
#endif
// MediaPlayer implementations // MediaPlayer implementations
media_player::MediaPlayerTraits get_traits() override; media_player::MediaPlayerTraits get_traits() override;
bool is_muted() const override { return this->is_muted_; } bool is_muted() const override { return this->is_muted_; }

View File

@@ -84,9 +84,9 @@ void OTARequestHandler::report_ota_progress_(AsyncWebServerRequest *request) {
} else { } else {
ESP_LOGD(TAG, "OTA in progress: %" PRIu32 " bytes read", this->ota_read_length_); ESP_LOGD(TAG, "OTA in progress: %" PRIu32 " bytes read", this->ota_read_length_);
} }
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
// Report progress - use call_deferred since we're in web server task // Report progress - use notify_state_deferred_ since we're in web server task
this->parent_->state_callback_.call_deferred(ota::OTA_IN_PROGRESS, percentage, 0); this->parent_->notify_state_deferred_(ota::OTA_IN_PROGRESS, percentage, 0);
#endif #endif
this->last_ota_progress_ = now; this->last_ota_progress_ = now;
} }
@@ -114,9 +114,9 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
// Initialize OTA on first call // Initialize OTA on first call
this->ota_init_(filename.c_str()); this->ota_init_(filename.c_str());
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
// Notify OTA started - use call_deferred since we're in web server task // Notify OTA started - use notify_state_deferred_ since we're in web server task
this->parent_->state_callback_.call_deferred(ota::OTA_STARTED, 0.0f, 0); this->parent_->notify_state_deferred_(ota::OTA_STARTED, 0.0f, 0);
#endif #endif
// Platform-specific pre-initialization // Platform-specific pre-initialization
@@ -134,9 +134,9 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
this->ota_backend_ = ota::make_ota_backend(); this->ota_backend_ = ota::make_ota_backend();
if (!this->ota_backend_) { if (!this->ota_backend_) {
ESP_LOGE(TAG, "Failed to create OTA backend"); ESP_LOGE(TAG, "Failed to create OTA backend");
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f,
static_cast<uint8_t>(ota::OTA_RESPONSE_ERROR_UNKNOWN)); static_cast<uint8_t>(ota::OTA_RESPONSE_ERROR_UNKNOWN));
#endif #endif
return; return;
} }
@@ -148,8 +148,8 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
if (error_code != ota::OTA_RESPONSE_OK) { if (error_code != ota::OTA_RESPONSE_OK) {
ESP_LOGE(TAG, "OTA begin failed: %d", error_code); ESP_LOGE(TAG, "OTA begin failed: %d", error_code);
this->ota_backend_.reset(); this->ota_backend_.reset();
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code)); this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#endif #endif
return; return;
} }
@@ -166,8 +166,8 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
ESP_LOGE(TAG, "OTA write failed: %d", error_code); ESP_LOGE(TAG, "OTA write failed: %d", error_code);
this->ota_backend_->abort(); this->ota_backend_->abort();
this->ota_backend_.reset(); this->ota_backend_.reset();
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code)); this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#endif #endif
return; return;
} }
@@ -186,15 +186,15 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
error_code = this->ota_backend_->end(); error_code = this->ota_backend_->end();
if (error_code == ota::OTA_RESPONSE_OK) { if (error_code == ota::OTA_RESPONSE_OK) {
this->ota_success_ = true; this->ota_success_ = true;
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
// Report completion before reboot - use call_deferred since we're in web server task // Report completion before reboot - use notify_state_deferred_ since we're in web server task
this->parent_->state_callback_.call_deferred(ota::OTA_COMPLETED, 100.0f, 0); this->parent_->notify_state_deferred_(ota::OTA_COMPLETED, 100.0f, 0);
#endif #endif
this->schedule_ota_reboot_(); this->schedule_ota_reboot_();
} else { } else {
ESP_LOGE(TAG, "OTA end failed: %d", error_code); ESP_LOGE(TAG, "OTA end failed: %d", error_code);
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_LISTENER
this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code)); this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#endif #endif
} }
this->ota_backend_.reset(); this->ota_backend_.reset();
@@ -232,10 +232,6 @@ void WebServerOTAComponent::setup() {
// AsyncWebServer takes ownership of the handler and will delete it when the server is destroyed // AsyncWebServer takes ownership of the handler and will delete it when the server is destroyed
base->add_handler(new OTARequestHandler(this)); // NOLINT base->add_handler(new OTARequestHandler(this)); // NOLINT
#ifdef USE_OTA_STATE_CALLBACK
// Register with global OTA callback system
ota::register_ota_platform(this);
#endif
} }
void WebServerOTAComponent::dump_config() { ESP_LOGCONFIG(TAG, "Web Server OTA"); } void WebServerOTAComponent::dump_config() { ESP_LOGCONFIG(TAG, "Web Server OTA"); }

View File

@@ -146,7 +146,7 @@
#define USE_OTA_PASSWORD #define USE_OTA_PASSWORD
#define USE_OTA_SHA256 #define USE_OTA_SHA256
#define ALLOW_OTA_DOWNGRADE_MD5 #define ALLOW_OTA_DOWNGRADE_MD5
#define USE_OTA_STATE_CALLBACK #define USE_OTA_STATE_LISTENER
#define USE_OTA_VERSION 2 #define USE_OTA_VERSION 2
#define USE_TIME_TIMEZONE #define USE_TIME_TIMEZONE
#define USE_WIFI #define USE_WIFI