Skip to content

Commit 6860386

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU][NFC] Remove duplicate lowering rule for vector.BroadcastOp.
PiperOrigin-RevId: 845253825
1 parent b8747f9 commit 6860386

File tree

1 file changed

+4
-20
lines changed

1 file changed

+4
-20
lines changed

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -606,32 +606,16 @@ def _broadcasted_iota_op_lowering_rule(
606606
return [fragmented_array_to_ir(a, result_type)]
607607

608608

609-
@_register_lowering(vector.BroadcastOp)
610-
def _vector_splat_op_lowering_rule(
611-
_: LoweringContext, vector_splat_op: vector.BroadcastOp
612-
) -> Sequence[ir.Value]:
613-
614-
out_vec_ty = ir.VectorType(vector_splat_op.aggregate.type)
615-
fragmented_array = fa.FragmentedArray.splat(
616-
vector_splat_op.input,
617-
tuple(out_vec_ty.shape),
618-
layouts.from_layout_attr(vector_splat_op.attributes["out_layouts"][0]),
619-
is_signed=_default_is_signed(out_vec_ty.element_type),
620-
)
621-
return [fragmented_array_to_ir(fragmented_array, out_vec_ty)]
622-
623-
624609
@_register_lowering(vector.BroadcastOp)
625610
def _vector_broadcast_op_lowering_rule(
626-
_: LoweringContext, vector_broadcast_op: vector.BroadcastOp
611+
_: LoweringContext, op: vector.BroadcastOp
627612
) -> Sequence[ir.Value]:
628-
629-
out_vec_ty = ir.VectorType(vector_broadcast_op.vector.type)
613+
out_vec_ty = ir.VectorType(op.vector.type)
630614
fragmented_array = fa.FragmentedArray.splat(
631-
vector_broadcast_op.source,
615+
op.source,
632616
tuple(out_vec_ty.shape),
633617
layouts.from_layout_attr(
634-
vector_broadcast_op.attributes["out_layouts"][0]
618+
op.attributes["out_layouts"][0]
635619
),
636620
is_signed=_default_is_signed(out_vec_ty.element_type),
637621
)

0 commit comments

Comments
 (0)