@@ -606,32 +606,16 @@ def _broadcasted_iota_op_lowering_rule(
606606 return [fragmented_array_to_ir (a , result_type )]
607607
608608
609- @_register_lowering (vector .BroadcastOp )
610- def _vector_splat_op_lowering_rule (
611- _ : LoweringContext , vector_splat_op : vector .BroadcastOp
612- ) -> Sequence [ir .Value ]:
613-
614- out_vec_ty = ir .VectorType (vector_splat_op .aggregate .type )
615- fragmented_array = fa .FragmentedArray .splat (
616- vector_splat_op .input ,
617- tuple (out_vec_ty .shape ),
618- layouts .from_layout_attr (vector_splat_op .attributes ["out_layouts" ][0 ]),
619- is_signed = _default_is_signed (out_vec_ty .element_type ),
620- )
621- return [fragmented_array_to_ir (fragmented_array , out_vec_ty )]
622-
623-
624609@_register_lowering (vector .BroadcastOp )
625610def _vector_broadcast_op_lowering_rule (
626- _ : LoweringContext , vector_broadcast_op : vector .BroadcastOp
611+ _ : LoweringContext , op : vector .BroadcastOp
627612) -> Sequence [ir .Value ]:
628-
629- out_vec_ty = ir .VectorType (vector_broadcast_op .vector .type )
613+ out_vec_ty = ir .VectorType (op .vector .type )
630614 fragmented_array = fa .FragmentedArray .splat (
631- vector_broadcast_op .source ,
615+ op .source ,
632616 tuple (out_vec_ty .shape ),
633617 layouts .from_layout_attr (
634- vector_broadcast_op .attributes ["out_layouts" ][0 ]
618+ op .attributes ["out_layouts" ][0 ]
635619 ),
636620 is_signed = _default_is_signed (out_vec_ty .element_type ),
637621 )
0 commit comments