Skip to content

Commit

Permalink
Merge pull request #1056 from me-no-dev/async-ota
Browse files Browse the repository at this point in the history
Async ota
  • Loading branch information
igrr committed Nov 22, 2015
2 parents b5ca4fe + fe9dc91 commit 3d26810
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 99 deletions.
248 changes: 151 additions & 97 deletions libraries/ArduinoOTA/ArduinoOTA.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
#include <ESP8266WiFi.h>
#include <ESP8266mDNS.h>
#define LWIP_OPEN_SRC
#include <functional>
#include <WiFiUdp.h>
#include "ArduinoOTA.h"
#include "MD5Builder.h"

extern "C" {
#include "osapi.h"
#include "ets_sys.h"
#include "user_interface.h"
}

#include "lwip/opt.h"
#include "lwip/udp.h"
#include "lwip/inet.h"
#include "lwip/igmp.h"
#include "lwip/mem.h"
#include "include/UdpContext.h"
#include <ESP8266mDNS.h>

//#define OTA_DEBUG 1

ArduinoOTAClass::ArduinoOTAClass()
Expand All @@ -16,9 +30,17 @@ ArduinoOTAClass::ArduinoOTAClass()
, _end_callback(NULL)
, _progress_callback(NULL)
, _error_callback(NULL)
, _udp_ota(0)
{
}

ArduinoOTAClass::~ArduinoOTAClass(){
if(_udp_ota){
_udp_ota->unref();
_udp_ota = 0;
}
}

void ArduinoOTAClass::onStart(OTA_CALLBACK(fn)) {
_start_callback = fn;
}
Expand All @@ -35,9 +57,6 @@ void ArduinoOTAClass::onError(OTA_CALLBACK_ERROR(fn)) {
_error_callback = fn;
}

ArduinoOTAClass::~ArduinoOTAClass() {
}

