Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions tests/filecheck/backend/riscv/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ builtin.module {
%c1 = rv32.li 1 : !riscv.reg
%c2 = rv32.li 2 : !riscv.reg
%c3 = rv32.li 3 : !riscv.reg
%c_neg1 = rv32.li -1 : !riscv.reg

// Don't optimise out unused immediates
"test.op"(%zero, %zero_rv64, %c0, %c1, %c2, %c3) : (!riscv.reg<zero>, !riscv.reg<zero>, !riscv.reg, !riscv.reg, !riscv.reg, !riscv.reg) -> ()
"test.op"(%zero, %zero_rv64, %c0, %c1, %c2, %c3, %c_neg1) : (!riscv.reg<zero>, !riscv.reg<zero>, !riscv.reg, !riscv.reg, !riscv.reg, !riscv.reg, !riscv.reg) -> ()

%load_zero_zero = rv32.li 0 : !riscv.reg<zero>
"test.op"(%load_zero_zero) : (!riscv.reg<zero>) -> ()
Expand Down Expand Up @@ -104,6 +105,16 @@ builtin.module {
%shift_right_immediate = riscv.srli %shift_left_immediate, 3 : (!riscv.reg<a0>) -> !riscv.reg<a0>
"test.op"(%shift_right_immediate) : (!riscv.reg<a0>) -> ()

// Check shifts with signed numbers
%srli_neg = riscv.srli %c_neg1, 1 : (!riscv.reg) -> !riscv.reg<a0>
"test.op"(%srli_neg) : (!riscv.reg<a0>) -> ()

%srai_neg = riscv.srai %c_neg1, 1 : (!riscv.reg) -> !riscv.reg<a0>
"test.op"(%srai_neg) : (!riscv.reg<a0>) -> ()

%slli_neg = riscv.slli %c_neg1, 1 : (!riscv.reg) -> !riscv.reg<a0>
"test.op"(%slli_neg) : (!riscv.reg<a0>) -> ()

%load_float_ptr = riscv.addi %i2, 8 : (!riscv.reg) -> !riscv.reg
%load_float_known_offset = riscv.flw %load_float_ptr, 4 : (!riscv.reg) -> !riscv.freg<fa0>
"test.op"(%load_float_known_offset) : (!riscv.freg<fa0>) -> ()
Expand Down Expand Up @@ -203,7 +214,8 @@ builtin.module {
// CHECK-NEXT: %c1 = rv32.li 1 : !riscv.reg
// CHECK-NEXT: %c2 = rv32.li 2 : !riscv.reg
// CHECK-NEXT: %c3 = rv32.li 3 : !riscv.reg
// CHECK-NEXT: "test.op"(%zero, %zero_rv64, %c0_1, %c1, %c2, %c3) : (!riscv.reg<zero>, !riscv.reg<zero>, !riscv.reg, !riscv.reg, !riscv.reg, !riscv.reg) -> ()
// CHECK-NEXT: %c_neg1 = rv32.li -1 : !riscv.reg
// CHECK-NEXT: "test.op"(%zero, %zero_rv64, %c0_1, %c1, %c2, %c3, %c_neg1) : (!riscv.reg<zero>, !riscv.reg<zero>, !riscv.reg, !riscv.reg, !riscv.reg, !riscv.reg, !riscv.reg) -> ()

// CHECK-NEXT: %load_zero_zero = rv32.get_register : !riscv.reg<zero>
// CHECK-NEXT: "test.op"(%load_zero_zero) : (!riscv.reg<zero>) -> ()
Expand Down Expand Up @@ -279,6 +291,15 @@ builtin.module {
// CHECK-NEXT: %shift_right_immediate = rv32.li 4 : !riscv.reg<a0>
// CHECK-NEXT: "test.op"(%shift_right_immediate) : (!riscv.reg<a0>) -> ()

// CHECK-NEXT: %srli_neg = rv32.li 2147483647 : !riscv.reg<a0>
// CHECK-NEXT: "test.op"(%srli_neg) : (!riscv.reg<a0>) -> ()

// CHECK-NEXT: %srai_neg = rv32.li -1 : !riscv.reg<a0>
// CHECK-NEXT: "test.op"(%srai_neg) : (!riscv.reg<a0>) -> ()

// CHECK-NEXT: %slli_neg = rv32.li -2 : !riscv.reg<a0>
// CHECK-NEXT: "test.op"(%slli_neg) : (!riscv.reg<a0>) -> ()

// CHECK-NEXT: %load_float_known_offset = riscv.flw %i2, 12 : (!riscv.reg) -> !riscv.freg<fa0>
// CHECK-NEXT: "test.op"(%load_float_known_offset) : (!riscv.freg<fa0>) -> ()

Expand Down
3 changes: 3 additions & 0 deletions xdsl/dialects/riscv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from .abstract_ops import (
AssemblyInstructionArg as AssemblyInstructionArg,
)
from .abstract_ops import (
RdRsImmShiftOperation as RdRsImmShiftOperation,
)
from .abstract_ops import (
RdRsRsFloatOperationWithFastMath as RdRsRsFloatOperationWithFastMath,
)
Expand Down
18 changes: 17 additions & 1 deletion xdsl/dialects/riscv/abstract_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from xdsl.backend.register_type import RegisterAllocatedMemoryEffect, RegisterType
from xdsl.dialects.builtin import (
I32,
IntegerAttr,
IntegerType,
ModuleOp,
Expand Down Expand Up @@ -726,9 +727,10 @@ class ImmShiftOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrai
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.riscv import (
ShiftbyZero,
ShiftConstantFolding,
)

return (ShiftbyZero(),)
return (ShiftbyZero(), ShiftConstantFolding())


class RdRsImmShiftOperation(RISCVInstruction, ABC):
Expand Down Expand Up @@ -779,6 +781,20 @@ def __init__(
def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
return self.rd, self.rs1, self.immediate

@abstractmethod
def py_operation(self, rs1: IntegerAttr[I32]) -> IntegerAttr[I32]:
"""
Performs a python function corresponding to this operation.

If `i := py_operation(rs1)` is an IntegerAttr[I32], then this operation can be
canonicalized to a constant with value `i` when the inputs are constants
with values `rs1`. The immediate value is retrieved from the `immediate` attribute of the operation.
"""

raise NotImplementedError(
"RdRsImmShiftOperation py_operation is not yet implemented"
)


class RdRsImmBitManipOperation(RISCVCustomFormatOperation, RISCVInstruction, ABC):
"""
Expand Down
31 changes: 9 additions & 22 deletions xdsl/dialects/riscv/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,6 @@ class XoriOp(RdRsImmIntegerOperation):
traits = traits_def(XoriOpHasCanonicalizationPatternsTrait())


class SlliOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.riscv import (
ShiftLeftImmediate,
)

return (ShiftLeftImmediate(),)


@irdl_op_definition
class SlliOp(RdRsImmShiftOperation):
"""
Expand All @@ -258,17 +248,8 @@ class SlliOp(RdRsImmShiftOperation):

name = "riscv.slli"

traits = traits_def(SlliOpHasCanonicalizationPatternsTrait())


class SrliOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.riscv import (
ShiftRightImmediate,
)

return (ShiftRightImmediate(),)
def py_operation(self, rs1: IntegerAttr[I32]) -> IntegerAttr[I32]:
return IntegerAttr(rs1.value.data << self.immediate.value.data, i32)


@irdl_op_definition
Expand All @@ -284,7 +265,10 @@ class SrliOp(RdRsImmShiftOperation):

name = "riscv.srli"

traits = traits_def(SrliOpHasCanonicalizationPatternsTrait())
def py_operation(self, rs1: IntegerAttr[I32]) -> IntegerAttr[I32]:
return IntegerAttr(
(rs1.value.data % 0x100000000) >> self.immediate.value.data, i32
)


@irdl_op_definition
Expand All @@ -300,6 +284,9 @@ class SraiOp(RdRsImmShiftOperation):

name = "riscv.srai"

def py_operation(self, rs1: IntegerAttr[I32]) -> IntegerAttr[I32]:
return IntegerAttr(rs1.value.data >> self.immediate.value.data, i32)


@irdl_op_definition
class AddiwOp(RdRsImmIntegerOperation):
Expand Down
40 changes: 17 additions & 23 deletions xdsl/transforms/canonicalization_patterns/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,42 +296,36 @@ def match_and_rewrite(self, op: riscv.XoriOp, rewriter: PatternRewriter) -> None
)


class ShiftLeftImmediate(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv.SlliOp, rewriter: PatternRewriter) -> None:
if (rs1 := get_constant_value(op.rs1)) is not None:
rd = op.rd.type
rewriter.replace_op(
op,
rv32.LiOp(rs1.value.data << op.immediate.value.data, rd=rd),
)


class ShiftRightImmediate(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv.SrliOp, rewriter: PatternRewriter) -> None:
if (rs1 := get_constant_value(op.rs1)) is not None:
rd = op.rd.type
rewriter.replace_op(
op,
rv32.LiOp(rs1.value.data >> op.immediate.value.data, rd=rd),
)


class ShiftbyZero(RewritePattern):
"""
shift(x, 0) -> x
"""

@op_type_rewrite_pattern
def match_and_rewrite(
self, op: riscv.SlliOp | riscv.SrliOp | riscv.SraiOp, rewriter: PatternRewriter
self, op: riscv.RdRsImmShiftOperation, rewriter: PatternRewriter
) -> None:
# check if the shift amount is zero
if op.immediate.value.data == 0:
rewriter.replace_op(op, riscv.MVOp(op.rs1, rd=op.rd.type))


class ShiftConstantFolding(RewritePattern):
"""
shift(c1, c2) -> c3
"""

@op_type_rewrite_pattern
def match_and_rewrite(
self, op: riscv.RdRsImmShiftOperation, rewriter: PatternRewriter
) -> None:
if (rs1 := get_constant_value(op.rs1)) is not None:
rd = op.rd.type
val = cast(IntegerAttr[I32], rs1)
result = op.py_operation(val)
rewriter.replace_op(op, rv32.LiOp(result, rd=rd))


class LoadWordWithKnownOffset(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv.LwOp, rewriter: PatternRewriter) -> None:
Expand Down