diff --git a/libraries/ArduinoOTA/ArduinoOTA.cpp b/libraries/ArduinoOTA/ArduinoOTA.cpp index 744c208cc1..9b55733bb6 100644 --- a/libraries/ArduinoOTA/ArduinoOTA.cpp +++ b/libraries/ArduinoOTA/ArduinoOTA.cpp @@ -1,9 +1,23 @@ -#include -#include +#define LWIP_OPEN_SRC +#include #include #include "ArduinoOTA.h" #include "MD5Builder.h" +extern "C" { + #include "osapi.h" + #include "ets_sys.h" + #include "user_interface.h" +} + +#include "lwip/opt.h" +#include "lwip/udp.h" +#include "lwip/inet.h" +#include "lwip/igmp.h" +#include "lwip/mem.h" +#include "include/UdpContext.h" +#include + //#define OTA_DEBUG 1 ArduinoOTAClass::ArduinoOTAClass() @@ -16,9 +30,17 @@ ArduinoOTAClass::ArduinoOTAClass() , _end_callback(NULL) , _progress_callback(NULL) , _error_callback(NULL) +, _udp_ota(0) { } +ArduinoOTAClass::~ArduinoOTAClass(){ + if(_udp_ota){ + _udp_ota->unref(); + _udp_ota = 0; + } +} + void ArduinoOTAClass::onStart(OTA_CALLBACK(fn)) { _start_callback = fn; } @@ -35,9 +57,6 @@ void ArduinoOTAClass::onError(OTA_CALLBACK_ERROR(fn)) { _error_callback = fn; } -ArduinoOTAClass::~ArduinoOTAClass() { -} - void ArduinoOTAClass::setPort(uint16_t port) { if (!_initialized && !_port && port) { _port = port; @@ -59,7 +78,6 @@ void ArduinoOTAClass::setPassword(const char * password) { void ArduinoOTAClass::begin() { if (_initialized) return; - _initialized = true; if (!_hostname.length()) { char tmp[15]; @@ -70,7 +88,17 @@ void ArduinoOTAClass::begin() { _port = 8266; } - _udp_ota.begin(_port); + if(_udp_ota){ + _udp_ota->unref(); + _udp_ota = 0; + } + + _udp_ota = new UdpContext; + _udp_ota->ref(); + + if(!_udp_ota->listen(*IP_ADDR_ANY, _port)) + return; + _udp_ota->onRx(std::bind(&ArduinoOTAClass::_onRx, this)); MDNS.begin(_hostname.c_str()); if (_password.length()) { @@ -78,12 +106,123 @@ void ArduinoOTAClass::begin() { } else { MDNS.enableArduino(_port); } + _initialized = true; _state = OTA_IDLE; #if OTA_DEBUG Serial.printf("OTA server at: %s.local:%u\n", _hostname.c_str(), _port); #endif } +int ArduinoOTAClass::parseInt(){ + char data[16]; + uint8_t index = 0; + char value; + while(_udp_ota->peek() == ' ') _udp_ota->read(); + while(true){ + value = _udp_ota->peek(); + if(value < '0' || value > '9'){ + data[index++] = '\0'; + return atoi(data); + } + data[index++] = _udp_ota->read(); + } + return 0; +} + +String ArduinoOTAClass::readStringUntil(char end){ + String res = ""; + char value; + while(true){ + value = _udp_ota->read(); + if(value == '\0' || value == end){ + return res; + } + res += value; + } + return res; +} + +void ArduinoOTAClass::_onRx(){ + if(!_udp_ota->next()) return; + ip_addr_t ota_ip; + + if (_state == OTA_IDLE) { + int cmd = parseInt(); + if (cmd != U_FLASH && cmd != U_SPIFFS) + return; + _ota_ip = _udp_ota->getRemoteAddress(); + _cmd = cmd; + _ota_port = parseInt(); + _size = parseInt(); + _udp_ota->read(); + _md5 = readStringUntil('\n'); + _md5.trim(); + if(_md5.length() != 32) + return; + + ota_ip.addr = (uint32_t)_ota_ip; + + if (_password.length()){ + MD5Builder nonce_md5; + nonce_md5.begin(); + nonce_md5.add(String(micros())); + nonce_md5.calculate(); + _nonce = nonce_md5.toString(); + + char auth_req[38]; + sprintf(auth_req, "AUTH %s", _nonce.c_str()); + _udp_ota->append((const char *)auth_req, strlen(auth_req)); + _udp_ota->send(&ota_ip, _udp_ota->getRemotePort()); + _state = OTA_WAITAUTH; + return; + } else { + _udp_ota->append("OK", 2); + _udp_ota->send(&ota_ip, _udp_ota->getRemotePort()); + _state = OTA_RUNUPDATE; + } + } else if (_state == OTA_WAITAUTH) { + int cmd = parseInt(); + if (cmd != U_AUTH) { + _state = OTA_IDLE; + return; + } + _udp_ota->read(); + String cnonce = readStringUntil(' '); + String response = readStringUntil('\n'); + if (cnonce.length() != 32 || response.length() != 32) { + _state = OTA_IDLE; + return; + } + + MD5Builder _passmd5; + _passmd5.begin(); + _passmd5.add(_password); + _passmd5.calculate(); + String passmd5 = _passmd5.toString(); + + String challenge = passmd5 + ":" + String(_nonce) + ":" + cnonce; + MD5Builder _challengemd5; + _challengemd5.begin(); + _challengemd5.add(challenge); + _challengemd5.calculate(); + String result = _challengemd5.toString(); + + ota_ip.addr = (uint32_t)_ota_ip; + if(result.equals(response)){ + _udp_ota->append("OK", 2); + _udp_ota->send(&ota_ip, _udp_ota->getRemotePort()); + _state = OTA_RUNUPDATE; + } else { + _udp_ota->append("Authentication Failed", 21); + _udp_ota->send(&ota_ip, _udp_ota->getRemotePort()); + if (_error_callback) _error_callback(OTA_AUTH_ERROR); + _state = OTA_IDLE; + } + } + + while(_udp_ota->next()) _udp_ota->flush(); +} + void ArduinoOTAClass::_runUpdate() { if (!Update.begin(_size, _cmd)) { #if OTA_DEBUG @@ -92,7 +231,7 @@ void ArduinoOTAClass::_runUpdate() { if (_error_callback) { _error_callback(OTA_BEGIN_ERROR); } - _udp_ota.begin(_port); + _udp_ota->listen(*IP_ADDR_ANY, _port); _state = OTA_IDLE; return; } @@ -112,7 +251,7 @@ void ArduinoOTAClass::_runUpdate() { #if OTA_DEBUG Serial.printf("Connect Failed\n"); #endif - _udp_ota.begin(_port); + _udp_ota->listen(*IP_ADDR_ANY, _port); if (_error_callback) { _error_callback(OTA_CONNECT_ERROR); } @@ -128,7 +267,7 @@ void ArduinoOTAClass::_runUpdate() { #if OTA_DEBUG Serial.printf("Recieve Failed\n"); #endif - _udp_ota.begin(_port); + _udp_ota->listen(*IP_ADDR_ANY, _port); if (_error_callback) { _error_callback(OTA_RECIEVE_ERROR); } @@ -156,7 +295,7 @@ void ArduinoOTAClass::_runUpdate() { } ESP.restart(); } else { - _udp_ota.begin(_port); + _udp_ota->listen(*IP_ADDR_ANY, _port); if (_error_callback) { _error_callback(OTA_END_ERROR); } @@ -169,94 +308,9 @@ void ArduinoOTAClass::_runUpdate() { } void ArduinoOTAClass::handle() { - if (!_udp_ota) { - _udp_ota.begin(_port); -#if OTA_DEBUG - Serial.println("OTA restarted"); -#endif - } - - if (!_udp_ota.parsePacket()) return; - - if (_state == OTA_IDLE) { - int cmd = _udp_ota.parseInt(); - if (cmd != U_FLASH && cmd != U_SPIFFS) - return; - _ota_ip = _udp_ota.remoteIP(); - _cmd = cmd; - _ota_port = _udp_ota.parseInt(); - _size = _udp_ota.parseInt(); - _udp_ota.read(); - _md5 = _udp_ota.readStringUntil('\n'); - _md5.trim(); - if(_md5.length() != 32) - return; - -#if OTA_DEBUG - Serial.print("Update Start: ip:"); - Serial.print(_ota_ip); - Serial.printf(", port:%d, size:%d, md5:%s\n", _ota_port, _size, _md5.c_str()); -#endif - - _udp_ota.beginPacket(_ota_ip, _udp_ota.remotePort()); - if (_password.length()){ - MD5Builder nonce_md5; - nonce_md5.begin(); - nonce_md5.add(String(micros())); - nonce_md5.calculate(); - _nonce = nonce_md5.toString(); - _udp_ota.printf("AUTH %s", _nonce.c_str()); - _udp_ota.endPacket(); - _state = OTA_WAITAUTH; - return; - } else { - _udp_ota.print("OK"); - _udp_ota.endPacket(); - _state = OTA_RUNUPDATE; - } - } else if (_state == OTA_WAITAUTH) { - int cmd = _udp_ota.parseInt(); - if (cmd != U_AUTH) { - _state = OTA_IDLE; - return; - } - _udp_ota.read(); - String cnonce = _udp_ota.readStringUntil(' '); - String response = _udp_ota.readStringUntil('\n'); - if (cnonce.length() != 32 || response.length() != 32) { - _state = OTA_IDLE; - return; - } - - MD5Builder _passmd5; - _passmd5.begin(); - _passmd5.add(_password); - _passmd5.calculate(); - String passmd5 = _passmd5.toString(); - - String challenge = passmd5 + ":" + String(_nonce) + ":" + cnonce; - MD5Builder _challengemd5; - _challengemd5.begin(); - _challengemd5.add(challenge); - _challengemd5.calculate(); - String result = _challengemd5.toString(); - - if(result.equals(response)){ - _udp_ota.beginPacket(_ota_ip, _udp_ota.remotePort()); - _udp_ota.print("OK"); - _udp_ota.endPacket(); - _state = OTA_RUNUPDATE; - } else { - _udp_ota.beginPacket(_ota_ip, _udp_ota.remotePort()); - _udp_ota.print("Authentication Failed"); - _udp_ota.endPacket(); - if (_error_callback) _error_callback(OTA_AUTH_ERROR); - _state = OTA_IDLE; - } - } - if (_state == OTA_RUNUPDATE) { _runUpdate(); + _state = OTA_IDLE; } } diff --git a/libraries/ArduinoOTA/ArduinoOTA.h b/libraries/ArduinoOTA/ArduinoOTA.h index 5d5161e5eb..bdb839397d 100644 --- a/libraries/ArduinoOTA/ArduinoOTA.h +++ b/libraries/ArduinoOTA/ArduinoOTA.h @@ -1,7 +1,10 @@ #ifndef __ARDUINO_OTA_H #define __ARDUINO_OTA_H -class WiFiUDP; +#include +#include + +class UdpContext; #define OTA_CALLBACK(callback) void (*callback)() #define OTA_CALLBACK_PROGRESS(callback) void (*callback)(unsigned int, unsigned int) @@ -41,7 +44,7 @@ class ArduinoOTAClass String _password; String _hostname; String _nonce; - WiFiUDP _udp_ota; + UdpContext *_udp_ota; bool _initialized; ota_state_t _state; int _size; @@ -56,6 +59,9 @@ class ArduinoOTAClass OTA_CALLBACK_PROGRESS(_progress_callback); void _runUpdate(void); + void _onRx(void); + int parseInt(void); + String readStringUntil(char end); }; extern ArduinoOTAClass ArduinoOTA;