samples: tflite-micro: update samples for latest tflite-micro
tflite-micro now uses MicroPrintf instead of MicroErrorReporter. Update the samples to use this function instead. AllOpsResolver is now removed from tflite-micro. AllOpsResolver was also removed in the latest tflite-micro. Use MicroMutableOpResolver and only include the kernels used instead. Signed-off-by: Ryan McClelland <ryanmcclelland@meta.com>
This commit is contained in:
parent
d2f82648cd
commit
65a15e9299
|
@ -13,5 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================s
|
||||
CONFIG_CPP=y
|
||||
CONFIG_STD_CPP17=y
|
||||
CONFIG_TENSORFLOW_LITE_MICRO=y
|
||||
CONFIG_MAIN_STACK_SIZE=2048
|
||||
|
|
|
@ -16,18 +16,17 @@
|
|||
|
||||
#include "main_functions.h"
|
||||
|
||||
#include <tensorflow/lite/micro/all_ops_resolver.h>
|
||||
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
|
||||
#include "constants.h"
|
||||
#include "model.hpp"
|
||||
#include "output_handler.hpp"
|
||||
#include <tensorflow/lite/micro/micro_error_reporter.h>
|
||||
#include <tensorflow/lite/micro/micro_log.h>
|
||||
#include <tensorflow/lite/micro/micro_interpreter.h>
|
||||
#include <tensorflow/lite/micro/system_setup.h>
|
||||
#include <tensorflow/lite/schema/schema_generated.h>
|
||||
|
||||
/* Globals, used for compatibility with Arduino-style sketches. */
|
||||
namespace {
|
||||
tflite::ErrorReporter *error_reporter = nullptr;
|
||||
const tflite::Model *model = nullptr;
|
||||
tflite::MicroInterpreter *interpreter = nullptr;
|
||||
TfLiteTensor *input = nullptr;
|
||||
|
@ -41,40 +40,32 @@ namespace {
|
|||
/* The name of this function is important for Arduino compatibility. */
|
||||
void setup(void)
|
||||
{
|
||||
/* Set up logging. Google style is to avoid globals or statics because of
|
||||
* lifetime uncertainty, but since this has a trivial destructor it's okay.
|
||||
* NOLINTNEXTLINE(runtime-global-variables)
|
||||
*/
|
||||
static tflite::MicroErrorReporter micro_error_reporter;
|
||||
|
||||
error_reporter = µ_error_reporter;
|
||||
|
||||
/* Map the model into a usable data structure. This doesn't involve any
|
||||
* copying or parsing, it's a very lightweight operation.
|
||||
*/
|
||||
model = tflite::GetModel(g_model);
|
||||
if (model->version() != TFLITE_SCHEMA_VERSION) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"Model provided is schema version %d not equal "
|
||||
"to supported version %d.",
|
||||
model->version(), TFLITE_SCHEMA_VERSION);
|
||||
MicroPrintf("Model provided is schema version %d not equal "
|
||||
"to supported version %d.",
|
||||
model->version(), TFLITE_SCHEMA_VERSION);
|
||||
return;
|
||||
}
|
||||
|
||||
/* This pulls in all the operation implementations we need.
|
||||
/* This pulls in the operation implementations we need.
|
||||
* NOLINTNEXTLINE(runtime-global-variables)
|
||||
*/
|
||||
static tflite::AllOpsResolver resolver;
|
||||
static tflite::MicroMutableOpResolver <1> resolver;
|
||||
resolver.AddFullyConnected();
|
||||
|
||||
/* Build an interpreter to run the model with. */
|
||||
static tflite::MicroInterpreter static_interpreter(
|
||||
model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
|
||||
model, resolver, tensor_arena, kTensorArenaSize);
|
||||
interpreter = &static_interpreter;
|
||||
|
||||
/* Allocate memory from the tensor_arena for the model's tensors. */
|
||||
TfLiteStatus allocate_status = interpreter->AllocateTensors();
|
||||
if (allocate_status != kTfLiteOk) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
|
||||
MicroPrintf("AllocateTensors() failed");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -106,8 +97,7 @@ void loop(void)
|
|||
/* Run inference, and report any error */
|
||||
TfLiteStatus invoke_status = interpreter->Invoke();
|
||||
if (invoke_status != kTfLiteOk) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on x: %f\n",
|
||||
static_cast < double > (x));
|
||||
MicroPrintf("Invoke failed on x: %f\n", static_cast < double > (x));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -119,7 +109,7 @@ void loop(void)
|
|||
/* Output the results. A custom HandleOutput function can be implemented
|
||||
* for each supported hardware target.
|
||||
*/
|
||||
HandleOutput(error_reporter, x, y);
|
||||
HandleOutput(x, y);
|
||||
|
||||
/* Increment the inference_counter, and reset it if we have reached
|
||||
* the total number per cycle
|
||||
|
|
|
@ -16,11 +16,10 @@
|
|||
|
||||
#include "output_handler.hpp"
|
||||
|
||||
void HandleOutput(tflite::ErrorReporter *error_reporter, float x_value,
|
||||
float y_value)
|
||||
void HandleOutput(float x_value, float y_value)
|
||||
{
|
||||
/* Log the current X and Y values */
|
||||
TF_LITE_REPORT_ERROR(error_reporter, "x_value: %f, y_value: %f\n",
|
||||
static_cast < double > (x_value),
|
||||
static_cast < double > (y_value));
|
||||
MicroPrintf("x_value: %f, y_value: %f\n",
|
||||
static_cast < double > (x_value),
|
||||
static_cast < double > (y_value));
|
||||
}
|
||||
|
|
|
@ -18,10 +18,9 @@
|
|||
#define TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_
|
||||
|
||||
#include <tensorflow/lite/c/common.h>
|
||||
#include <tensorflow/lite/micro/micro_error_reporter.h>
|
||||
#include <tensorflow/lite/micro/micro_log.h>
|
||||
|
||||
/* Called by the main loop to produce some output based on the x and y values */
|
||||
void HandleOutput(tflite::ErrorReporter *error_reporter, float x_value,
|
||||
float y_value);
|
||||
void HandleOutput(float x_value, float y_value);
|
||||
|
||||
#endif /* TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_ */
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
CONFIG_CPP=y
|
||||
CONFIG_STD_CPP17=y
|
||||
CONFIG_NEWLIB_LIBC_FLOAT_PRINTF=y
|
||||
CONFIG_SENSOR=y
|
||||
CONFIG_NETWORKING=n
|
||||
|
|
|
@ -33,7 +33,7 @@ float bufz[BUFLEN] = { 0.0f };
|
|||
|
||||
bool initial = true;
|
||||
|
||||
TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter)
|
||||
TfLiteStatus SetupAccelerometer()
|
||||
{
|
||||
if (!device_is_ready(sensor)) {
|
||||
printk("%s: device not ready.\n", sensor->name);
|
||||
|
@ -41,18 +41,16 @@ TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter)
|
|||
}
|
||||
|
||||
if (sensor == NULL) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"Failed to get accelerometer, name: %s\n",
|
||||
sensor->name);
|
||||
MicroPrintf("Failed to get accelerometer, name: %s\n",
|
||||
sensor->name);
|
||||
} else {
|
||||
TF_LITE_REPORT_ERROR(error_reporter, "Got accelerometer, name: %s\n",
|
||||
sensor->name);
|
||||
MicroPrintf("Got accelerometer, name: %s\n",
|
||||
sensor->name);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input,
|
||||
int length)
|
||||
bool ReadAccelerometer(float *input, int length)
|
||||
{
|
||||
int rc;
|
||||
struct sensor_value accel[3];
|
||||
|
@ -60,7 +58,7 @@ bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input,
|
|||
|
||||
rc = sensor_sample_fetch(sensor);
|
||||
if (rc < 0) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter, "Fetch failed\n");
|
||||
MicroPrintf("Fetch failed\n");
|
||||
return false;
|
||||
}
|
||||
/* Skip if there is no data */
|
||||
|
@ -72,7 +70,7 @@ bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input,
|
|||
for (int i = 0; i < samples_count; i++) {
|
||||
rc = sensor_channel_get(sensor, SENSOR_CHAN_ACCEL_XYZ, accel);
|
||||
if (rc < 0) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter, "ERROR: Update failed: %d\n", rc);
|
||||
MicroPrintf("ERROR: Update failed: %d\n", rc);
|
||||
return false;
|
||||
}
|
||||
bufx[begin_index] = (float)sensor_value_to_double(&accel[0]);
|
||||
|
|
|
@ -20,11 +20,10 @@
|
|||
#define kChannelNumber 3
|
||||
|
||||
#include <tensorflow/lite/c/c_api_types.h>
|
||||
#include <tensorflow/lite/micro/micro_error_reporter.h>
|
||||
#include <tensorflow/lite/micro/micro_log.h>
|
||||
|
||||
extern int begin_index;
|
||||
extern TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter);
|
||||
extern bool ReadAccelerometer(tflite::ErrorReporter *error_reporter,
|
||||
float *input, int length);
|
||||
extern TfLiteStatus SetupAccelerometer();
|
||||
extern bool ReadAccelerometer(float *input, int length);
|
||||
|
||||
#endif /* TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_ACCELEROMETER_HANDLER_H_ */
|
||||
|
|
|
@ -21,14 +21,13 @@
|
|||
#include "gesture_predictor.hpp"
|
||||
#include "magic_wand_model_data.hpp"
|
||||
#include "output_handler.hpp"
|
||||
#include <tensorflow/lite/micro/micro_error_reporter.h>
|
||||
#include <tensorflow/lite/micro/micro_log.h>
|
||||
#include <tensorflow/lite/micro/micro_interpreter.h>
|
||||
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
|
||||
#include <tensorflow/lite/schema/schema_generated.h>
|
||||
|
||||
/* Globals, used for compatibility with Arduino-style sketches. */
|
||||
namespace {
|
||||
tflite::ErrorReporter *error_reporter = nullptr;
|
||||
const tflite::Model *model = nullptr;
|
||||
tflite::MicroInterpreter *interpreter = nullptr;
|
||||
TfLiteTensor *model_input = nullptr;
|
||||
|
@ -45,22 +44,14 @@ namespace {
|
|||
/* The name of this function is important for Arduino compatibility. */
|
||||
void setup(void)
|
||||
{
|
||||
/* Set up logging. Google style is to avoid globals or statics because of
|
||||
* lifetime uncertainty, but since this has a trivial destructor it's okay.
|
||||
*/
|
||||
static tflite::MicroErrorReporter micro_error_reporter; /* NOLINT */
|
||||
|
||||
error_reporter = µ_error_reporter;
|
||||
|
||||
/* Map the model into a usable data structure. This doesn't involve any
|
||||
* copying or parsing, it's a very lightweight operation.
|
||||
*/
|
||||
model = tflite::GetModel(g_magic_wand_model_data);
|
||||
if (model->version() != TFLITE_SCHEMA_VERSION) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"Model provided is schema version %d not equal "
|
||||
"to supported version %d.",
|
||||
model->version(), TFLITE_SCHEMA_VERSION);
|
||||
MicroPrintf("Model provided is schema version %d not equal "
|
||||
"to supported version %d.",
|
||||
model->version(), TFLITE_SCHEMA_VERSION);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -79,7 +70,7 @@ void setup(void)
|
|||
|
||||
/* Build an interpreter to run the model with. */
|
||||
static tflite::MicroInterpreter static_interpreter(
|
||||
model, micro_op_resolver, tensor_arena, kTensorArenaSize, error_reporter);
|
||||
model, micro_op_resolver, tensor_arena, kTensorArenaSize);
|
||||
interpreter = &static_interpreter;
|
||||
|
||||
/* Allocate memory from the tensor_arena for the model's tensors. */
|
||||
|
@ -91,16 +82,15 @@ void setup(void)
|
|||
(model_input->dims->data[1] != 128) ||
|
||||
(model_input->dims->data[2] != kChannelNumber) ||
|
||||
(model_input->type != kTfLiteFloat32)) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"Bad input tensor parameters in model");
|
||||
MicroPrintf("Bad input tensor parameters in model");
|
||||
return;
|
||||
}
|
||||
|
||||
input_length = model_input->bytes / sizeof(float);
|
||||
|
||||
TfLiteStatus setup_status = SetupAccelerometer(error_reporter);
|
||||
TfLiteStatus setup_status = SetupAccelerometer();
|
||||
if (setup_status != kTfLiteOk) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter, "Set up failed\n");
|
||||
MicroPrintf("Set up failed\n");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -108,7 +98,7 @@ void loop(void)
|
|||
{
|
||||
/* Attempt to read new data from the accelerometer. */
|
||||
bool got_data =
|
||||
ReadAccelerometer(error_reporter, model_input->data.f, input_length);
|
||||
ReadAccelerometer(model_input->data.f, input_length);
|
||||
|
||||
/* If there was no new data, wait until next time. */
|
||||
if (!got_data) {
|
||||
|
@ -118,13 +108,12 @@ void loop(void)
|
|||
/* Run inference, and report any error */
|
||||
TfLiteStatus invoke_status = interpreter->Invoke();
|
||||
if (invoke_status != kTfLiteOk) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on index: %d\n",
|
||||
begin_index);
|
||||
MicroPrintf("Invoke failed on index: %d\n", begin_index);
|
||||
return;
|
||||
}
|
||||
/* Analyze the results to obtain a prediction */
|
||||
int gesture_index = PredictGesture(interpreter->output(0)->data.f);
|
||||
|
||||
/* Produce an output */
|
||||
HandleOutput(error_reporter, gesture_index);
|
||||
HandleOutput(gesture_index);
|
||||
}
|
||||
|
|
|
@ -16,24 +16,21 @@
|
|||
|
||||
#include "output_handler.hpp"
|
||||
|
||||
void HandleOutput(tflite::ErrorReporter *error_reporter, int kind)
|
||||
void HandleOutput(int kind)
|
||||
{
|
||||
/* light (red: wing, blue: ring, green: slope) */
|
||||
if (kind == 0) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter,
|
||||
MicroPrintf(
|
||||
"WING:\n\r* * *\n\r * * * "
|
||||
"*\n\r * * * *\n\r * * * *\n\r * * "
|
||||
"* *\n\r * *\n\r");
|
||||
} else if (kind == 1) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter,
|
||||
MicroPrintf(
|
||||
"RING:\n\r *\n\r * *\n\r * *\n\r "
|
||||
" * *\n\r * *\n\r * *\n\r "
|
||||
" *\n\r");
|
||||
} else if (kind == 2) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter,
|
||||
MicroPrintf(
|
||||
"SLOPE:\n\r *\n\r *\n\r *\n\r *\n\r "
|
||||
"*\n\r *\n\r *\n\r * * * * * * * *\n\r");
|
||||
}
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_OUTPUT_HANDLER_H_
|
||||
|
||||
#include <tensorflow/lite/c/common.h>
|
||||
#include <tensorflow/lite/micro/micro_error_reporter.h>
|
||||
#include <tensorflow/lite/micro/micro_log.h>
|
||||
|
||||
void HandleOutput(tflite::ErrorReporter *error_reporter, int kind);
|
||||
void HandleOutput(int kind);
|
||||
|
||||
#endif /* TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_OUTPUT_HANDLER_H_ */
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#application default configuration
|
||||
# include TFLM based on CMSIS NN optimization and ETHOSU acceleration
|
||||
CONFIG_CPP=y
|
||||
CONFIG_STD_CPP17=y
|
||||
CONFIG_TENSORFLOW_LITE_MICRO=y
|
||||
CONFIG_ARM_ETHOS_U=y
|
||||
CONFIG_HEAP_MEM_POOL_SIZE=16384
|
||||
|
|
|
@ -6,9 +6,9 @@
|
|||
|
||||
#include "inference_process.hpp"
|
||||
|
||||
#include <tensorflow/lite/micro/all_ops_resolver.h>
|
||||
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
|
||||
#include <tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h>
|
||||
#include <tensorflow/lite/micro/micro_error_reporter.h>
|
||||
#include <tensorflow/lite/micro/micro_log.h>
|
||||
#include <tensorflow/lite/micro/micro_interpreter.h>
|
||||
#include <tensorflow/lite/micro/micro_profiler.h>
|
||||
#include <tensorflow/lite/schema/schema_generated.h>
|
||||
|
@ -118,11 +118,10 @@ bool InferenceProcess::runJob(InferenceJob &job)
|
|||
}
|
||||
|
||||
/* Create the TFL micro interpreter */
|
||||
tflite::AllOpsResolver resolver;
|
||||
tflite::MicroErrorReporter errorReporter;
|
||||
tflite::MicroMutableOpResolver <1> resolver;
|
||||
resolver.AddEthosU();
|
||||
|
||||
tflite::MicroInterpreter interpreter(model, resolver, tensorArena, tensorArenaSize,
|
||||
&errorReporter);
|
||||
tflite::MicroInterpreter interpreter(model, resolver, tensorArena, tensorArenaSize);
|
||||
|
||||
/* Allocate tensors */
|
||||
TfLiteStatus allocate_status = interpreter.AllocateTensors();
|
||||
|
|
Loading…
Reference in a new issue