Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions tests/dialects/test_arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,26 +130,21 @@ def test_Cmpi_from_mnemonic(self, input: str):


def test_constant_construction():
c1 = ConstantOp(IntegerAttr(1, i32))
attr1 = IntegerAttr(1, i32)
c1 = ConstantOp(attr1)
assert c1.value.type == i32
constantlike1 = c1.get_trait(ConstantLike)
assert constantlike1 is not None
assert constantlike1.get_constant_value(c1) == IntegerAttr(1, i32)
assert ConstantLike.get_constant_value(c1.result) == attr1

c3 = ConstantOp(FloatAttr(1.0, f32))
assert c3.value.type == f32
constantlike3 = c3.get_trait(ConstantLike)
assert constantlike3 is not None
assert constantlike3.get_constant_value(c3) == FloatAttr(1.0, f32)
attr2 = FloatAttr(1.0, f32)
c2 = ConstantOp(attr2)
assert c2.value.type == f32
assert ConstantLike.get_constant_value(c2.result) == attr2

value_type = TensorType(i32, [2, 2])
c5 = ConstantOp(DenseIntOrFPElementsAttr.from_list(value_type, [1, 2, 3, 4]))
assert c5.value.type == value_type
constantlike5 = c5.get_trait(ConstantLike)
assert constantlike5 is not None
assert constantlike5.get_constant_value(c5) == DenseIntOrFPElementsAttr.from_list(
value_type, [1, 2, 3, 4]
)
attr3 = DenseIntOrFPElementsAttr.from_list(value_type, [1, 2, 3, 4])
c3 = ConstantOp(attr3)
assert c3.value.type == value_type
assert ConstantLike.get_constant_value(c3.result) == attr3


