@@ -190,12 +190,54 @@ void RmsNormKernel(const Context& dev_ctx,
190190 runner.Run (reinterpret_cast <C_Stream>(dev_ctx.stream ()), tensors);
191191}
192192
193+ template <typename T, typename Context>
194+ void FusedRmsNormQuantKernel (
195+ const Context& dev_ctx,
196+ const phi::DenseTensor& x,
197+ const paddle::optional<phi::DenseTensor>& bias,
198+ const paddle::optional<phi::DenseTensor>& residual,
199+ const phi::DenseTensor& norm_weight,
200+ const paddle::optional<phi::DenseTensor>& norm_bias,
201+ const float epsilon,
202+ const int begin_norm_axis,
203+ const float quant_scale,
204+ const int quant_round_type,
205+ const float quant_max_bound,
206+ const float quant_min_bound,
207+ phi::DenseTensor* out,
208+ phi::DenseTensor* residual_out,
209+ phi::DenseTensor* inv_var) {
210+ custom_kernel::RmsNormKernel<T, Context>(dev_ctx,
211+ x,
212+ bias,
213+ residual,
214+ norm_weight,
215+ norm_bias,
216+ epsilon,
217+ begin_norm_axis,
218+ quant_scale,
219+ quant_round_type,
220+ quant_max_bound,
221+ quant_min_bound,
222+ out,
223+ residual_out,
224+ inv_var);
225+ }
193226} // namespace custom_kernel
194227
195- PD_REGISTER_PLUGIN_KERNEL (fused_rms_norm_quant,
228+ // Add the original kernel name rms_norm for compatibility with Paddle 3.2.2
229+ PD_REGISTER_PLUGIN_KERNEL (rms_norm,
196230 intel_hpu,
197231 ALL_LAYOUT,
198232 custom_kernel::RmsNormKernel,
199233 float ,
200234 phi::dtype::float16,
201235 phi::dtype::bfloat16) {}
236+
237+ PD_REGISTER_PLUGIN_KERNEL (fused_rms_norm_quant,
238+ intel_hpu,
239+ ALL_LAYOUT,
240+ custom_kernel::FusedRmsNormQuantKernel,
241+ float ,
242+ phi::dtype::float16,
243+ phi::dtype::bfloat16) {}
0 commit comments