Skip to content

Commit 5bdb940

Browse files
committed
Refactor: Use scratch buffers for reduce kernel
1 parent b63a3ac commit 5bdb940

File tree

2 files changed

+45
-43
lines changed

2 files changed

+45
-43
lines changed

tensorflow/lite/micro/kernels/reduce.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,12 @@ limitations under the License.
2424

2525
namespace tflite {
2626

27-
extern const int kMaxNumberOfAxis;
28-
extern const int kMaxNumberOfReducedAxis;
29-
3027
struct OpDataReduce {
3128
int32_t multiplier;
3229
int shift;
33-
int temp_buffer_idx;
34-
int resolved_axis_idx;
30+
int scratch_accumulator_idx;
31+
int scratch_resolved_axis_idx;
32+
int scratch_input_iter_idx;
3533
int input_zp;
3634
float input_scale;
3735
int output_zp;

tensorflow/lite/micro/kernels/reduce_common.cc

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ limitations under the License.
2828

2929
namespace tflite {
3030

31-
const int kMaxNumberOfAxis = 5;
32-
const int kMaxNumberOfReducedAxis = 2;
33-
3431
namespace {
3532

3633
TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,
@@ -80,7 +77,7 @@ void ResolveAxis(const int* axis_data, int axis_count,
8077

8178
template <typename T>
8279
TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
83-
int* temp_index, int* resolved_axis,
80+
int* input_iter, int* resolved_axis,
8481
int32_t* temp_sum, OpDataReduce* op_data,
8582
bool compute_sum) {
8683
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
@@ -96,7 +93,7 @@ TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
9693
op_data->multiplier, op_data->shift, op_data->output_zp,
9794
&output->dims->data[0], output->dims->size,
9895
tflite::micro::GetTensorData<int>(axis), op_data->num_axis,
99-
params->keep_dims, temp_index, resolved_axis, temp_sum, compute_sum);
96+
params->keep_dims, input_iter, resolved_axis, temp_sum, compute_sum);
10097
TF_LITE_ENSURE(context, result);
10198

10299
return kTfLiteOk;
@@ -105,11 +102,11 @@ TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
105102
template <typename integer_type>
106103
TfLiteStatus EvalIntegerMean(TfLiteContext* context, TfLiteNode* node,
107104
int num_axis, OpDataReduce* op_data,
108-
int* temp_index, int* resolved_axis) {
105+
int* input_iter, int* resolved_axis) {
109106
int32_t* temp_sum = static_cast<int32_t*>(
110-
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
107+
context->GetScratchBuffer(context, op_data->scratch_accumulator_idx));
111108

112-
QuantizedMeanOrSum<integer_type>(context, node, temp_index, resolved_axis,
109+
QuantizedMeanOrSum<integer_type>(context, node, input_iter, resolved_axis,
113110
temp_sum, op_data, /*compute_sum=*/false);
114111

115112
return kTfLiteOk;
@@ -155,10 +152,10 @@ TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
155152

156153
// Interpret an axis tensor with null dimensions as a scalar
157154
int num_axis = static_cast<int>(ElementCount(*axis->dims));
158-
int* temp_buffer = static_cast<int*>(
159-
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
155+
int* input_iter = static_cast<int*>(
156+
context->GetScratchBuffer(context, op_data->scratch_input_iter_idx));
160157
int* resolved_axis = static_cast<int*>(
161-
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
158+
context->GetScratchBuffer(context, op_data->scratch_resolved_axis_idx));
162159
switch (input->type) {
163160
case kTfLiteFloat32: {
164161
MinMaxReducerCompare<float> reducer(evalType);
@@ -169,7 +166,7 @@ TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
169166
input->dims->size, tflite::micro::GetTensorData<float>(output),
170167
output->dims->data, output->dims->size,
171168
tflite::micro::GetTensorData<int>(axis), num_axis,
172-
params->keep_dims, temp_buffer, resolved_axis,
169+
params->keep_dims, input_iter, resolved_axis,
173170
reducer.initialValue(), reducer.compare()));
174171
} break;
175172
case kTfLiteInt8: {
@@ -184,7 +181,7 @@ TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
184181
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
185182
output->dims->data, output->dims->size,
186183
tflite::micro::GetTensorData<int>(axis), num_axis,
187-
params->keep_dims, temp_buffer, resolved_axis,
184+
params->keep_dims, input_iter, resolved_axis,
188185
reducer.initialValue(), reducer.compare()));
189186
} break;
190187
default:
@@ -211,12 +208,11 @@ TfLiteStatus PrepareMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
211208
op_data->output_zp = output->params.zero_point;
212209
op_data->output_scale = output->params.scale;
213210
op_data->num_output_elements = NumElements(output);
214-
215211
context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
216-
&op_data->temp_buffer_idx);
212+
&op_data->scratch_input_iter_idx);
217213
context->RequestScratchBufferInArena(
218214
context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
219-
&op_data->resolved_axis_idx);
215+
&op_data->scratch_resolved_axis_idx);
220216

