diff --git a/ESP32_AP-Flasher/include/wifimanager.h b/ESP32_AP-Flasher/include/wifimanager.h index 7aaafba7..fce867ea 100644 --- a/ESP32_AP-Flasher/include/wifimanager.h +++ b/ESP32_AP-Flasher/include/wifimanager.h @@ -1,3 +1,11 @@ +#pragma once +#include + +#include +#include +#include +#include + #ifndef WIFI_MANAGER_H #define WIFI_MANAGER_H @@ -26,6 +34,8 @@ class WifiManager { const int SERIAL_BUFFER_SIZE = 64; char serialBuffer[64]; int serialIndex = 0; + uint8_t x_buffer[16]; + uint8_t x_position = 0; String WiFi_SSID(); String WiFi_psk(); @@ -47,3 +57,72 @@ class WifiManager { }; #endif + +// **** Improv Wi-Fi **** +// https://www.improv-wifi.com/ +// https://github.com/jnthas/improv-wifi-demo +// http://www.apache.org/licenses/LICENSE-2.0 + +namespace improv { + +enum Error : uint8_t { + ERROR_NONE = 0x00, + ERROR_INVALID_RPC = 0x01, + ERROR_UNKNOWN_RPC = 0x02, + ERROR_UNABLE_TO_CONNECT = 0x03, + ERROR_NOT_AUTHORIZED = 0x04, + ERROR_UNKNOWN = 0xFF, +}; + +enum State : uint8_t { + STATE_STOPPED = 0x00, + STATE_AWAITING_AUTHORIZATION = 0x01, + STATE_AUTHORIZED = 0x02, + STATE_PROVISIONING = 0x03, + STATE_PROVISIONED = 0x04, +}; + +enum Command : uint8_t { + UNKNOWN = 0x00, + WIFI_SETTINGS = 0x01, + IDENTIFY = 0x02, + GET_CURRENT_STATE = 0x02, + GET_DEVICE_INFO = 0x03, + GET_WIFI_NETWORKS = 0x04, + BAD_CHECKSUM = 0xFF, +}; + +static const uint8_t CAPABILITY_IDENTIFY = 0x01; +static const uint8_t IMPROV_SERIAL_VERSION = 1; + +enum ImprovSerialType : uint8_t { + TYPE_CURRENT_STATE = 0x01, + TYPE_ERROR_STATE = 0x02, + TYPE_RPC = 0x03, + TYPE_RPC_RESPONSE = 0x04 +}; + +struct ImprovCommand { + Command command; + std::string ssid; + std::string password; +}; + +ImprovCommand parse_improv_data(const std::vector &data, bool check_checksum = true); +ImprovCommand parse_improv_data(const uint8_t *data, size_t length, bool check_checksum = true); + +bool parse_improv_serial_byte(size_t position, uint8_t byte, const uint8_t *buffer, + std::function &&callback, std::function &&on_error); + +std::vector build_rpc_response(Command command, const std::vector &datum, + bool add_checksum = true); +std::vector build_rpc_response(Command command, const std::vector &datum, bool add_checksum = true); + +} // namespace improv + +void set_state(improv::State state); +void send_response(std::vector &response); +void set_error(improv::Error error); +void getAvailableWifiNetworks(); +bool onCommandCallback(improv::ImprovCommand cmd); +void onErrorCallback(improv::Error err); diff --git a/ESP32_AP-Flasher/src/wifimanager.cpp b/ESP32_AP-Flasher/src/wifimanager.cpp index 22bc83ca..362ad7af 100644 --- a/ESP32_AP-Flasher/src/wifimanager.cpp +++ b/ESP32_AP-Flasher/src/wifimanager.cpp @@ -158,47 +158,53 @@ void WifiManager::pollSerial() { while (Serial.available() > 0) { char receivedChar = Serial.read(); - if (receivedChar == 27) { - memset(serialBuffer, 0, sizeof(serialBuffer)); - serialIndex = 0; - Serial.println(); - continue; - } - - if (receivedChar == 8) { - if (serialIndex > 0) { - serialIndex--; - serialBuffer[serialIndex] = '\0'; - Serial.print("\r"); - Serial.print(serialBuffer); - } - continue; - } - if (receivedChar == '\r') { - continue; - } - - if (receivedChar == '\n') { - serialBuffer[serialIndex] = '\0'; - String command = String(serialBuffer); - - if (command.startsWith("ssid ")) { - _ssid = command.substring(5); - Serial.println("\rSSID set to: " + _ssid); - } else if (command.startsWith("pass ")) { - _pass = command.substring(5); - Serial.println("\rPassword set to: " + _pass); - } else if (command.startsWith("connect")) { - connectToWifi(_ssid, _pass, true); - } - memset(serialBuffer, 0, sizeof(serialBuffer)); - serialIndex = 0; + if (parse_improv_serial_byte(x_position, receivedChar, x_buffer, onCommandCallback, onErrorCallback)) { + x_buffer[x_position++] = receivedChar; } else { - if (serialIndex < SERIAL_BUFFER_SIZE - 1) { - serialBuffer[serialIndex] = receivedChar; - serialIndex++; - Serial.print("\r"); - Serial.print(serialBuffer); + x_position = 0; + + if (receivedChar == 27) { + memset(serialBuffer, 0, sizeof(serialBuffer)); + serialIndex = 0; + Serial.println(); + continue; + } + + if (receivedChar == 8) { + if (serialIndex > 0) { + serialIndex--; + serialBuffer[serialIndex] = '\0'; + Serial.print("\r"); + Serial.print(serialBuffer); + } + continue; + } + if (receivedChar == '\r') { + continue; + } + + if (receivedChar == '\n') { + serialBuffer[serialIndex] = '\0'; + String command = String(serialBuffer); + + if (command.startsWith("ssid ")) { + _ssid = command.substring(5); + Serial.println("\rSSID set to: " + _ssid); + } else if (command.startsWith("pass ")) { + _pass = command.substring(5); + Serial.println("\rPassword set to: " + _pass); + } else if (command.startsWith("connect")) { + connectToWifi(_ssid, _pass, true); + } + memset(serialBuffer, 0, sizeof(serialBuffer)); + serialIndex = 0; + } else { + if (serialIndex < SERIAL_BUFFER_SIZE - 1) { + serialBuffer[serialIndex] = receivedChar; + serialIndex++; + Serial.print("\r"); + Serial.print(serialBuffer); + } } } } @@ -264,3 +270,295 @@ void WifiManager::WiFiEvent(WiFiEvent_t event) { break; } } + +// *** Improv + +#define STR_IMPL(x) #x +#define STR(x) STR_IMPL(x) + +#ifndef BUILD_ENV_NAME +#define BUILD_ENV_NAME unknown +#endif +#ifndef BUILD_TIME +#define BUILD_TIME 0 +#endif +#ifndef BUILD_VERSION +#define BUILD_VERSION custom +#endif + +std::vector getLocalUrl() { + return { String("http://" + WiFi.localIP().toString()).c_str() }; +} + +void onErrorCallback(improv::Error err) { +} + +bool onCommandCallback(improv::ImprovCommand cmd) { + switch (cmd.command) { + case improv::Command::GET_CURRENT_STATE: { + if ((WiFi.status() == WL_CONNECTED)) { + set_state(improv::State::STATE_PROVISIONED); + std::vector data = improv::build_rpc_response(improv::GET_CURRENT_STATE, getLocalUrl(), false); + send_response(data); + } else { + set_state(improv::State::STATE_AUTHORIZED); + } + break; + } + + case improv::Command::WIFI_SETTINGS: { + if (cmd.ssid.length() == 0) { + set_error(improv::Error::ERROR_INVALID_RPC); + break; + } + + set_state(improv::STATE_PROVISIONING); + + WifiManager wm; + if (wm.connectToWifi(String(cmd.ssid.c_str()), String(cmd.password.c_str()), true)) { + set_state(improv::STATE_PROVISIONED); + std::vector data = improv::build_rpc_response(improv::WIFI_SETTINGS, getLocalUrl(), false); + send_response(data); + } else { + set_state(improv::STATE_STOPPED); + set_error(improv::Error::ERROR_UNABLE_TO_CONNECT); + } + + break; + } + + case improv::Command::GET_DEVICE_INFO: { + std::vector infos = { + // Firmware name + "OpenEPaperLink", + // Firmware version + STR(BUILD_VERSION), + // Hardware chip/variant + STR(BUILD_ENV_NAME), + // Device name + "Access Point"}; + std::vector data = improv::build_rpc_response(improv::GET_DEVICE_INFO, infos, false); + send_response(data); + break; + } + + case improv::Command::GET_WIFI_NETWORKS: { + getAvailableWifiNetworks(); + break; + } + + default: { + set_error(improv::ERROR_UNKNOWN_RPC); + return false; + } + } + + return true; +} + +void getAvailableWifiNetworks() { + int networkNum = WiFi.scanNetworks(); + + for (int id = 0; id < networkNum; ++id) { + std::vector data = improv::build_rpc_response( + improv::GET_WIFI_NETWORKS, {WiFi.SSID(id), String(WiFi.RSSI(id)), (WiFi.encryptionType(id) == WIFI_AUTH_OPEN ? "NO" : "YES")}, false); + send_response(data); + delay(1); + } + // final response + std::vector data = + improv::build_rpc_response(improv::GET_WIFI_NETWORKS, std::vector{}, false); + send_response(data); +} + +void set_state(improv::State state) { + std::vector data = {'I', 'M', 'P', 'R', 'O', 'V'}; + data.resize(11); + data[6] = improv::IMPROV_SERIAL_VERSION; + data[7] = improv::TYPE_CURRENT_STATE; + data[8] = 1; + data[9] = state; + + uint8_t checksum = 0x00; + for (uint8_t d : data) + checksum += d; + data[10] = checksum; + + Serial.write(data.data(), data.size()); +} + +void send_response(std::vector &response) { + std::vector data = {'I', 'M', 'P', 'R', 'O', 'V'}; + data.resize(9); + data[6] = improv::IMPROV_SERIAL_VERSION; + data[7] = improv::TYPE_RPC_RESPONSE; + data[8] = response.size(); + data.insert(data.end(), response.begin(), response.end()); + + uint8_t checksum = 0x00; + for (uint8_t d : data) + checksum += d; + data.push_back(checksum); + + Serial.write(data.data(), data.size()); +} + +void set_error(improv::Error error) { + std::vector data = {'I', 'M', 'P', 'R', 'O', 'V'}; + data.resize(11); + data[6] = improv::IMPROV_SERIAL_VERSION; + data[7] = improv::TYPE_ERROR_STATE; + data[8] = 1; + data[9] = error; + + uint8_t checksum = 0x00; + for (uint8_t d : data) + checksum += d; + data[10] = checksum; + + Serial.write(data.data(), data.size()); +} + +// **** improv **** + +namespace improv { + +ImprovCommand parse_improv_data(const std::vector &data, bool check_checksum) { + return parse_improv_data(data.data(), data.size(), check_checksum); +} + +ImprovCommand parse_improv_data(const uint8_t *data, size_t length, bool check_checksum) { + ImprovCommand improv_command; + Command command = (Command)data[0]; + uint8_t data_length = data[1]; + + if (data_length != length - 2 - check_checksum) { + improv_command.command = UNKNOWN; + return improv_command; + } + + if (check_checksum) { + uint8_t checksum = data[length - 1]; + + uint32_t calculated_checksum = 0; + for (uint8_t i = 0; i < length - 1; i++) { + calculated_checksum += data[i]; + } + + if ((uint8_t)calculated_checksum != checksum) { + improv_command.command = BAD_CHECKSUM; + return improv_command; + } + } + + if (command == WIFI_SETTINGS) { + uint8_t ssid_length = data[2]; + uint8_t ssid_start = 3; + size_t ssid_end = ssid_start + ssid_length; + + uint8_t pass_length = data[ssid_end]; + size_t pass_start = ssid_end + 1; + size_t pass_end = pass_start + pass_length; + + std::string ssid(data + ssid_start, data + ssid_end); + std::string password(data + pass_start, data + pass_end); + return {.command = command, .ssid = ssid, .password = password}; + } + + improv_command.command = command; + return improv_command; +} + +bool parse_improv_serial_byte(size_t position, uint8_t byte, const uint8_t *buffer, + std::function &&callback, std::function &&on_error) { + if (position == 0) + return byte == 'I'; + if (position == 1) + return byte == 'M'; + if (position == 2) + return byte == 'P'; + if (position == 3) + return byte == 'R'; + if (position == 4) + return byte == 'O'; + if (position == 5) + return byte == 'V'; + + if (position == 6) + return byte == IMPROV_SERIAL_VERSION; + + if (position <= 8) + return true; + + uint8_t type = buffer[7]; + uint8_t data_len = buffer[8]; + + if (position <= 8 + data_len) + return true; + + if (position == 8 + data_len + 1) { + uint8_t checksum = 0x00; + for (size_t i = 0; i < position; i++) + checksum += buffer[i]; + + if (checksum != byte) { + on_error(ERROR_INVALID_RPC); + return false; + } + + if (type == TYPE_RPC) { + auto command = parse_improv_data(&buffer[9], data_len, false); + return callback(command); + } + } + + return false; +} + +std::vector build_rpc_response(Command command, const std::vector &datum, bool add_checksum) { + std::vector out; + uint32_t length = 0; + out.push_back(command); + for (const auto &str : datum) { + uint8_t len = str.length(); + length += len + 1; + out.push_back(len); + out.insert(out.end(), str.begin(), str.end()); + } + out.insert(out.begin() + 1, length); + + if (add_checksum) { + uint32_t calculated_checksum = 0; + + for (uint8_t byte : out) { + calculated_checksum += byte; + } + out.push_back(calculated_checksum); + } + return out; +} + +std::vector build_rpc_response(Command command, const std::vector &datum, bool add_checksum) { + std::vector out; + uint32_t length = 0; + out.push_back(command); + for (const auto &str : datum) { + uint8_t len = str.length(); + length += len; + out.push_back(len); + out.insert(out.end(), str.begin(), str.end()); + } + out.insert(out.begin() + 1, length); + + if (add_checksum) { + uint32_t calculated_checksum = 0; + + for (uint8_t byte : out) { + calculated_checksum += byte; + } + out.push_back(calculated_checksum); + } + return out; +} + +} // namespace improv