@pytest.mark.parametrize(
Expand Down
8 changes: 4 additions & 4 deletions tests/dialects/test_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@


def test_constant_construction():
value = ArrayAttr([IntAttr(42), IntAttr(43)])
c1 = complex.ConstantOp(
value=ArrayAttr([IntAttr(42), IntAttr(43)]),
value=value,
result_type=complex.ComplexType(i32),
)
constantlike = c1.get_trait(ConstantLike)
assert constantlike is not None
assert constantlike.get_constant_value(c1) == ArrayAttr([IntAttr(42), IntAttr(43)])

assert ConstantLike.get_constant_value(c1.complex) == value


class Test_constant_op_helper_constructors:
Expand Down
9 changes: 4 additions & 5 deletions tests/dialects/test_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@


def test_const_class_construction():
constant_op = arith.ConstantOp(IntegerAttr.from_int_and_width(42, 64))
value = IntegerAttr.from_int_and_width(42, 64)
constant_op = arith.ConstantOp(value)

const_class = equivalence.ConstantClassOp(constant_op.result)
trait = const_class.get_trait(ConstantLike)
assert trait is not None
assert trait.get_constant_value(const_class) == IntegerAttr.from_int_and_width(
42, 64
)
assert ConstantLike.get_constant_value(constant_op.result) == value

non_constant_op = test.TestOp(result_types=(i32,))
with pytest.raises(
DiagnosticException,
match="The argument of a ConstantClass must be a constant-like operation.",
match="The argument of a ConstantClass must be a `ConstantLike` operation implementing `HasFolderInterface`.",
):
equivalence.ConstantClassOp(non_constant_op.results[0])
27 changes: 10 additions & 17 deletions tests/dialects/test_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,25 +276,18 @@ def test_asm_section():

def test_get_constant_value():
# Test 32-bit LiOp
li_op = rv32.LiOp(1)
li_val = get_constant_value(li_op.rd)
assert li_val == IntegerAttr.from_int_and_width(1, 32)
# LiOp implements ConstantLikeInterface so it also has a get_constant_value method:
constantlike = li_op.get_trait(ConstantLike)
assert constantlike is not None
assert constantlike.get_constant_value(li_op) == IntegerAttr.from_int_and_width(
1, 32
)
one_32 = IntegerAttr.from_int_and_width(1, 32)
li_op_32 = rv32.LiOp(1)
li_val_32 = get_constant_value(li_op_32.rd)
assert li_val_32 == one_32
assert ConstantLike.get_constant_value(li_op_32.rd) == one_32

# Test 64-bit LiOp
li_op_64 = rv64.LiOp(1)
li_val_64 = get_constant_value(li_op_64.rd)
assert li_val_64 == IntegerAttr.from_int_and_width(1, 64)
constantlike = li_op_64.get_trait(ConstantLike)
assert constantlike is not None
assert constantlike.get_constant_value(li_op_64) == IntegerAttr.from_int_and_width(
1, 64
)
one_rv64 = IntegerAttr.from_int_and_width(1, 64)
li_op_rv64 = rv64.LiOp(1)
li_val_rv64 = get_constant_value(li_op_rv64.rd)
assert li_val_rv64 == one_rv64
assert ConstantLike.get_constant_value(li_op_rv64.rd) == one_rv64

zero_op = riscv.GetRegisterOp(riscv.Registers.ZERO)
zero_val = get_constant_value(zero_op.res)
Expand Down
20 changes: 5 additions & 15 deletions tests/dialects/test_smt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,12 @@ def test_constant_bool():
op = ConstantBoolOp(True)
assert op.value is True
assert op.value_attr == IntegerAttr(-1, 1)
constantlike = op.get_trait(ConstantLike)
assert constantlike is not None
assert constantlike.get_constant_value(op) == IntegerAttr(-1, 1)
assert ConstantLike.get_constant_value(op.result) == IntegerAttr(-1, 1)

op = ConstantBoolOp(False)
assert op.value is False
assert op.value_attr == IntegerAttr(0, 1)
constantlike = op.get_trait(ConstantLike)
assert constantlike is not None
assert constantlike.get_constant_value(op) == IntegerAttr(0, 1)
assert ConstantLike.get_constant_value(op.result) == IntegerAttr(0, 1)


def test_bv_type():
Expand Down Expand Up @@ -189,23 +185,17 @@ def test_bv_constant_op():
op = BvConstantOp(bv_attr)
assert op.value == bv_attr
assert op.result.type == BitVectorType(32)
constantlike = op.get_trait(ConstantLike)
assert constantlike is not None
assert constantlike.get_constant_value(op) == bv_attr
assert ConstantLike.get_constant_value(op.result) == bv_attr

op2 = BvConstantOp(42, 32)
assert op2.value == bv_attr
assert op2.result.type == BitVectorType(32)
constantlike2 = op2.get_trait(ConstantLike)
assert constantlike2 is not None
assert constantlike2.get_constant_value(op2) == bv_attr
assert ConstantLike.get_constant_value(op2.result) == bv_attr

op3 = BvConstantOp(42, BitVectorType(32))
assert op3.value == bv_attr
assert op3.result.type == BitVectorType(32)
constantlike3 = op3.get_trait(ConstantLike)
assert constantlike3 is not None
assert constantlike3.get_constant_value(op3) == bv_attr
assert ConstantLike.get_constant_value(op3.result) == bv_attr


@pytest.mark.parametrize("op_type", [BVNotOp, BVNegOp])
Expand Down
15 changes: 8 additions & 7 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
VectorType,
)
from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag
from xdsl.interfaces import ConstantLikeInterface, HasFolderInterface
from xdsl.interfaces import HasFolderInterface
from xdsl.ir import (
Attribute,
BitEnumAttribute,
Expand Down Expand Up @@ -58,6 +58,7 @@
from xdsl.traits import (
Commutative,
ConditionallySpeculatable,
ConstantLike,
HasCanonicalizationPatternsTrait,
NoMemoryEffect,
Pure,
Expand Down Expand Up @@ -137,7 +138,7 @@ def __init__(self, flags: None | Sequence[IntegerOverflowFlag] | Literal["none"]


@irdl_op_definition
class ConstantOp(IRDLOperation, ConstantLikeInterface):
class ConstantOp(IRDLOperation, HasFolderInterface):
name = "arith.constant"
_T: ClassVar = VarConstraint("T", AnyAttr())
result = result_def(_T)
Expand All @@ -148,7 +149,7 @@ class ConstantOp(IRDLOperation, ConstantLikeInterface):
| ParamAttrConstraint(DenseResourceAttr, (AnyAttr(), _T))
)

traits = traits_def(Pure())
traits = traits_def(Pure(), ConstantLike())

assembly_format = "attr-dict $value"

Expand Down Expand Up @@ -180,8 +181,8 @@ def from_int_and_width(
},
)

def get_constant_value(self) -> Attribute:
return self.value
def fold(self) -> Sequence[SSAValue | Attribute] | None:
return (self.value,)


class SignlessIntegerBinaryOperation(IRDLOperation, HasFolderInterface, abc.ABC):
Expand Down Expand Up @@ -228,8 +229,8 @@ def is_right_unit(attr: IntegerAttr) -> bool:
return False

def fold(self):
lhs = self.get_constant(self.lhs)
rhs = self.get_constant(self.rhs)
lhs = ConstantLike.get_constant_value(self.lhs)
rhs = ConstantLike.get_constant_value(self.rhs)
if lhs is not None and rhs is not None:
if isa(lhs, IntegerAttr) and isa(rhs, IntegerAttr):
assert lhs.type == rhs.type
Expand Down
12 changes: 6 additions & 6 deletions xdsl/dialects/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
IntegerAttr,
IntegerType,
)
from xdsl.interfaces import ConstantLikeInterface
from xdsl.interfaces import HasFolderInterface
from xdsl.ir import (
Attribute,
Dialect,
Expand Down Expand Up @@ -47,7 +47,7 @@
)
from xdsl.parser import AttrParser
from xdsl.printer import Printer
from xdsl.traits import Pure
from xdsl.traits import ConstantLike, Pure
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa

