diff --git a/tests/dialects/test_arith.py b/tests/dialects/test_arith.py index 49bb494b71..3f978a2764 100644 --- a/tests/dialects/test_arith.py +++ b/tests/dialects/test_arith.py @@ -130,26 +130,21 @@ def test_Cmpi_from_mnemonic(self, input: str): def test_constant_construction(): - c1 = ConstantOp(IntegerAttr(1, i32)) + attr1 = IntegerAttr(1, i32) + c1 = ConstantOp(attr1) assert c1.value.type == i32 - constantlike1 = c1.get_trait(ConstantLike) - assert constantlike1 is not None - assert constantlike1.get_constant_value(c1) == IntegerAttr(1, i32) + assert ConstantLike.get_constant_value(c1.result) == attr1 - c3 = ConstantOp(FloatAttr(1.0, f32)) - assert c3.value.type == f32 - constantlike3 = c3.get_trait(ConstantLike) - assert constantlike3 is not None - assert constantlike3.get_constant_value(c3) == FloatAttr(1.0, f32) + attr2 = FloatAttr(1.0, f32) + c2 = ConstantOp(attr2) + assert c2.value.type == f32 + assert ConstantLike.get_constant_value(c2.result) == attr2 value_type = TensorType(i32, [2, 2]) - c5 = ConstantOp(DenseIntOrFPElementsAttr.from_list(value_type, [1, 2, 3, 4])) - assert c5.value.type == value_type - constantlike5 = c5.get_trait(ConstantLike) - assert constantlike5 is not None - assert constantlike5.get_constant_value(c5) == DenseIntOrFPElementsAttr.from_list( - value_type, [1, 2, 3, 4] - ) + attr3 = DenseIntOrFPElementsAttr.from_list(value_type, [1, 2, 3, 4]) + c3 = ConstantOp(attr3) + assert c3.value.type == value_type + assert ConstantLike.get_constant_value(c3.result) == attr3 @pytest.mark.parametrize( diff --git a/tests/dialects/test_complex.py b/tests/dialects/test_complex.py index 9d4efd757a..717d4fb0a3 100644 --- a/tests/dialects/test_complex.py +++ b/tests/dialects/test_complex.py @@ -22,13 +22,13 @@ def test_constant_construction(): + value = ArrayAttr([IntAttr(42), IntAttr(43)]) c1 = complex.ConstantOp( - value=ArrayAttr([IntAttr(42), IntAttr(43)]), + value=value, result_type=complex.ComplexType(i32), ) - constantlike = c1.get_trait(ConstantLike) - assert constantlike is not None - assert constantlike.get_constant_value(c1) == ArrayAttr([IntAttr(42), IntAttr(43)]) + + assert ConstantLike.get_constant_value(c1.complex) == value class Test_constant_op_helper_constructors: diff --git a/tests/dialects/test_equivalence.py b/tests/dialects/test_equivalence.py index bc35857c16..37471f6177 100644 --- a/tests/dialects/test_equivalence.py +++ b/tests/dialects/test_equivalence.py @@ -7,18 +7,17 @@ def test_const_class_construction(): - constant_op = arith.ConstantOp(IntegerAttr.from_int_and_width(42, 64)) + value = IntegerAttr.from_int_and_width(42, 64) + constant_op = arith.ConstantOp(value) const_class = equivalence.ConstantClassOp(constant_op.result) trait = const_class.get_trait(ConstantLike) assert trait is not None - assert trait.get_constant_value(const_class) == IntegerAttr.from_int_and_width( - 42, 64 - ) + assert ConstantLike.get_constant_value(constant_op.result) == value non_constant_op = test.TestOp(result_types=(i32,)) with pytest.raises( DiagnosticException, - match="The argument of a ConstantClass must be a constant-like operation.", + match="The argument of a ConstantClass must be a `ConstantLike` operation implementing `HasFolderInterface`.", ): equivalence.ConstantClassOp(non_constant_op.results[0]) diff --git a/tests/dialects/test_riscv.py b/tests/dialects/test_riscv.py index 09d148ef63..311e2e1100 100644 --- a/tests/dialects/test_riscv.py +++ b/tests/dialects/test_riscv.py @@ -276,25 +276,18 @@ def test_asm_section(): def test_get_constant_value(): # Test 32-bit LiOp - li_op = rv32.LiOp(1) - li_val = get_constant_value(li_op.rd) - assert li_val == IntegerAttr.from_int_and_width(1, 32) - # LiOp implements ConstantLikeInterface so it also has a get_constant_value method: - constantlike = li_op.get_trait(ConstantLike) - assert constantlike is not None - assert constantlike.get_constant_value(li_op) == IntegerAttr.from_int_and_width( - 1, 32 - ) + one_32 = IntegerAttr.from_int_and_width(1, 32) + li_op_32 = rv32.LiOp(1) + li_val_32 = get_constant_value(li_op_32.rd) + assert li_val_32 == one_32 + assert ConstantLike.get_constant_value(li_op_32.rd) == one_32 # Test 64-bit LiOp - li_op_64 = rv64.LiOp(1) - li_val_64 = get_constant_value(li_op_64.rd) - assert li_val_64 == IntegerAttr.from_int_and_width(1, 64) - constantlike = li_op_64.get_trait(ConstantLike) - assert constantlike is not None - assert constantlike.get_constant_value(li_op_64) == IntegerAttr.from_int_and_width( - 1, 64 - ) + one_rv64 = IntegerAttr.from_int_and_width(1, 64) + li_op_rv64 = rv64.LiOp(1) + li_val_rv64 = get_constant_value(li_op_rv64.rd) + assert li_val_rv64 == one_rv64 + assert ConstantLike.get_constant_value(li_op_rv64.rd) == one_rv64 zero_op = riscv.GetRegisterOp(riscv.Registers.ZERO) zero_val = get_constant_value(zero_op.res) diff --git a/tests/dialects/test_smt.py b/tests/dialects/test_smt.py index 4207d20908..0ef7682b77 100644 --- a/tests/dialects/test_smt.py +++ b/tests/dialects/test_smt.py @@ -52,16 +52,12 @@ def test_constant_bool(): op = ConstantBoolOp(True) assert op.value is True assert op.value_attr == IntegerAttr(-1, 1) - constantlike = op.get_trait(ConstantLike) - assert constantlike is not None - assert constantlike.get_constant_value(op) == IntegerAttr(-1, 1) + assert ConstantLike.get_constant_value(op.result) == IntegerAttr(-1, 1) op = ConstantBoolOp(False) assert op.value is False assert op.value_attr == IntegerAttr(0, 1) - constantlike = op.get_trait(ConstantLike) - assert constantlike is not None - assert constantlike.get_constant_value(op) == IntegerAttr(0, 1) + assert ConstantLike.get_constant_value(op.result) == IntegerAttr(0, 1) def test_bv_type(): @@ -189,23 +185,17 @@ def test_bv_constant_op(): op = BvConstantOp(bv_attr) assert op.value == bv_attr assert op.result.type == BitVectorType(32) - constantlike = op.get_trait(ConstantLike) - assert constantlike is not None - assert constantlike.get_constant_value(op) == bv_attr + assert ConstantLike.get_constant_value(op.result) == bv_attr op2 = BvConstantOp(42, 32) assert op2.value == bv_attr assert op2.result.type == BitVectorType(32) - constantlike2 = op2.get_trait(ConstantLike) - assert constantlike2 is not None - assert constantlike2.get_constant_value(op2) == bv_attr + assert ConstantLike.get_constant_value(op2.result) == bv_attr op3 = BvConstantOp(42, BitVectorType(32)) assert op3.value == bv_attr assert op3.result.type == BitVectorType(32) - constantlike3 = op3.get_trait(ConstantLike) - assert constantlike3 is not None - assert constantlike3.get_constant_value(op3) == bv_attr + assert ConstantLike.get_constant_value(op3.result) == bv_attr @pytest.mark.parametrize("op_type", [BVNotOp, BVNegOp]) diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 1e79f40c4d..327946b25b 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -30,7 +30,7 @@ VectorType, ) from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag -from xdsl.interfaces import ConstantLikeInterface, HasFolderInterface +from xdsl.interfaces import HasFolderInterface from xdsl.ir import ( Attribute, BitEnumAttribute, @@ -58,6 +58,7 @@ from xdsl.traits import ( Commutative, ConditionallySpeculatable, + ConstantLike, HasCanonicalizationPatternsTrait, NoMemoryEffect, Pure, @@ -137,7 +138,7 @@ def __init__(self, flags: None | Sequence[IntegerOverflowFlag] | Literal["none"] @irdl_op_definition -class ConstantOp(IRDLOperation, ConstantLikeInterface): +class ConstantOp(IRDLOperation, HasFolderInterface): name = "arith.constant" _T: ClassVar = VarConstraint("T", AnyAttr()) result = result_def(_T) @@ -148,7 +149,7 @@ class ConstantOp(IRDLOperation, ConstantLikeInterface): | ParamAttrConstraint(DenseResourceAttr, (AnyAttr(), _T)) ) - traits = traits_def(Pure()) + traits = traits_def(Pure(), ConstantLike()) assembly_format = "attr-dict $value" @@ -180,8 +181,8 @@ def from_int_and_width( }, ) - def get_constant_value(self) -> Attribute: - return self.value + def fold(self) -> Sequence[SSAValue | Attribute] | None: + return (self.value,) class SignlessIntegerBinaryOperation(IRDLOperation, HasFolderInterface, abc.ABC): @@ -228,8 +229,8 @@ def is_right_unit(attr: IntegerAttr) -> bool: return False def fold(self): - lhs = self.get_constant(self.lhs) - rhs = self.get_constant(self.rhs) + lhs = ConstantLike.get_constant_value(self.lhs) + rhs = ConstantLike.get_constant_value(self.rhs) if lhs is not None and rhs is not None: if isa(lhs, IntegerAttr) and isa(rhs, IntegerAttr): assert lhs.type == rhs.type diff --git a/xdsl/dialects/complex.py b/xdsl/dialects/complex.py index c38bab970e..daea196da8 100644 --- a/xdsl/dialects/complex.py +++ b/xdsl/dialects/complex.py @@ -19,7 +19,7 @@ IntegerAttr, IntegerType, ) -from xdsl.interfaces import ConstantLikeInterface +from xdsl.interfaces import HasFolderInterface from xdsl.ir import ( Attribute, Dialect, @@ -47,7 +47,7 @@ ) from xdsl.parser import AttrParser from xdsl.printer import Printer -from xdsl.traits import Pure +from xdsl.traits import ConstantLike, Pure from xdsl.utils.exceptions import VerifyException from xdsl.utils.hints import isa @@ -292,7 +292,7 @@ class ConjOp(ComplexUnaryComplexResultOperation): @irdl_op_definition -class ConstantOp(IRDLOperation, ConstantLikeInterface): +class ConstantOp(IRDLOperation, HasFolderInterface): name = "complex.constant" T: ClassVar = VarConstraint("T", AnyFloatConstr | base(IntegerType)) value = prop_def( @@ -312,15 +312,15 @@ class ConstantOp(IRDLOperation, ConstantLikeInterface): # have any complex result type, not just floating point: complex = result_def(ComplexType.constr(T)) - traits = traits_def(Pure()) + traits = traits_def(Pure(), ConstantLike()) assembly_format = "$value attr-dict `:` type($complex)" def __init__(self, value: ArrayAttr, result_type: ComplexType): super().__init__(properties={"value": value}, result_types=[result_type]) - def get_constant_value(self) -> Attribute: - return self.value + def fold(self) -> tuple[Attribute]: + return (self.value,) @staticmethod def from_floats(value: tuple[float, float], type: AnyFloat) -> ConstantOp: diff --git a/xdsl/dialects/equivalence.py b/xdsl/dialects/equivalence.py index ca41478a52..c9c14d249a 100644 --- a/xdsl/dialects/equivalence.py +++ b/xdsl/dialects/equivalence.py @@ -11,7 +11,7 @@ from typing import ClassVar from xdsl.dialects.builtin import IntAttr -from xdsl.interfaces import ConstantLikeInterface +from xdsl.interfaces import HasFolderInterface from xdsl.ir import Attribute, Dialect, OpResult, Region, SSAValue from xdsl.irdl import ( AnyAttr, @@ -43,7 +43,7 @@ @irdl_op_definition -class ConstantClassOp(IRDLOperation, ConstantLikeInterface): +class ConstantClassOp(IRDLOperation, HasFolderInterface): """An e-class representing a known constant value. For non-constant e-classes, use [ClassOp][xdsl.dialects.equivalence.ClassOp]. """ @@ -55,22 +55,22 @@ class ConstantClassOp(IRDLOperation, ConstantLikeInterface): assembly_format = ( "$arguments ` ` `(` `constant` `=` $value `)` attr-dict `:` type($result)" ) - traits = traits_def(Pure()) + traits = traits_def(Pure(), ConstantLike()) arguments = var_operand_def(T) result = result_def(T) value = prop_def() min_cost_index = opt_attr_def(IntAttr) - def get_constant_value(self): - return self.value + def fold(self) -> tuple[Attribute]: + return (self.value,) def __init__(self, const_arg: OpResult): - if (trait := const_arg.owner.get_trait(ConstantLike)) is None: + if (value := ConstantLike.get_constant_value(const_arg)) is None: raise DiagnosticException( - "The argument of a ConstantClass must be a constant-like operation." + "The argument of a ConstantClass must be a `ConstantLike` operation implementing `HasFolderInterface`." ) - value = trait.get_constant_value(const_arg.owner) + super().__init__( operands=[const_arg], result_types=[const_arg.type], diff --git a/xdsl/dialects/rv32.py b/xdsl/dialects/rv32.py index c84e00b126..a0445ae3be 100644 --- a/xdsl/dialects/rv32.py +++ b/xdsl/dialects/rv32.py @@ -23,7 +23,7 @@ print_immediate_value, ) from xdsl.dialects.riscv.ops import LiOpHasCanonicalizationPatternTrait -from xdsl.interfaces import ConstantLikeInterface +from xdsl.interfaces import HasFolderInterface from xdsl.ir import ( Attribute, Dialect, @@ -37,12 +37,13 @@ from xdsl.parser import Parser from xdsl.printer import Printer from xdsl.traits import ( + ConstantLike, Pure, ) @irdl_op_definition -class LiOp(RISCVCustomFormatOperation, RISCVInstruction, ConstantLikeInterface, ABC): +class LiOp(RISCVCustomFormatOperation, RISCVInstruction, HasFolderInterface, ABC): """ Loads a 32-bit immediate into rd. @@ -56,7 +57,7 @@ class LiOp(RISCVCustomFormatOperation, RISCVInstruction, ConstantLikeInterface, rd = result_def(IntRegisterType) immediate = attr_def(IntegerAttr[I32] | LabelAttr) - traits = traits_def(Pure(), LiOpHasCanonicalizationPatternTrait()) + traits = traits_def(Pure(), LiOpHasCanonicalizationPatternTrait(), ConstantLike()) def __init__( self, @@ -83,8 +84,8 @@ def __init__( def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]: return self.rd, self.immediate - def get_constant_value(self): - return self.immediate + def fold(self) -> tuple[IntegerAttr[I32] | LabelAttr]: + return (self.immediate,) @classmethod def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]: diff --git a/xdsl/dialects/rv64.py b/xdsl/dialects/rv64.py index d724106d63..da406b0a12 100644 --- a/xdsl/dialects/rv64.py +++ b/xdsl/dialects/rv64.py @@ -23,7 +23,7 @@ print_immediate_value, ) from xdsl.dialects.riscv.ops import LiOpHasCanonicalizationPatternTrait -from xdsl.interfaces import ConstantLikeInterface +from xdsl.interfaces import HasFolderInterface from xdsl.ir import ( Attribute, Dialect, @@ -37,12 +37,13 @@ from xdsl.parser import Parser from xdsl.printer import Printer from xdsl.traits import ( + ConstantLike, Pure, ) @irdl_op_definition -class LiOp(RISCVCustomFormatOperation, RISCVInstruction, ConstantLikeInterface, ABC): +class LiOp(RISCVCustomFormatOperation, RISCVInstruction, HasFolderInterface, ABC): """ Loads a 64-bit immediate into rd. @@ -56,7 +57,7 @@ class LiOp(RISCVCustomFormatOperation, RISCVInstruction, ConstantLikeInterface, rd = result_def(IntRegisterType) immediate = attr_def(IntegerAttr[I64] | LabelAttr) - traits = traits_def(Pure(), LiOpHasCanonicalizationPatternTrait()) + traits = traits_def(Pure(), LiOpHasCanonicalizationPatternTrait(), ConstantLike()) def __init__( self, @@ -83,8 +84,8 @@ def __init__( def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]: return self.rd, self.immediate - def get_constant_value(self): - return self.immediate + def fold(self) -> tuple[IntegerAttr[I64] | LabelAttr]: + return (self.immediate,) @classmethod def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]: diff --git a/xdsl/dialects/smt.py b/xdsl/dialects/smt.py index 899c3bf905..4d5896bbfc 100644 --- a/xdsl/dialects/smt.py +++ b/xdsl/dialects/smt.py @@ -7,7 +7,7 @@ from typing_extensions import Self from xdsl.dialects.builtin import ArrayAttr, BoolAttr, IntAttr, StringAttr -from xdsl.interfaces import ConstantLikeInterface +from xdsl.interfaces import HasFolderInterface from xdsl.ir import ( Attribute, Dialect, @@ -41,7 +41,7 @@ ) from xdsl.parser import AttrParser, Parser from xdsl.printer import Printer -from xdsl.traits import HasParent, IsTerminator, Pure +from xdsl.traits import ConstantLike, HasParent, IsTerminator, Pure from xdsl.utils.exceptions import VerifyException from xdsl.utils.hints import isa @@ -265,7 +265,7 @@ def __init__(self, func: SSAValue[FuncType], *args: SSAValue): @irdl_op_definition -class ConstantBoolOp(IRDLOperation, ConstantLikeInterface): +class ConstantBoolOp(IRDLOperation, HasFolderInterface): """ This operation represents a constant boolean value. The semantics are equivalent to the ‘true’ and ‘false’ keywords in the Core theory of the @@ -277,7 +277,7 @@ class ConstantBoolOp(IRDLOperation, ConstantLikeInterface): value_attr = prop_def(BoolAttr, prop_name="value") result = result_def(BoolType()) - traits = traits_def(Pure()) + traits = traits_def(Pure(), ConstantLike()) assembly_format = "qualified($value) attr-dict" @@ -289,8 +289,8 @@ def __init__(self, value: bool): def value(self) -> bool: return bool(self.value_attr) - def get_constant_value(self) -> Attribute: - return self.value_attr + def fold(self) -> tuple[BoolAttr]: + return (self.value_attr,) @irdl_op_definition @@ -604,7 +604,7 @@ def __init__(self, input: SSAValue): @irdl_op_definition -class BvConstantOp(IRDLOperation, ConstantLikeInterface): +class BvConstantOp(IRDLOperation, HasFolderInterface): """ This operation produces an SSA value equal to the bitvector constant specified by the ‘value’ attribute. @@ -619,7 +619,7 @@ class BvConstantOp(IRDLOperation, ConstantLikeInterface): assembly_format = "qualified($value) attr-dict" - traits = traits_def(Pure()) + traits = traits_def(Pure(), ConstantLike()) @overload def __init__(self, value: BitVectorAttr) -> None: ... @@ -640,8 +640,8 @@ def __init__( value = BitVectorAttr(value, type) super().__init__(properties={"value": value}, result_types=[value.type]) - def get_constant_value(self) -> Attribute: - return self.value + def fold(self) -> tuple[Attribute]: + return (self.value,) class UnaryBVOp(IRDLOperation, ABC): diff --git a/xdsl/interfaces.py b/xdsl/interfaces.py index 57fb312243..c4b053cb97 100644 --- a/xdsl/interfaces.py +++ b/xdsl/interfaces.py @@ -16,7 +16,6 @@ from xdsl.irdl import traits_def from xdsl.pattern_rewriter import RewritePattern from xdsl.traits import ( - ConstantLike, HasCanonicalizationPatternsTrait, HasFolder, ) @@ -56,35 +55,6 @@ def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: raise NotImplementedError() -class _ConstantLikeInterfaceTrait(ConstantLike): - """ - Gets the constant value from the operation's implementation - of `ConstantLikeInterface`. - """ - - def verify(self, op: Operation) -> None: - return - - @classmethod - def get_constant_value(cls, op: Operation) -> Attribute: - op = cast(ConstantLikeInterface, op) - return op.get_constant_value() - - -class ConstantLikeInterface(Operation, abc.ABC): - """ - An operation subclassing this interface must implement the - `get_constant_value` method, which returns the constant value of this operation. - Wraps `ConstantLikeTrait`. - """ - - traits = traits_def(_ConstantLikeInterfaceTrait()) - - @abc.abstractmethod - def get_constant_value(self) -> Attribute: - raise NotImplementedError() - - class _HasFolderInterfaceTrait(HasFolder): """ Gets the fold results from the operation's implementation @@ -109,13 +79,6 @@ class HasFolderInterface(Operation, abc.ABC): traits = traits_def(_HasFolderInterfaceTrait()) - def get_constant(self, operand: SSAValue) -> Attribute | None: - if ( - isinstance(operand_op := operand.owner, Operation) - and (t := operand_op.get_trait(ConstantLike)) is not None - ): - return t.get_constant_value(operand_op) - @abc.abstractmethod def fold(self) -> Sequence[SSAValue | Attribute] | None: """ diff --git a/xdsl/traits.py b/xdsl/traits.py index d5c453322f..ccd0d5fb74 100644 --- a/xdsl/traits.py +++ b/xdsl/traits.py @@ -37,19 +37,30 @@ class ConstantLike(OpTrait, abc.ABC): """ Operation known to be constant-like. + To participate in constant folding and other generic mechanisms implement + `HasFolder` or `HasFolderInterface` for your operation. + See external [documentation](https://mlir.llvm.org/doxygen/classmlir_1_1OpTrait_1_1ConstantLike.html). """ - @classmethod - @abc.abstractmethod - def get_constant_value(cls, op: Operation) -> Attribute: + @staticmethod + def get_constant_value(ssa_value: SSAValue) -> Attribute | None: """ - Get the constant value from this constant-like operation. - - Returns: - The constant value as an Attribute, or None if the value cannot be determined. + If the value is the result of a `ConstantLike` operation that implements + `HasFolderInterface`, return the attribute returned by `fold` corresponding to + the value's index in the list of results. """ - raise NotImplementedError() + from xdsl.ir import Attribute, OpResult + + if ( + isinstance(ssa_value, OpResult) + and (op := ssa_value.owner) + and op.has_trait(ConstantLike) + and (t := op.get_trait(HasFolder)) is not None + and (values := t.fold(op)) is not None + and isinstance(value := values[ssa_value.index], Attribute) + ): + return value class HasFolder(OpTrait): diff --git a/xdsl/transforms/canonicalization_patterns/riscv.py b/xdsl/transforms/canonicalization_patterns/riscv.py index 0ddd9f47a2..fcf59459bd 100644 --- a/xdsl/transforms/canonicalization_patterns/riscv.py +++ b/xdsl/transforms/canonicalization_patterns/riscv.py @@ -1,9 +1,10 @@ -from typing import cast +from typing import Literal, cast from xdsl.dialects import riscv, riscv_snitch, rv32 -from xdsl.dialects.builtin import I32, I64, IntegerAttr, i32 +from xdsl.dialects.builtin import I32, I64, IntegerAttr, IntegerType, Signedness, i32 from xdsl.dialects.utils import FastMathFlag from xdsl.ir import OpResult, SSAValue +from xdsl.irdl import irdl_to_attr_constraint from xdsl.pattern_rewriter import ( PatternRewriter, RewritePattern, @@ -645,7 +646,14 @@ def match_and_rewrite(self, op: rv32.LiOp, rewriter: PatternRewriter) -> None: ) -def get_constant_value(value: SSAValue) -> IntegerAttr[I32] | IntegerAttr[I64] | None: +_I32_I64_CONSTRAINT = irdl_to_attr_constraint( + IntegerAttr[IntegerType[Literal[32, 64], Literal[Signedness.SIGNLESS]]] +) + + +def get_constant_value( + value: SSAValue, +) -> IntegerAttr[I32] | IntegerAttr[I64] | None: if value.type == riscv.Registers.ZERO: return IntegerAttr(0, i32) @@ -655,8 +663,7 @@ def get_constant_value(value: SSAValue) -> IntegerAttr[I32] | IntegerAttr[I64] | if isinstance(value.op, riscv.MVOp): return get_constant_value(value.op.rs) - constant_like = value.op.get_trait(ConstantLike) - if constant_like is not None: - result = constant_like.get_constant_value(value.op) - if isinstance(result, IntegerAttr): - return cast(IntegerAttr[I32] | IntegerAttr[I64], result) + if ( + result := ConstantLike.get_constant_value(value) + ) is not None and _I32_I64_CONSTRAINT.verifies(result): + return cast(IntegerAttr[I32] | IntegerAttr[I64], result)