Skip to content

Commit c1f82e8

Browse files
committed
tflm: update one_hot test to use micro_ops registration
1 parent 1da8669 commit c1f82e8

File tree

3 files changed

+20
-13
lines changed

3 files changed

+20
-13
lines changed

tensorflow/lite/micro/kernels/micro_ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ TFLMRegistration Register_MIRROR_PAD();
9595
TFLMRegistration Register_MUL();
9696
TFLMRegistration Register_NEG();
9797
TFLMRegistration Register_NOT_EQUAL();
98-
TFLMRegistration Register_ONE_HOT();
98+
TFLMRegistration* Register_ONE_HOT();
9999
TFLMRegistration Register_PACK();
100100
TFLMRegistration Register_PAD();
101101
TFLMRegistration Register_PADV2();

tensorflow/lite/micro/kernels/one_hot.cc

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ void OneHotComputeImpl(const OneHotContext& op_context) {
8686
tflite::micro::GetTensorShape(op_context.indices);
8787
const int suffix_dim_size = indices_shape.FlatSize() / prefix_dim_size;
8888

89-
const int depth = *op_context.depth->data.i32;
89+
const int32_t* depth_ptr =
90+
tflite::micro::GetTensorData<int32_t>(op_context.depth);
91+
if (depth_ptr == nullptr) return;
92+
const int depth = *depth_ptr;
9093

9194
const T on_value = *tflite::micro::GetTensorData<T>(op_context.on_value);
9295
const T off_value = *tflite::micro::GetTensorData<T>(op_context.off_value);
@@ -115,13 +118,14 @@ void OneHotCompute(const OneHotContext& op_context) {
115118
}
116119
}
117120

118-
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
119-
const OneHotContext& op_context) {
120-
TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0);
121-
121+
TfLiteStatus EnsureOutputDimsMatchExpected(TfLiteContext* context,
122+
const OneHotContext& op_context) {
122123
// read depth data
123-
const int depth_val =
124-
*tflite::micro::GetTensorData<int32_t>(op_context.depth);
124+
const int32_t* depth_ptr =
125+
tflite::micro::GetTensorData<int32_t>(op_context.depth);
126+
TF_LITE_ENSURE(context, depth_ptr != nullptr);
127+
128+
const int depth_val = *depth_ptr;
125129
TF_LITE_ENSURE(context, depth_val >= 0);
126130

127131
// Output Tensor evaluation
@@ -143,8 +147,8 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
143147
expected_dim_i = op_context.indices->dims->data[i - 1];
144148
}
145149

146-
// If the size pre-allocated by the TFLM compiler (Offline Memory Planner)
147-
// does not match the actual computed size, an error is raised.
150+
// If the size pre-allocated by the TFLM Memory Planner does not match the
151+
// actual computed size, an error is raised.
148152
TF_LITE_ENSURE_EQ(context, op_context.output->dims->data[i],
149153
expected_dim_i);
150154
}
@@ -196,7 +200,7 @@ TfLiteStatus OneHotPrepare(TfLiteContext* context, TfLiteNode* node) {
196200

197201
// Even if the depth tensor is not a constant, the test predefines the output
198202
// shape, so here we only perform validation.
199-
return ResizeOutputTensor(context, op_context);
203+
return EnsureOutputDimsMatchExpected(context, op_context);
200204
}
201205

202206
TfLiteStatus OneHotEval(TfLiteContext* context, TfLiteNode* node) {
@@ -230,8 +234,10 @@ TfLiteStatus OneHotEval(TfLiteContext* context, TfLiteNode* node) {
230234

231235
} // namespace
232236

233-
TFLMRegistration Register_ONE_HOT() {
234-
return tflite::micro::RegisterOp(OneHotInit, OneHotPrepare, OneHotEval);
237+
const TFLMRegistration* Register_ONE_HOT() {
238+
static TFLMRegistration r =
239+
tflite::micro::RegisterOp(OneHotInit, OneHotPrepare, OneHotEval);
240+
return &r;
235241
}
236242

237243
} // namespace tflite

tensorflow/lite/micro/one_hot_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "tensorflow/lite/c/builtin_op_data.h"
22
#include "tensorflow/lite/c/common.h"
33
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
4+
#include "tensorflow/lite/micro/kernels/micro_ops.h"
45
#include "tensorflow/lite/micro/test_helpers.h"
56
#include "tensorflow/lite/micro/testing/micro_test.h"
67

0 commit comments

Comments
 (0)