221217
micro_context->DeallocateTempTfLiteTensor(input);
222218
micro_context->DeallocateTempTfLiteTensor(output);
@@ -236,17 +232,22 @@ TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
236232
QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift);
237233
}
238234

239-
int output_size = NumElements(output);
240235
op_data->num_axis = NumElements(axis);
236+
op_data->num_output_elements = NumElements(output);
241237

242238
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
243-
context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
244-
&op_data->temp_buffer_idx);
239+
context->RequestScratchBufferInArena(
240+
context, sizeof(int32_t) * op_data->num_output_elements,
241+
&op_data->scratch_accumulator_idx);
245242
op_data->input_zp = input->params.zero_point;
246243
op_data->input_scale = input->params.scale;
247244
op_data->output_zp = output->params.zero_point;
248245
op_data->output_scale = output->params.scale;
249246
}
247+
context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
248+
&op_data->scratch_input_iter_idx);
249+
context->RequestScratchBufferInArena(context, sizeof(int) * op_data->num_axis,
250+
&op_data->scratch_resolved_axis_idx);
250251

251252
TF_LITE_ENSURE_OK(
252253
context,
@@ -274,12 +275,11 @@ TfLiteStatus PrepareAllHelper(TfLiteContext* context, TfLiteNode* node,
274275
op_data->output_zp = output->params.zero_point;
275276
op_data->output_scale = output->params.scale;
276277
op_data->num_output_elements = NumElements(output);
277-
278278
context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
279-
&op_data->temp_buffer_idx);
279+
&op_data->scratch_input_iter_idx);
280280
context->RequestScratchBufferInArena(
281281
context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
282-
&op_data->resolved_axis_idx);
282+
&op_data->scratch_resolved_axis_idx);
283283

284284
micro_context->DeallocateTempTfLiteTensor(input);
285285
micro_context->DeallocateTempTfLiteTensor(output);
@@ -296,8 +296,10 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
296296
reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
297297

298298
int num_axis = static_cast<int>(ElementCount(*axis->dims));
299-
int temp_index[kMaxNumberOfAxis];
300-
int resolved_axis[kMaxNumberOfReducedAxis];
299+
int* input_iter = static_cast<int*>(
300+
context->GetScratchBuffer(context, op_data->scratch_input_iter_idx));
301+
int* resolved_axis = static_cast<int*>(
302+
context->GetScratchBuffer(context, op_data->scratch_resolved_axis_idx));
301303

