@@ -28,9 +28,6 @@ limitations under the License.
2828
2929namespace tflite {
3030
31- const int kMaxNumberOfAxis = 5 ;
32- const int kMaxNumberOfReducedAxis = 2 ;
33-
3431namespace {
3532
3633TfLiteStatus PrepareSimple (TfLiteContext* context, TfLiteNode* node,
@@ -80,7 +77,7 @@ void ResolveAxis(const int* axis_data, int axis_count,
8077
8178template <typename T>
8279TfLiteStatus QuantizedMeanOrSum (TfLiteContext* context, TfLiteNode* node,
83- int * temp_index , int * resolved_axis,
80+ int * input_iter , int * resolved_axis,
8481 int32_t * temp_sum, OpDataReduce* op_data,
8582 bool compute_sum) {
8683 const TfLiteEvalTensor* input = tflite::micro::GetEvalInput (context, node, 0 );
@@ -96,7 +93,7 @@ TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
9693 op_data->multiplier , op_data->shift , op_data->output_zp ,
9794 &output->dims ->data [0 ], output->dims ->size ,
9895 tflite::micro::GetTensorData<int >(axis), op_data->num_axis ,
99- params->keep_dims , temp_index , resolved_axis, temp_sum, compute_sum);
96+ params->keep_dims , input_iter , resolved_axis, temp_sum, compute_sum);
10097 TF_LITE_ENSURE (context, result);
10198
10299 return kTfLiteOk ;
@@ -105,11 +102,11 @@ TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
105102template <typename integer_type>
106103TfLiteStatus EvalIntegerMean (TfLiteContext* context, TfLiteNode* node,
107104 int num_axis, OpDataReduce* op_data,
108- int * temp_index , int * resolved_axis) {
105+ int * input_iter , int * resolved_axis) {
109106 int32_t * temp_sum = static_cast <int32_t *>(
110- context->GetScratchBuffer (context, op_data->temp_buffer_idx ));
107+ context->GetScratchBuffer (context, op_data->scratch_accumulator_idx ));
111108
112- QuantizedMeanOrSum<integer_type>(context, node, temp_index , resolved_axis,
109+ QuantizedMeanOrSum<integer_type>(context, node, input_iter , resolved_axis,
113110 temp_sum, op_data, /* compute_sum=*/ false );
114111
115112 return kTfLiteOk ;
@@ -155,10 +152,10 @@ TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
155152
156153 // Interpret an axis tensor with null dimensions as a scalar
157154 int num_axis = static_cast <int >(ElementCount (*axis->dims ));
158- int * temp_buffer = static_cast <int *>(
159- context->GetScratchBuffer (context, op_data->temp_buffer_idx ));
155+ int * input_iter = static_cast <int *>(
156+ context->GetScratchBuffer (context, op_data->scratch_input_iter_idx ));
160157 int * resolved_axis = static_cast <int *>(
161- context->GetScratchBuffer (context, op_data->resolved_axis_idx ));
158+ context->GetScratchBuffer (context, op_data->scratch_resolved_axis_idx ));
162159 switch (input->type ) {
163160 case kTfLiteFloat32 : {
164161 MinMaxReducerCompare<float > reducer (evalType);
@@ -169,7 +166,7 @@ TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
169166 input->dims ->size , tflite::micro::GetTensorData<float >(output),
170167 output->dims ->data , output->dims ->size ,
171168 tflite::micro::GetTensorData<int >(axis), num_axis,
172- params->keep_dims , temp_buffer , resolved_axis,
169+ params->keep_dims , input_iter , resolved_axis,
173170 reducer.initialValue (), reducer.compare ()));
174171 } break ;
175172 case kTfLiteInt8 : {
@@ -184,7 +181,7 @@ TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
184181 input->dims ->size , tflite::micro::GetTensorData<int8_t >(output),
185182 output->dims ->data , output->dims ->size ,
186183 tflite::micro::GetTensorData<int >(axis), num_axis,
187- params->keep_dims , temp_buffer , resolved_axis,
184+ params->keep_dims , input_iter , resolved_axis,
188185 reducer.initialValue (), reducer.compare ()));
189186 } break ;
190187 default :
@@ -211,12 +208,11 @@ TfLiteStatus PrepareMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
211208 op_data->output_zp = output->params .zero_point ;
212209 op_data->output_scale = output->params .scale ;
213210 op_data->num_output_elements = NumElements (output);
214-
215211 context->RequestScratchBufferInArena (context, sizeof (int ) * input->dims ->size ,
216- &op_data->temp_buffer_idx );
212+ &op_data->scratch_input_iter_idx );
217213 context->RequestScratchBufferInArena (
218214 context, sizeof (int ) * static_cast <int >(ElementCount (*axis->dims )),
219- &op_data->resolved_axis_idx );
215+ &op_data->scratch_resolved_axis_idx );
220216
221217 micro_context->DeallocateTempTfLiteTensor (input);
222218 micro_context->DeallocateTempTfLiteTensor (output);
@@ -236,17 +232,22 @@ TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
236232 QuantizeMultiplier (real_multiplier, &op_data->multiplier , &op_data->shift );
237233 }
238234
239- int output_size = NumElements (output);
240235 op_data->num_axis = NumElements (axis);
236+ op_data->num_output_elements = NumElements (output);
241237
242238 if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16 ) {
243- context->RequestScratchBufferInArena (context, output_size * sizeof (int32_t ),
244- &op_data->temp_buffer_idx );
239+ context->RequestScratchBufferInArena (
240+ context, sizeof (int32_t ) * op_data->num_output_elements ,
241+ &op_data->scratch_accumulator_idx );
245242 op_data->input_zp = input->params .zero_point ;
246243 op_data->input_scale = input->params .scale ;
247244 op_data->output_zp = output->params .zero_point ;
248245 op_data->output_scale = output->params .scale ;
249246 }
247+ context->RequestScratchBufferInArena (context, sizeof (int ) * input->dims ->size ,
248+ &op_data->scratch_input_iter_idx );
249+ context->RequestScratchBufferInArena (context, sizeof (int ) * op_data->num_axis ,
250+ &op_data->scratch_resolved_axis_idx );
250251
251252 TF_LITE_ENSURE_OK (
252253 context,
@@ -274,12 +275,11 @@ TfLiteStatus PrepareAllHelper(TfLiteContext* context, TfLiteNode* node,
274275 op_data->output_zp = output->params .zero_point ;
275276 op_data->output_scale = output->params .scale ;
276277 op_data->num_output_elements = NumElements (output);
277-
278278 context->RequestScratchBufferInArena (context, sizeof (int ) * input->dims ->size ,
279- &op_data->temp_buffer_idx );
279+ &op_data->scratch_input_iter_idx );
280280 context->RequestScratchBufferInArena (
281281 context, sizeof (int ) * static_cast <int >(ElementCount (*axis->dims )),
282- &op_data->resolved_axis_idx );
282+ &op_data->scratch_resolved_axis_idx );
283283
284284 micro_context->DeallocateTempTfLiteTensor (input);
285285 micro_context->DeallocateTempTfLiteTensor (output);
@@ -296,8 +296,10 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
296296 reinterpret_cast <TfLiteReducerParams*>(node->builtin_data );
297297
298298 int num_axis = static_cast <int >(ElementCount (*axis->dims ));
299- int temp_index[kMaxNumberOfAxis ];
300- int resolved_axis[kMaxNumberOfReducedAxis ];
299+ int * input_iter = static_cast <int *>(
300+ context->GetScratchBuffer (context, op_data->scratch_input_iter_idx ));
301+ int * resolved_axis = static_cast <int *>(
302+ context->GetScratchBuffer (context, op_data->scratch_resolved_axis_idx ));
301303
302304 switch (input->type ) {
303305 case kTfLiteFloat32 : {
@@ -326,19 +328,19 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
326328 input->dims ->size , tflite::micro::GetTensorData<float >(output),
327329 output->dims ->data , output->dims ->size ,
328330 tflite::micro::GetTensorData<int >(axis), num_axis,
329- params->keep_dims , temp_index , resolved_axis,
331+ params->keep_dims , input_iter , resolved_axis,
330332 tflite::micro::GetTensorData<float >(output)));
331333 }
332334 } break ;
333335 case kTfLiteInt8 : {
334336 TF_LITE_ENSURE_OK (
335337 context, EvalIntegerMean<int8_t >(context, node, num_axis, op_data,
336- temp_index , resolved_axis));
338+ input_iter , resolved_axis));
337339 } break ;
338340 case kTfLiteInt16 : {
339341 TF_LITE_ENSURE_OK (
340342 context, EvalIntegerMean<int16_t >(context, node, num_axis, op_data,
341- temp_index , resolved_axis));
343+ input_iter , resolved_axis));
342344 } break ;
343345 default :
344346 TF_LITE_ENSURE_MSG (context, false ,
@@ -369,8 +371,10 @@ TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
369371
370372 // Interpret an axis tensor with null dimensions as a scalar.
371373 int num_axis = static_cast <int >(ElementCount (*axis->dims ));
372- int temp_index[kMaxNumberOfAxis ];
373- int resolved_axis[kMaxNumberOfReducedAxis ];
374+ int * input_iter = static_cast <int *>(
375+ context->GetScratchBuffer (context, op_data->scratch_input_iter_idx ));
376+ int * resolved_axis = static_cast <int *>(
377+ context->GetScratchBuffer (context, op_data->scratch_resolved_axis_idx ));
374378
375379 switch (input->type ) {
376380 case kTfLiteFloat32 : {
@@ -381,21 +385,21 @@ TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
381385 input->dims ->size , tflite::micro::GetTensorData<float >(output),
382386 output->dims ->data , output->dims ->size ,
383387 tflite::micro::GetTensorData<int >(axis), num_axis,
384- params->keep_dims , temp_index , resolved_axis, /* init_value=*/ 0 .f ,
388+ params->keep_dims , input_iter , resolved_axis, /* init_value=*/ 0 .f ,
385389 [](const float current, const float in) -> float {
386390 return in + current;
387391 }));
388392 } break ;
389393 case kTfLiteInt8 : {
390394 int32_t * temp_sum = static_cast <int32_t *>(
391- context->GetScratchBuffer (context, op_data->temp_buffer_idx ));
392- QuantizedMeanOrSum<int8_t >(context, node, temp_index , resolved_axis,
395+ context->GetScratchBuffer (context, op_data->scratch_accumulator_idx ));
396+ QuantizedMeanOrSum<int8_t >(context, node, input_iter , resolved_axis,
393397 temp_sum, op_data, /* compute_sum=*/ true );
394398 } break ;
395399 case kTfLiteInt16 : {
396400 int32_t * temp_sum = static_cast <int32_t *>(
397- context->GetScratchBuffer (context, op_data->temp_buffer_idx ));
398- QuantizedMeanOrSum<int16_t >(context, node, temp_index , resolved_axis,
401+ context->GetScratchBuffer (context, op_data->scratch_accumulator_idx ));
402+ QuantizedMeanOrSum<int16_t >(context, node, input_iter , resolved_axis,
399403 temp_sum, op_data, /* compute_sum=*/ true );
400404 } break ;
401405 default :
@@ -416,10 +420,10 @@ TfLiteStatus EvalAllHelper(TfLiteContext* context, TfLiteNode* node,
416420
417421 // Interpret an axis tensor with null dimensions as a scalar
418422 int num_axis = static_cast <int >(ElementCount (*axis->dims ));
419- int * temp_buffer = static_cast <int *>(
420- context->GetScratchBuffer (context, op_data->temp_buffer_idx ));
423+ int * input_iter = static_cast <int *>(
424+ context->GetScratchBuffer (context, op_data->scratch_input_iter_idx ));
421425 int * resolved_axis = static_cast <int *>(
422- context->GetScratchBuffer (context, op_data->resolved_axis_idx ));
426+ context->GetScratchBuffer (context, op_data->scratch_resolved_axis_idx ));
423427 switch (input->type ) {
424428 case kTfLiteBool :
425429 TF_LITE_ENSURE (
@@ -429,7 +433,7 @@ TfLiteStatus EvalAllHelper(TfLiteContext* context, TfLiteNode* node,
429433 input->dims ->size , tflite::micro::GetTensorData<bool >(output),
430434 output->dims ->data , output->dims ->size ,
431435 tflite::micro::GetTensorData<int >(axis), num_axis,
432- params->keep_dims , temp_buffer , resolved_axis, true ,
436+ params->keep_dims , input_iter , resolved_axis, true ,
433437 [](const bool current, const bool in) -> bool {
434438 return in && current;
435439 }));
0 commit comments