Skip to content

Commit 0c1ef1a

Browse files
authored
[Intel HPU] add the original kernel name rms_norm for compatibility with Paddle 3.2.2 (#2333)
1 parent 2a5cb18 commit 0c1ef1a

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

backends/intel_hpu/kernels/rms_norm_kernel.cc

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,54 @@ void RmsNormKernel(const Context& dev_ctx,
190190
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors);
191191
}
192192

193+
template <typename T, typename Context>
194+
void FusedRmsNormQuantKernel(
195+
const Context& dev_ctx,
196+
const phi::DenseTensor& x,
197+
const paddle::optional<phi::DenseTensor>& bias,
198+
const paddle::optional<phi::DenseTensor>& residual,
199+
const phi::DenseTensor& norm_weight,
200+
const paddle::optional<phi::DenseTensor>& norm_bias,
201+
const float epsilon,
202+
const int begin_norm_axis,
203+
const float quant_scale,
204+
const int quant_round_type,
205+
const float quant_max_bound,
206+
const float quant_min_bound,
207+
phi::DenseTensor* out,
208+
phi::DenseTensor* residual_out,
209+
phi::DenseTensor* inv_var) {
210+
custom_kernel::RmsNormKernel<T, Context>(dev_ctx,
211+
x,
212+
bias,
213+
residual,
214+
norm_weight,
215+
norm_bias,
216+
epsilon,
217+
begin_norm_axis,
218+
quant_scale,
219+
quant_round_type,
220+
quant_max_bound,
221+
quant_min_bound,
222+
out,
223+
residual_out,
224+
inv_var);
225+
}
193226
} // namespace custom_kernel
194227

195-
PD_REGISTER_PLUGIN_KERNEL(fused_rms_norm_quant,
228+
// Add the original kernel name rms_norm for compatibility with Paddle 3.2.2
229+
PD_REGISTER_PLUGIN_KERNEL(rms_norm,
196230
intel_hpu,
197231
ALL_LAYOUT,
198232
custom_kernel::RmsNormKernel,
199233
float,
200234
phi::dtype::float16,
201235
phi::dtype::bfloat16) {}
236+
237+
PD_REGISTER_PLUGIN_KERNEL(fused_rms_norm_quant,
238+
intel_hpu,
239+
ALL_LAYOUT,
240+
custom_kernel::FusedRmsNormQuantKernel,
241+
float,
242+
phi::dtype::float16,
243+
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)