302304
switch (input->type) {
303305
case kTfLiteFloat32: {
@@ -326,19 +328,19 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
326328
input->dims->size, tflite::micro::GetTensorData<float>(output),
327329
output->dims->data, output->dims->size,
328330
tflite::micro::GetTensorData<int>(axis), num_axis,
329-
params->keep_dims, temp_index, resolved_axis,
331+
params->keep_dims, input_iter, resolved_axis,
330332
tflite::micro::GetTensorData<float>(output)));
331333
}
332334
} break;
333335
case kTfLiteInt8: {
334336
TF_LITE_ENSURE_OK(
335337
context, EvalIntegerMean<int8_t>(context, node, num_axis, op_data,
336-
temp_index, resolved_axis));
338+
input_iter, resolved_axis));
337339
} break;
338340
case kTfLiteInt16: {
339341
TF_LITE_ENSURE_OK(
340342
context, EvalIntegerMean<int16_t>(context, node, num_axis, op_data,
341-
temp_index, resolved_axis));
343+
input_iter, resolved_axis));
342344
} break;
343345
default:
344346
TF_LITE_ENSURE_MSG(context, false,
@@ -369,8 +371,10 @@ TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
369371

370372
// Interpret an axis tensor with null dimensions as a scalar.
371373
int num_axis = static_cast<int>(ElementCount(*axis->dims));
372-
int temp_index[kMaxNumberOfAxis];
373-
int resolved_axis[kMaxNumberOfReducedAxis];
374+
int* input_iter = static_cast<int*>(
375+
context->GetScratchBuffer(context, op_data->scratch_input_iter_idx));
376+
int* resolved_axis = static_cast<int*>(
377+
context->GetScratchBuffer(context, op_data->scratch_resolved_axis_idx));
374378

375379
switch (input->type) {
376380
case kTfLiteFloat32: {
@@ -381,21 +385,21 @@ TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
381385
input->dims->size, tflite::micro::GetTensorData<float>(output),
382386
output->dims->data, output->dims->size,
383387
tflite::micro::GetTensorData<int>(axis), num_axis,
384-
params->keep_dims, temp_index, resolved_axis, /*init_value=*/0.f,
388+
params->keep_dims, input_iter, resolved_axis, /*init_value=*/0.f,
385389
[](const float current, const float in) -> float {
386390
return in + current;
387391
}));
388392
} break;
389393
case kTfLiteInt8: {
390394
int32_t* temp_sum = static_cast<int32_t*>(
391-
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
392-
QuantizedMeanOrSum<int8_t>(context, node, temp_index, resolved_axis,
395+
context->GetScratchBuffer(context, op_data->scratch_accumulator_idx));
396+
QuantizedMeanOrSum<int8_t>(context, node, input_iter, resolved_axis,
393397
temp_sum, op_data, /*compute_sum=*/true);
394398
} break;
395399
case kTfLiteInt16: {
396400
int32_t* temp_sum = static_cast<int32_t*>(
397-
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
398-
QuantizedMeanOrSum<int16_t>(context, node, temp_index, resolved_axis,
401+
context->GetScratchBuffer(context, op_data->scratch_accumulator_idx));
402+
QuantizedMeanOrSum<int16_t>(context, node, input_iter, resolved_axis,
399403
temp_sum, op_data, /*compute_sum=*/true);
400404
} break;
401405
default:
@@ -416,10 +420,10 @@ TfLiteStatus EvalAllHelper(TfLiteContext* context, TfLiteNode* node,
416420

417421
// Interpret an axis tensor with null dimensions as a scalar
418422
int num_axis = static_cast<int>(ElementCount(*axis->dims));
419-
int* temp_buffer = static_cast<int*>(
420-
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
423+
int* input_iter = static_cast<int*>(
424+
context->GetScratchBuffer(context, op_data->scratch_input_iter_idx));
421425
int* resolved_axis = static_cast<int*>(
422-
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
426+
context->GetScratchBuffer(context, op_data->scratch_resolved_axis_idx));
423427
switch (input->type) {
424428
case kTfLiteBool:
425429
TF_LITE_ENSURE(
@@ -429,7 +433,7 @@ TfLiteStatus EvalAllHelper(TfLiteContext* context, TfLiteNode* node,
429433
input->dims->size, tflite::micro::GetTensorData<bool>(output),
430434
output->dims->data, output->dims->size,
431435
tflite::micro::GetTensorData<int>(axis), num_axis,
432-
params->keep_dims, temp_buffer, resolved_axis, true,
436+
params->keep_dims, input_iter, resolved_axis, true,
433437
[](const bool current, const bool in) -> bool {
434438
return in && current;
435439
}));

0 commit comments

Comments
 (0)