void ArduinoOTAClass::setPort(uint16_t port) {
if (!_initialized && !_port && port) {
_port = port;
Expand All @@ -59,7 +78,6 @@ void ArduinoOTAClass::setPassword(const char * password) {
void ArduinoOTAClass::begin() {
if (_initialized)
return;
_initialized = true;

if (!_hostname.length()) {
char tmp[15];
Expand All @@ -70,20 +88,141 @@ void ArduinoOTAClass::begin() {
_port = 8266;
}

_udp_ota.begin(_port);
if(_udp_ota){
_udp_ota->unref();
_udp_ota = 0;
}

_udp_ota = new UdpContext;
_udp_ota->ref();

if(!_udp_ota->listen(*IP_ADDR_ANY, _port))
return;
_udp_ota->onRx(std::bind(&ArduinoOTAClass::_onRx, this));
MDNS.begin(_hostname.c_str());

if (_password.length()) {
MDNS.enableArduino(_port, true);
} else {
MDNS.enableArduino(_port);
}
_initialized = true;
_state = OTA_IDLE;
#if OTA_DEBUG
Serial.printf("OTA server at: %s.local:%u\n", _hostname.c_str(), _port);
#endif
}

int ArduinoOTAClass::parseInt(){
char data[16];
uint8_t index = 0;
char value;
while(_udp_ota->peek() == ' ') _udp_ota->read();
while(true){
value = _udp_ota->peek();
if(value < '0' || value > '9'){
data[index++] = '\0';
return atoi(data);
}
data[index++] = _udp_ota->read();
}
return 0;
}

String ArduinoOTAClass::readStringUntil(char end){
String res = "";
char value;
while(true){
value = _udp_ota->read();
if(value == '\0' || value == end){
return res;
}
res += value;
}
return res;
}

void ArduinoOTAClass::_onRx(){
if(!_udp_ota->next()) return;
ip_addr_t ota_ip;

if (_state == OTA_IDLE) {
int cmd = parseInt();
if (cmd != U_FLASH && cmd != U_SPIFFS)
return;
_ota_ip = _udp_ota->getRemoteAddress();
_cmd = cmd;
_ota_port = parseInt();
_size = parseInt();
_udp_ota->read();
_md5 = readStringUntil('\n');
_md5.trim();
if(_md5.length() != 32)
return;

ota_ip.addr = (uint32_t)_ota_ip;

if (_password.length()){
MD5Builder nonce_md5;
nonce_md5.begin();
nonce_md5.add(String(micros()));
nonce_md5.calculate();
_nonce = nonce_md5.toString();

char auth_req[38];
sprintf(auth_req, "AUTH %s", _nonce.c_str());
_udp_ota->append((const char *)auth_req, strlen(auth_req));
_udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
_state = OTA_WAITAUTH;
return;
} else {
_udp_ota->append("OK", 2);
_udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
_state = OTA_RUNUPDATE;
}
} else if (_state == OTA_WAITAUTH) {
int cmd = parseInt();
if (cmd != U_AUTH) {
_state = OTA_IDLE;
return;
}
_udp_ota->read();
String cnonce = readStringUntil(' ');
String response = readStringUntil('\n');
if (cnonce.length() != 32 || response.length() != 32) {
_state = OTA_IDLE;
return;
}

MD5Builder _passmd5;
_passmd5.begin();
_passmd5.add(_password);
_passmd5.calculate();
String passmd5 = _passmd5.toString();

String challenge = passmd5 + ":" + String(_nonce) + ":" + cnonce;
MD5Builder _challengemd5;
_challengemd5.begin();
_challengemd5.add(challenge);
_challengemd5.calculate();
String result = _challengemd5.toString();

ota_ip.addr = (uint32_t)_ota_ip;
if(result.equals(response)){
_udp_ota->append("OK", 2);
_udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
_state = OTA_RUNUPDATE;
} else {
_udp_ota->append("Authentication Failed", 21);
_udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
if (_error_callback) _error_callback(OTA_AUTH_ERROR);
_state = OTA_IDLE;
}
}

while(_udp_ota->next()) _udp_ota->flush();
}

void ArduinoOTAClass::_runUpdate() {
if (!Update.begin(_size, _cmd)) {
#if OTA_DEBUG
Expand All @@ -92,7 +231,7 @@ void ArduinoOTAClass::_runUpdate() {
if (_error_callback) {
_error_callback(OTA_BEGIN_ERROR);
}
_udp_ota.begin(_port);
_udp_ota->listen(*IP_ADDR_ANY, _port);
_state = OTA_IDLE;
return;
}
Expand All @@ -112,7 +251,7 @@ void ArduinoOTAClass::_runUpdate() {
#if OTA_DEBUG
Serial.printf("Connect Failed\n");
#endif
_udp_ota.begin(_port);
_udp_ota->listen(*IP_ADDR_ANY, _port);
if (_error_callback) {
_error_callback(OTA_CONNECT_ERROR);
}
Expand All @@ -128,7 +267,7 @@ void ArduinoOTAClass::_runUpdate() {
#if OTA_DEBUG
Serial.printf("Recieve Failed\n");
#endif
_udp_ota.begin(_port);
_udp_ota->listen(*IP_ADDR_ANY, _port);
if (_error_callback) {
_error_callback(OTA_RECIEVE_ERROR);
}
Expand Down Expand Up @@ -156,7 +295,7 @@ void ArduinoOTAClass::_runUpdate() {
}
ESP.restart();
} else {
_udp_ota.begin(_port);
_udp_ota->listen(*IP_ADDR_ANY, _port);
if (_error_callback) {
_error_callback(OTA_END_ERROR);
}
Expand All @@ -169,94 +308,9 @@ void ArduinoOTAClass::_runUpdate() {
}

void ArduinoOTAClass::handle() {
if (!_udp_ota) {
_udp_ota.begin(_port);
#if OTA_DEBUG
Serial.println("OTA restarted");
#endif
}

if (!_udp_ota.parsePacket()) return;

if (_state == OTA_IDLE) {
int cmd = _udp_ota.parseInt();
if (cmd != U_FLASH && cmd != U_SPIFFS)
return;
_ota_ip = _udp_ota.remoteIP();
_cmd = cmd;
_ota_port = _udp_ota.parseInt();
_size = _udp_ota.parseInt();
_udp_ota.read();
_md5 = _udp_ota.readStringUntil('\n');
_md5.trim();
if(_md5.length() != 32)
return;

#if OTA_DEBUG
Serial.print("Update Start: ip:");
Serial.print(_ota_ip);
Serial.printf(", port:%d, size:%d, md5:%s\n", _ota_port, _size, _md5.c_str());
#endif

_udp_ota.beginPacket(_ota_ip, _udp_ota.remotePort());
if (_password.length()){
MD5Builder nonce_md5;
nonce_md5.begin();
nonce_md5.add(String(micros()));
nonce_md5.calculate();
_nonce = nonce_md5.toString();
_udp_ota.printf("AUTH %s", _nonce.c_str());
_udp_ota.endPacket();
_state = OTA_WAITAUTH;
return;
} else {
_udp_ota.print("OK");
_udp_ota.endPacket();
_state = OTA_RUNUPDATE;
}
} else if (_state == OTA_WAITAUTH) {
int cmd = _udp_ota.parseInt();
if (cmd != U_AUTH) {
_state = OTA_IDLE;
return;
}
_udp_ota.read();
String cnonce = _udp_ota.readStringUntil(' ');
String response = _udp_ota.readStringUntil('\n');
if (cnonce.length() != 32 || response.length() != 32) {
_state = OTA_IDLE;
return;
}

MD5Builder _passmd5;
_passmd5.begin();
_passmd5.add(_password);
_passmd5.calculate();
String passmd5 = _passmd5.toString();

String challenge = passmd5 + ":" + String(_nonce) + ":" + cnonce;
MD5Builder _challengemd5;
_challengemd5.begin();
_challengemd5.add(challenge);
_challengemd5.calculate();
String result = _challengemd5.toString();

if(result.equals(response)){
_udp_ota.beginPacket(_ota_ip, _udp_ota.remotePort());
_udp_ota.print("OK");
_udp_ota.endPacket();
_state = OTA_RUNUPDATE;
} else {
_udp_ota.beginPacket(_ota_ip, _udp_ota.remotePort());
_udp_ota.print("Authentication Failed");
_udp_ota.endPacket();
if (_error_callback) _error_callback(OTA_AUTH_ERROR);
_state = OTA_IDLE;
}
}

if (_state == OTA_RUNUPDATE) {
_runUpdate();
_state = OTA_IDLE;
}
}

Expand Down
10 changes: 8 additions & 2 deletions libraries/ArduinoOTA/ArduinoOTA.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#ifndef __ARDUINO_OTA_H
#define __ARDUINO_OTA_H

class WiFiUDP;
#include <ESP8266WiFi.h>
#include <WiFiUdp.h>

class UdpContext;

#define OTA_CALLBACK(callback) void (*callback)()
#define OTA_CALLBACK_PROGRESS(callback) void (*callback)(unsigned int, unsigned int)
Expand Down Expand Up @@ -41,7 +44,7 @@ class ArduinoOTAClass
String _password;
String _hostname;
String _nonce;
WiFiUDP _udp_ota;
UdpContext *_udp_ota;
bool _initialized;
ota_state_t _state;
int _size;
Expand All @@ -56,6 +59,9 @@ class ArduinoOTAClass
OTA_CALLBACK_PROGRESS(_progress_callback);

void _runUpdate(void);
void _onRx(void);
int parseInt(void);
String readStringUntil(char end);
};

extern ArduinoOTAClass ArduinoOTA;
Expand Down

0 comments on commit 3d26810

Please sign in to comment.