Expand Down Expand Up @@ -292,7 +292,7 @@ class ConjOp(ComplexUnaryComplexResultOperation):


@irdl_op_definition
class ConstantOp(IRDLOperation, ConstantLikeInterface):
class ConstantOp(IRDLOperation, HasFolderInterface):
name = "complex.constant"
T: ClassVar = VarConstraint("T", AnyFloatConstr | base(IntegerType))
value = prop_def(
Expand All @@ -312,15 +312,15 @@ class ConstantOp(IRDLOperation, ConstantLikeInterface):
# have any complex result type, not just floating point:
complex = result_def(ComplexType.constr(T))

traits = traits_def(Pure())
traits = traits_def(Pure(), ConstantLike())

assembly_format = "$value attr-dict `:` type($complex)"

def __init__(self, value: ArrayAttr, result_type: ComplexType):
super().__init__(properties={"value": value}, result_types=[result_type])

def get_constant_value(self) -> Attribute:
return self.value
def fold(self) -> tuple[Attribute]:
return (self.value,)

@staticmethod
def from_floats(value: tuple[float, float], type: AnyFloat) -> ConstantOp:
Expand Down
16 changes: 8 additions & 8 deletions xdsl/dialects/equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import ClassVar

from xdsl.dialects.builtin import IntAttr
from xdsl.interfaces import ConstantLikeInterface
from xdsl.interfaces import HasFolderInterface
from xdsl.ir import Attribute, Dialect, OpResult, Region, SSAValue
from xdsl.irdl import (
AnyAttr,
Expand Down Expand Up @@ -43,7 +43,7 @@


@irdl_op_definition
class ConstantClassOp(IRDLOperation, ConstantLikeInterface):
class ConstantClassOp(IRDLOperation, HasFolderInterface):
"""An e-class representing a known constant value.
For non-constant e-classes, use [ClassOp][xdsl.dialects.equivalence.ClassOp].
"""
Expand All @@ -55,22 +55,22 @@ class ConstantClassOp(IRDLOperation, ConstantLikeInterface):
assembly_format = (
"$arguments ` ` `(` `constant` `=` $value `)` attr-dict `:` type($result)"
)
traits = traits_def(Pure())
traits = traits_def(Pure(), ConstantLike())

arguments = var_operand_def(T)
result = result_def(T)
value = prop_def()
min_cost_index = opt_attr_def(IntAttr)

def get_constant_value(self):
return self.value
def fold(self) -> tuple[Attribute]:
return (self.value,)

def __init__(self, const_arg: OpResult):
if (trait := const_arg.owner.get_trait(ConstantLike)) is None:
if (value := ConstantLike.get_constant_value(const_arg)) is None:
raise DiagnosticException(
"The argument of a ConstantClass must be a constant-like operation."
"The argument of a ConstantClass must be a `ConstantLike` operation implementing `HasFolderInterface`."
)
value = trait.get_constant_value(const_arg.owner)

super().__init__(
operands=[const_arg],
result_types=[const_arg.type],
Expand Down
11 changes: 6 additions & 5 deletions xdsl/dialects/rv32.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
print_immediate_value,
)
from xdsl.dialects.riscv.ops import LiOpHasCanonicalizationPatternTrait
from xdsl.interfaces import ConstantLikeInterface
from xdsl.interfaces import HasFolderInterface
from xdsl.ir import (
Attribute,
Dialect,
Expand All @@ -37,12 +37,13 @@
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.traits import (
ConstantLike,
Pure,
)


@irdl_op_definition
class LiOp(RISCVCustomFormatOperation, RISCVInstruction, ConstantLikeInterface, ABC):
class LiOp(RISCVCustomFormatOperation, RISCVInstruction, HasFolderInterface, ABC):
"""
Loads a 32-bit immediate into rd.

Expand All @@ -56,7 +57,7 @@ class LiOp(RISCVCustomFormatOperation, RISCVInstruction, ConstantLikeInterface,
rd = result_def(IntRegisterType)
immediate = attr_def(IntegerAttr[I32] | LabelAttr)

traits = traits_def(Pure(), LiOpHasCanonicalizationPatternTrait())
traits = traits_def(Pure(), LiOpHasCanonicalizationPatternTrait(), ConstantLike())

def __init__(
self,
Expand All @@ -83,8 +84,8 @@ def __init__(
def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
return self.rd, self.immediate

def get_constant_value(self):
return self.immediate
def fold(self) -> tuple[IntegerAttr[I32] | LabelAttr]:
return (self.immediate,)

@classmethod
def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]:
Expand Down
11 changes: 6 additions & 5 deletions xdsl/dialects/rv64.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
print_immediate_value,
)
from xdsl.dialects.riscv.ops import LiOpHasCanonicalizationPatternTrait
from xdsl.interfaces import ConstantLikeInterface
from xdsl.interfaces import HasFolderInterface
from xdsl.ir import (
Attribute,
Dialect,
Expand All @@ -37,12 +37,13 @@
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.traits import (
ConstantLike,
Pure,
)


@irdl_op_definition
class LiOp(RISCVCustomFormatOperation, RISCVInstruction, ConstantLikeInterface, ABC):
class LiOp(RISCVCustomFormatOperation, RISCVInstruction, HasFolderInterface, ABC):
"""
Loads a 64-bit immediate into rd.

Expand All @@ -56,7 +57,7 @@ class LiOp(RISCVCustomFormatOperation, RISCVInstruction, ConstantLikeInterface,
rd = result_def(IntRegisterType)
immediate = attr_def(IntegerAttr[I64] | LabelAttr)

traits = traits_def(Pure(), LiOpHasCanonicalizationPatternTrait())
traits = traits_def(Pure(), LiOpHasCanonicalizationPatternTrait(), ConstantLike())

def __init__(
self,
Expand All @@ -83,8 +84,8 @@ def __init__(
def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
return self.rd, self.immediate

def get_constant_value(self):
return self.immediate
def fold(self) -> tuple[IntegerAttr[I64] | LabelAttr]:
return (self.immediate,)

@classmethod
def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]:
Expand Down
Loading