samples: tflite-micro: update samples for latest tflite-micro

tflite-micro now uses MicroPrintf instead of MicroErrorReporter. Update
the samples to use this function instead. AllOpsResolver is now removed
from tflite-micro. AllOpsResolver was also removed in the latest
tflite-micro. Use MicroMutableOpResolver and only include the kernels
used instead.

Signed-off-by: Ryan McClelland <ryanmcclelland@meta.com>
This commit is contained in:
Ryan McClelland 2023-04-18 11:35:10 -07:00 committed by Carles Cufí
parent d2f82648cd
commit 65a15e9299
12 changed files with 54 additions and 81 deletions

View file

@ -13,5 +13,6 @@
# limitations under the License.
# ==============================================================================s
CONFIG_CPP=y
CONFIG_STD_CPP17=y
CONFIG_TENSORFLOW_LITE_MICRO=y
CONFIG_MAIN_STACK_SIZE=2048

View file

@ -16,18 +16,17 @@
#include "main_functions.h"
#include <tensorflow/lite/micro/all_ops_resolver.h>
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
#include "constants.h"
#include "model.hpp"
#include "output_handler.hpp"
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_log.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/system_setup.h>
#include <tensorflow/lite/schema/schema_generated.h>
/* Globals, used for compatibility with Arduino-style sketches. */
namespace {
tflite::ErrorReporter *error_reporter = nullptr;
const tflite::Model *model = nullptr;
tflite::MicroInterpreter *interpreter = nullptr;
TfLiteTensor *input = nullptr;
@ -41,40 +40,32 @@ namespace {
/* The name of this function is important for Arduino compatibility. */
void setup(void)
{
/* Set up logging. Google style is to avoid globals or statics because of
* lifetime uncertainty, but since this has a trivial destructor it's okay.
* NOLINTNEXTLINE(runtime-global-variables)
*/
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = &micro_error_reporter;
/* Map the model into a usable data structure. This doesn't involve any
* copying or parsing, it's a very lightweight operation.
*/
model = tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
MicroPrintf("Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
/* This pulls in all the operation implementations we need.
/* This pulls in the operation implementations we need.
* NOLINTNEXTLINE(runtime-global-variables)
*/
static tflite::AllOpsResolver resolver;
static tflite::MicroMutableOpResolver <1> resolver;
resolver.AddFullyConnected();
/* Build an interpreter to run the model with. */
static tflite::MicroInterpreter static_interpreter(
model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
model, resolver, tensor_arena, kTensorArenaSize);
interpreter = &static_interpreter;
/* Allocate memory from the tensor_arena for the model's tensors. */
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
MicroPrintf("AllocateTensors() failed");
return;
}
@ -106,8 +97,7 @@ void loop(void)
/* Run inference, and report any error */
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on x: %f\n",
static_cast < double > (x));
MicroPrintf("Invoke failed on x: %f\n", static_cast < double > (x));
return;
}
@ -119,7 +109,7 @@ void loop(void)
/* Output the results. A custom HandleOutput function can be implemented
* for each supported hardware target.
*/
HandleOutput(error_reporter, x, y);
HandleOutput(x, y);
/* Increment the inference_counter, and reset it if we have reached
* the total number per cycle

View file

@ -16,11 +16,10 @@
#include "output_handler.hpp"
void HandleOutput(tflite::ErrorReporter *error_reporter, float x_value,
float y_value)
void HandleOutput(float x_value, float y_value)
{
/* Log the current X and Y values */
TF_LITE_REPORT_ERROR(error_reporter, "x_value: %f, y_value: %f\n",
static_cast < double > (x_value),
static_cast < double > (y_value));
MicroPrintf("x_value: %f, y_value: %f\n",
static_cast < double > (x_value),
static_cast < double > (y_value));
}

View file

@ -18,10 +18,9 @@
#define TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_
#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_log.h>
/* Called by the main loop to produce some output based on the x and y values */
void HandleOutput(tflite::ErrorReporter *error_reporter, float x_value,
float y_value);
void HandleOutput(float x_value, float y_value);
#endif /* TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_ */

