@@ -86,7 +86,10 @@ void OneHotComputeImpl(const OneHotContext& op_context) {
8686 tflite::micro::GetTensorShape (op_context.indices );
8787 const int suffix_dim_size = indices_shape.FlatSize () / prefix_dim_size;
8888
89- const int depth = *op_context.depth ->data .i32 ;
89+ const int32_t * depth_ptr =
90+ tflite::micro::GetTensorData<int32_t >(op_context.depth );
91+ if (depth_ptr == nullptr ) return ;
92+ const int depth = *depth_ptr;
9093
9194 const T on_value = *tflite::micro::GetTensorData<T>(op_context.on_value );
9295 const T off_value = *tflite::micro::GetTensorData<T>(op_context.off_value );
@@ -115,13 +118,14 @@ void OneHotCompute(const OneHotContext& op_context) {
115118 }
116119}
117120
118- TfLiteStatus ResizeOutputTensor (TfLiteContext* context,
119- const OneHotContext& op_context) {
120- TF_LITE_ENSURE (context, *op_context.depth ->data .i32 >= 0 );
121-
121+ TfLiteStatus EnsureOutputDimsMatchExpected (TfLiteContext* context,
122+ const OneHotContext& op_context) {
122123 // read depth data
123- const int depth_val =
124- *tflite::micro::GetTensorData<int32_t >(op_context.depth );
124+ const int32_t * depth_ptr =
125+ tflite::micro::GetTensorData<int32_t >(op_context.depth );
126+ TF_LITE_ENSURE (context, depth_ptr != nullptr );
127+
128+ const int depth_val = *depth_ptr;
125129 TF_LITE_ENSURE (context, depth_val >= 0 );
126130
127131 // Output Tensor evaluation
@@ -143,8 +147,8 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
143147 expected_dim_i = op_context.indices ->dims ->data [i - 1 ];
144148 }
145149
146- // If the size pre-allocated by the TFLM compiler (Offline Memory Planner)
147- // does not match the actual computed size, an error is raised.
150+ // If the size pre-allocated by the TFLM Memory Planner does not match the
151+ // actual computed size, an error is raised.
148152 TF_LITE_ENSURE_EQ (context, op_context.output ->dims ->data [i],
149153 expected_dim_i);
150154 }
@@ -196,7 +200,7 @@ TfLiteStatus OneHotPrepare(TfLiteContext* context, TfLiteNode* node) {
196200
197201 // Even if the depth tensor is not a constant, the test predefines the output
198202 // shape, so here we only perform validation.
199- return ResizeOutputTensor (context, op_context);
203+ return EnsureOutputDimsMatchExpected (context, op_context);
200204}
201205
202206TfLiteStatus OneHotEval (TfLiteContext* context, TfLiteNode* node) {
@@ -230,8 +234,10 @@ TfLiteStatus OneHotEval(TfLiteContext* context, TfLiteNode* node) {
230234
231235} // namespace
232236
233- TFLMRegistration Register_ONE_HOT () {
234- return tflite::micro::RegisterOp (OneHotInit, OneHotPrepare, OneHotEval);
237+ const TFLMRegistration* Register_ONE_HOT () {
238+ static TFLMRegistration r =
239+ tflite::micro::RegisterOp (OneHotInit, OneHotPrepare, OneHotEval);
240+ return &r;
235241}
236242
237243} // namespace tflite
0 commit comments