View file

@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
CONFIG_CPP=y
CONFIG_STD_CPP17=y
CONFIG_NEWLIB_LIBC_FLOAT_PRINTF=y
CONFIG_SENSOR=y
CONFIG_NETWORKING=n

View file

@ -33,7 +33,7 @@ float bufz[BUFLEN] = { 0.0f };
bool initial = true;
TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter)
TfLiteStatus SetupAccelerometer()
{
if (!device_is_ready(sensor)) {
printk("%s: device not ready.\n", sensor->name);
@ -41,18 +41,16 @@ TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter)
}
if (sensor == NULL) {
TF_LITE_REPORT_ERROR(error_reporter,
"Failed to get accelerometer, name: %s\n",
sensor->name);
MicroPrintf("Failed to get accelerometer, name: %s\n",
sensor->name);
} else {
TF_LITE_REPORT_ERROR(error_reporter, "Got accelerometer, name: %s\n",
sensor->name);
MicroPrintf("Got accelerometer, name: %s\n",
sensor->name);
}
return kTfLiteOk;
}
bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input,
int length)
bool ReadAccelerometer(float *input, int length)
{
int rc;
struct sensor_value accel[3];
@ -60,7 +58,7 @@ bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input,
rc = sensor_sample_fetch(sensor);
if (rc < 0) {
TF_LITE_REPORT_ERROR(error_reporter, "Fetch failed\n");
MicroPrintf("Fetch failed\n");
return false;
}
/* Skip if there is no data */
@ -72,7 +70,7 @@ bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input,
for (int i = 0; i < samples_count; i++) {
rc = sensor_channel_get(sensor, SENSOR_CHAN_ACCEL_XYZ, accel);
if (rc < 0) {
TF_LITE_REPORT_ERROR(error_reporter, "ERROR: Update failed: %d\n", rc);
MicroPrintf("ERROR: Update failed: %d\n", rc);
return false;
}
bufx[begin_index] = (float)sensor_value_to_double(&accel[0]);

View file

@ -20,11 +20,10 @@
#define kChannelNumber 3
#include <tensorflow/lite/c/c_api_types.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_log.h>
extern int begin_index;
extern TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter);
extern bool ReadAccelerometer(tflite::ErrorReporter *error_reporter,
float *input, int length);
extern TfLiteStatus SetupAccelerometer();
extern bool ReadAccelerometer(float *input, int length);
#endif /* TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_ACCELEROMETER_HANDLER_H_ */

View file

@ -21,14 +21,13 @@
#include "gesture_predictor.hpp"
#include "magic_wand_model_data.hpp"
#include "output_handler.hpp"
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_log.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
#include <tensorflow/lite/schema/schema_generated.h>
/* Globals, used for compatibility with Arduino-style sketches. */
namespace {
tflite::ErrorReporter *error_reporter = nullptr;
const tflite::Model *model = nullptr;
tflite::MicroInterpreter *interpreter = nullptr;
TfLiteTensor *model_input = nullptr;
@ -45,22 +44,14 @@ namespace {
/* The name of this function is important for Arduino compatibility. */
void setup(void)
{
/* Set up logging. Google style is to avoid globals or statics because of
* lifetime uncertainty, but since this has a trivial destructor it's okay.
*/
static tflite::MicroErrorReporter micro_error_reporter; /* NOLINT */
error_reporter = &micro_error_reporter;
/* Map the model into a usable data structure. This doesn't involve any
* copying or parsing, it's a very lightweight operation.
*/
model = tflite::GetModel(g_magic_wand_model_data);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
MicroPrintf("Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
@ -79,7 +70,7 @@ void setup(void)
/* Build an interpreter to run the model with. */
static tflite::MicroInterpreter static_interpreter(
model, micro_op_resolver, tensor_arena, kTensorArenaSize, error_reporter);
model, micro_op_resolver, tensor_arena, kTensorArenaSize);
interpreter = &static_interpreter;
/* Allocate memory from the tensor_arena for the model's tensors. */
@ -91,16 +82,15 @@ void setup(void)
(model_input->dims->data[1] != 128) ||
(model_input->dims->data[2] != kChannelNumber) ||
(model_input->type != kTfLiteFloat32)) {
TF_LITE_REPORT_ERROR(error_reporter,
"Bad input tensor parameters in model");
MicroPrintf("Bad input tensor parameters in model");
return;
}
input_length = model_input->bytes / sizeof(float);
TfLiteStatus setup_status = SetupAccelerometer(error_reporter);
TfLiteStatus setup_status = SetupAccelerometer();
if (setup_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Set up failed\n");
MicroPrintf("Set up failed\n");
}
}
@ -108,7 +98,7 @@ void loop(void)
{
/* Attempt to read new data from the accelerometer. */
bool got_data =
ReadAccelerometer(error_reporter, model_input->data.f, input_length);
ReadAccelerometer(model_input->data.f, input_length);
/* If there was no new data, wait until next time. */
if (!got_data) {
@ -118,13 +108,12 @@ void loop(void)
/* Run inference, and report any error */
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on index: %d\n",
begin_index);
MicroPrintf("Invoke failed on index: %d\n", begin_index);
return;
}
/* Analyze the results to obtain a prediction */
int gesture_index = PredictGesture(interpreter->output(0)->data.f);
/* Produce an output */
HandleOutput(error_reporter, gesture_index);
HandleOutput(gesture_index);
}

View file

@ -16,24 +16,21 @@
#include "output_handler.hpp"
void HandleOutput(tflite::ErrorReporter *error_reporter, int kind)
void HandleOutput(int kind)
{
/* light (red: wing, blue: ring, green: slope) */
if (kind == 0) {
TF_LITE_REPORT_ERROR(
error_reporter,
MicroPrintf(
"WING:\n\r* * *\n\r * * * "
"*\n\r * * * *\n\r * * * *\n\r * * "
"* *\n\r * *\n\r");
} else if (kind == 1) {
TF_LITE_REPORT_ERROR(
error_reporter,
MicroPrintf(
"RING:\n\r *\n\r * *\n\r * *\n\r "
" * *\n\r * *\n\r * *\n\r "
" *\n\r");
} else if (kind == 2) {
TF_LITE_REPORT_ERROR(
error_reporter,
MicroPrintf(
"SLOPE:\n\r *\n\r *\n\r *\n\r *\n\r "
"*\n\r *\n\r *\n\r * * * * * * * *\n\r");
}

View file

@ -18,8 +18,8 @@
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_OUTPUT_HANDLER_H_
#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_log.h>
void HandleOutput(tflite::ErrorReporter *error_reporter, int kind);
void HandleOutput(int kind);
#endif /* TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_OUTPUT_HANDLER_H_ */

View file

@ -1,6 +1,7 @@
#application default configuration
# include TFLM based on CMSIS NN optimization and ETHOSU acceleration
CONFIG_CPP=y
CONFIG_STD_CPP17=y
CONFIG_TENSORFLOW_LITE_MICRO=y
CONFIG_ARM_ETHOS_U=y
CONFIG_HEAP_MEM_POOL_SIZE=16384

View file

@ -6,9 +6,9 @@
#include "inference_process.hpp"
#include <tensorflow/lite/micro/all_ops_resolver.h>
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
#include <tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_log.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/micro_profiler.h>
#include <tensorflow/lite/schema/schema_generated.h>
@ -118,11 +118,10 @@ bool InferenceProcess::runJob(InferenceJob &job)
}
/* Create the TFL micro interpreter */
tflite::AllOpsResolver resolver;
tflite::MicroErrorReporter errorReporter;
tflite::MicroMutableOpResolver <1> resolver;
resolver.AddEthosU();
tflite::MicroInterpreter interpreter(model, resolver, tensorArena, tensorArenaSize,
&errorReporter);
tflite::MicroInterpreter interpreter(model, resolver, tensorArena, tensorArenaSize);
/* Allocate tensors */
TfLiteStatus allocate_status = interpreter.AllocateTensors();