From 52833c5dc46031f5f7c6cea3014d62ba026cb72f Mon Sep 17 00:00:00 2001 From: Wojciech Sipak Date: Thu, 4 Dec 2025 15:07:33 +0100 Subject: [PATCH 1/2] add const match Internal-tag: #[88020] --- xls/dslx/frontend/ast.cc | 6 +- xls/dslx/frontend/ast.h | 8 +- xls/dslx/frontend/ast_cloner.cc | 2 +- xls/dslx/frontend/parser.cc | 35 +++++- xls/dslx/frontend/parser.h | 2 +- xls/dslx/frontend/token_parser.h | 12 +- xls/dslx/ir_convert/function_converter.cc | 104 ++++++++++++++++++ xls/dslx/ir_convert/function_converter.h | 4 + .../typecheck_module_v2_test.cc | 14 +++ 9 files changed, 169 insertions(+), 18 deletions(-) diff --git a/xls/dslx/frontend/ast.cc b/xls/dslx/frontend/ast.cc index da5656b94a..e6947bf6bd 100644 --- a/xls/dslx/frontend/ast.cc +++ b/xls/dslx/frontend/ast.cc @@ -1335,7 +1335,7 @@ std::vector Match::GetChildren(bool want_types) const { } std::string Match::ToStringInternal() const { - std::string result = absl::StrFormat("match %s {\n", matched_->ToString()); + std::string result = absl::StrFormat("%smatch %s {\n", IsConst() ? "const " : "", matched_->ToString()); for (MatchArm* arm : arms_) { absl::StrAppend(&result, Indent(absl::StrCat(arm->ToString(), ",\n"), kRustSpacesPerIndent)); @@ -2392,8 +2392,8 @@ Span MatchArm::GetPatternSpan() const { } Match::Match(Module* owner, Span span, Expr* matched, - std::vector arms, bool in_parens) - : Expr(owner, std::move(span), in_parens), + std::vector arms, bool in_parens, bool is_const) + : Expr(owner, std::move(span), in_parens, is_const), matched_(matched), arms_(std::move(arms)) {} diff --git a/xls/dslx/frontend/ast.h b/xls/dslx/frontend/ast.h index dc80fbcb42..854c113fd6 100644 --- a/xls/dslx/frontend/ast.h +++ b/xls/dslx/frontend/ast.h @@ -1151,8 +1151,8 @@ inline bool WeakerThan(Precedence x, Precedence y) { // (i.e. can produce runtime values). class Expr : public AstNode { public: - Expr(Module* owner, Span span, bool in_parens = false) - : AstNode(owner), span_(span), in_parens_(in_parens) {} + Expr(Module* owner, Span span, bool in_parens = false, bool is_const = false) + : AstNode(owner), span_(span), in_parens_(in_parens), is_const_(is_const) {} ~Expr() override; @@ -1202,6 +1202,7 @@ class Expr : public AstNode { // (x == y) == z bool in_parens() const { return in_parens_; } void set_in_parens(bool enabled) { in_parens_ = enabled; } + bool IsConst() const { return is_const_; } protected: virtual std::string ToStringInternal() const = 0; @@ -1211,6 +1212,7 @@ class Expr : public AstNode { private: Span span_; bool in_parens_ = false; + bool is_const_; }; // ChannelTypeAnnotation has to be placed after the definition of Expr, so it @@ -2626,7 +2628,7 @@ class MatchArm : public AstNode { class Match : public Expr { public: Match(Module* owner, Span span, Expr* matched, std::vector arms, - bool in_parens = false); + bool in_parens = false, bool is_const = false); ~Match() override; diff --git a/xls/dslx/frontend/ast_cloner.cc b/xls/dslx/frontend/ast_cloner.cc index c003ac9c49..b955a84707 100644 --- a/xls/dslx/frontend/ast_cloner.cc +++ b/xls/dslx/frontend/ast_cloner.cc @@ -553,7 +553,7 @@ class AstCloner : public AstNodeVisitor { old_to_new_[n] = module(n)->Make( n->span(), down_cast(old_to_new_.at(n->matched())), new_arms, - n->in_parens()); + n->in_parens(), n->IsConst()); return absl::OkStatus(); } diff --git a/xls/dslx/frontend/parser.cc b/xls/dslx/frontend/parser.cc index 11a42e538d..9a13f8bb81 100644 --- a/xls/dslx/frontend/parser.cc +++ b/xls/dslx/frontend/parser.cc @@ -1984,7 +1984,10 @@ absl::StatusOr Parser::ParsePattern(Bindings& bindings, absl::StrFormat("Expected pattern; got %s", peek->ToErrorString())); } -absl::StatusOr Parser::ParseMatch(Bindings& bindings) { +absl::StatusOr Parser::ParseMatch(Bindings& bindings, bool is_const) { + if (is_const) { + XLS_RETURN_IF_ERROR(DropKeywordOrError(Keyword::kConst)); + } XLS_ASSIGN_OR_RETURN(Token match, PopKeywordOrError(Keyword::kMatch)); XLS_ASSIGN_OR_RETURN(Expr * matched, ParseExpression(bindings)); XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kOBrace)); @@ -2041,7 +2044,7 @@ absl::StatusOr Parser::ParseMatch(Bindings& bindings) { must_end = !dropped_comma; } Span span(match.span().start(), GetPos()); - return module_->Make(span, matched, std::move(arms)); + return module_->Make(span, matched, std::move(arms), false, is_const); } absl::StatusOr Parser::ParseUseTreeEntry(Bindings& bindings) { @@ -2409,12 +2412,22 @@ absl::StatusOr Parser::ParseTermLhs(Bindings& outer_bindings, XLS_ASSIGN_OR_RETURN( lhs, ParseParentheticalOrCastLhs(outer_bindings, start_pos)); } else if (peek->IsKeyword(Keyword::kMatch)) { // Match expression. - XLS_ASSIGN_OR_RETURN(lhs, ParseMatch(outer_bindings)); + XLS_ASSIGN_OR_RETURN(lhs, ParseMatch(outer_bindings, false)); } else if (peek->kind() == TokenKind::kOBrack) { // Array expression. XLS_ASSIGN_OR_RETURN(lhs, ParseArray(outer_bindings)); } else if (peek->IsKeyword(Keyword::kIf)) { // Conditional expression. XLS_ASSIGN_OR_RETURN(lhs, ParseRangeExpression(outer_bindings, kNoRestrictions)); + } else if (peek->IsKeyword(Keyword::kConst)) { + XLS_ASSIGN_OR_RETURN(const Token* peek_1, PeekToken(1)); + if (peek_1->IsKeyword(Keyword::kMatch)) { // constexpr match + XLS_ASSIGN_OR_RETURN(lhs, ParseMatch(outer_bindings, true)); + } else { + return ParseErrorStatus( + peek_1->span(), + absl::StrFormat("Expected start of a const expression; got: %s", + peek_1->ToErrorString())); + } } else { return ParseErrorStatus( peek->span(), @@ -4126,8 +4139,7 @@ absl::StatusOr Parser::ParseBlockExpression( ParseTypeAlias(GetPos(), /*is_public=*/false, block_bindings)); stmts.push_back(module_->Make(alias)); last_expr_had_trailing_semi = true; - } else if (peek->IsKeyword(Keyword::kLet) || - peek->IsKeyword(Keyword::kConst)) { + } else if (peek->IsKeyword(Keyword::kLet)) { XLS_ASSIGN_OR_RETURN(Let * let, ParseLet(block_bindings)); stmts.push_back(module_->Make(let)); last_expr_had_trailing_semi = true; @@ -4137,6 +4149,19 @@ absl::StatusOr Parser::ParseBlockExpression( stmts.push_back(module_->Make(const_assert)); last_expr_had_trailing_semi = true; } else { + // const can be a constant or a modifier + if (peek->IsKeyword(Keyword::kConst)) { + XLS_ASSIGN_OR_RETURN(const Token* peek_1, PeekToken(1)); + // handle the case when const is a regular constant, otherwise + // it is a modifier and should be handled as expression with bindings + if (!peek_1->IsKeyword(Keyword::kMatch)) { + XLS_ASSIGN_OR_RETURN(Let * let, ParseLet(block_bindings)); + stmts.push_back(module_->Make(let)); + last_expr_had_trailing_semi = true; + continue; + } + } + VLOG(5) << "ParseBlockExpression; parsing expression with bindings: [" << absl::StrJoin(block_bindings.GetLocalBindings(), ", ") << "]"; XLS_ASSIGN_OR_RETURN(Expr * e, ParseExpression(block_bindings)); diff --git a/xls/dslx/frontend/parser.h b/xls/dslx/frontend/parser.h index 645dc66e89..c2efdece67 100644 --- a/xls/dslx/frontend/parser.h +++ b/xls/dslx/frontend/parser.h @@ -531,7 +531,7 @@ class Parser : public TokenParser { bool within_tuple_pattern); // Parses a match expression. - absl::StatusOr ParseMatch(Bindings& bindings); + absl::StatusOr ParseMatch(Bindings& bindings, bool is_const); // Parses a channel declaration. absl::StatusOr ParseChannelDecl( diff --git a/xls/dslx/frontend/token_parser.h b/xls/dslx/frontend/token_parser.h index f6cd3e7aed..c94b4e85b2 100644 --- a/xls/dslx/frontend/token_parser.h +++ b/xls/dslx/frontend/token_parser.h @@ -117,12 +117,14 @@ class TokenParser { // token is returned. // // Returns an error status in the case of things like scan errors. - absl::StatusOr PeekToken() { - if (index_ >= tokens_.size()) { - XLS_ASSIGN_OR_RETURN(Token token, scanner_->Pop()); - tokens_.push_back(std::make_unique(std::move(token))); + absl::StatusOr PeekToken(int skip_count = 0) { + if (index_ + skip_count >= tokens_.size()) { + for (int i = 0; i < skip_count + 1; ++i) { + XLS_ASSIGN_OR_RETURN(Token token, scanner_->Pop()); + tokens_.push_back(std::make_unique(std::move(token))); + } } - return tokens_[index_].get(); + return tokens_[index_ + skip_count].get(); } // Returns a token that has been popped destructively from the token stream. diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc index f6726b8d66..9276b222f1 100644 --- a/xls/dslx/ir_convert/function_converter.cc +++ b/xls/dslx/ir_convert/function_converter.cc @@ -1288,6 +1288,106 @@ absl::Status FunctionConverter::HandleBuiltinWideningCast( return absl::OkStatus(); } +absl::StatusOr FunctionConverter::ConstMatchWhichArm(const Match* node) { + ParametricEnv bindings(parametric_env_map_); + std::vector construct_match_arms; + construct_match_arms.reserve(node->arms().size()); + + // Construct a new Match object, which has the same `matched` and `patterns` as the original. + // Create a new expression for each arm + // so that the whole match can be evaluated to know which arm is selected. + for (int64_t i = 0; i < node->arms().size(); ++i) { + MatchArm* arm = node->arms()[i]; + // create a new expression for this arm - a number with the index of the arm. + Number* expr = module_->Make( + Span::Fake(), + absl::StrFormat("%d", i), + NumberKind::kOther, + CreateU32Annotation(*module_, Span::Fake())); + + current_type_info_->SetItem(expr, BitsType::MakeU32()); + current_type_info_->NoteConstExpr(expr, InterpValue::MakeUBits(32, i)); + + construct_match_arms.push_back( + module_->Make(arm->span(), arm->patterns(), expr)); + } + Match *fake_match = module_->Make(node->span(), node->matched(), construct_match_arms); + + XLS_ASSIGN_OR_RETURN(InterpValue interp_match, + ConstexprEvaluator::EvaluateToValue( + import_data_, current_type_info_, + kNoWarningCollector, bindings, fake_match)); + + XLS_ASSIGN_OR_RETURN(uint64_t arm_id, interp_match.GetBitValueUnsigned()); + + return arm_id; +} + +bool pattern_has_namedef(const NameDefTree* pattern) { + if (pattern->is_leaf()) { + return absl::visit( + Visitor{ + [&](NameDef* name_def) -> bool { + return true; + }, + [&](AstNode* node) -> bool { + return false; + }, + }, + pattern->leaf()); + } else { + return std::any_of(pattern->nodes().begin(), pattern->nodes().end(), pattern_has_namedef); + } +} + +bool patterns_have_namedef(const std::vector& patterns) { + return std::any_of(patterns.begin(), patterns.end(), pattern_has_namedef); +} + +absl::Status FunctionConverter::HandleConstMatch(const Match* node) { + ParametricEnv bindings(parametric_env_map_); + std::optional matched_val; + + + XLS_RETURN_IF_ERROR(Visit(node->matched())); + XLS_ASSIGN_OR_RETURN(BValue matched, Use(node->matched())); + XLS_ASSIGN_OR_RETURN(InterpValue matched_const, + ConstexprEvaluator::EvaluateToValue( + import_data_, current_type_info_, + kNoWarningCollector, bindings, node->matched())); + + XLS_ASSIGN_OR_RETURN(matched_val, InterpValueToValue(matched_const)); + XLS_ASSIGN_OR_RETURN(std::unique_ptr matched_type, + ResolveType(node->matched())); + + XLS_ASSIGN_OR_RETURN(uint64_t arm_id, ConstMatchWhichArm(node)); + + const MatchArm* arm = node->arms()[arm_id]; + bool has_namedef = patterns_have_namedef(arm->patterns()); + if (!has_namedef) { // simple case when the arm's expression can be converted to IR + XLS_RETURN_IF_ERROR(Visit(node->arms()[arm_id]->expr())); + XLS_ASSIGN_OR_RETURN(BValue bval, Use(node->arms()[arm_id]->expr())); + SetNodeToIr(node, bval); + return absl::OkStatus(); + } else { } + + BValue final_val = function_builder_->Literal(matched_val.value()); + std::vector arm_selectors; + + for (NameDefTree* pattern : arm->patterns()) { + XLS_ASSIGN_OR_RETURN(BValue selector, + HandleMatcher(pattern, final_val, *matched_type)); + XLS_RET_CHECK(selector.valid()); + arm_selectors.push_back(selector); + } + + XLS_RETURN_IF_ERROR(Visit(arm->expr())); + XLS_ASSIGN_OR_RETURN(BValue arm_rhs_value, Use(arm->expr())); + SetNodeToIr(node, arm_rhs_value); + + return absl::OkStatus(); +} + absl::Status FunctionConverter::HandleMatch(const Match* node) { if (node->arms().empty()) { return IrConversionErrorStatus( @@ -1297,6 +1397,10 @@ absl::Status FunctionConverter::HandleMatch(const Match* node) { file_table()); } + if (node->IsConst()) { + return HandleConstMatch(node); + } + XLS_RETURN_IF_ERROR(Visit(node->matched())); XLS_ASSIGN_OR_RETURN(BValue matched, Use(node->matched())); XLS_ASSIGN_OR_RETURN(std::unique_ptr matched_type, diff --git a/xls/dslx/ir_convert/function_converter.h b/xls/dslx/ir_convert/function_converter.h index 076a6a2d1e..5cf656073b 100644 --- a/xls/dslx/ir_convert/function_converter.h +++ b/xls/dslx/ir_convert/function_converter.h @@ -347,6 +347,9 @@ class FunctionConverter { ->GetInvocationCalleeBindings(invocation, key); } + // Helper to evaluate which arm of a const match should be used. + absl::StatusOr ConstMatchWhichArm(const Match* node); + // Helpers for HandleBinop(). absl::Status HandleConcat(const Binop* node, BValue lhs, BValue rhs); absl::Status HandleEq(const Binop* node, BValue lhs, BValue rhs); @@ -423,6 +426,7 @@ class FunctionConverter { absl::Status HandleLet(const Let* node); absl::Status HandleLetChannelDecl(const Let* node); absl::Status HandleMatch(const Match* node); + absl::Status HandleConstMatch(const Match* node); absl::Status HandleRange(const Range* node); absl::Status HandleSplatStructInstance(const SplatStructInstance* node); absl::Status HandleStatement(const Statement* node); diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc index 728241bb49..bf8472686d 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc @@ -3339,6 +3339,20 @@ fn repro(x: u3) -> u2 { TypecheckSucceeds(HasNodeWithType("upper", "uN[2]"))); } +TEST(TypecheckV2Test, ConstMatch) { + EXPECT_THAT(R"( +fn main(a: u32, b: u32) -> u32 { + const A = true; + let retval = const match A { + true => a, + false => b + }; + retval +} + )", + TypecheckSucceeds(HasNodeWithType("retval", "uN[32]"))); + +} TEST(TypecheckV2Test, MatchMismatch) { EXPECT_THAT(R"( const X = u32:1; From 7a8fd8fe89e69b590d509544094a6985efdf807f Mon Sep 17 00:00:00 2001 From: Wojciech Sipak Date: Wed, 3 Dec 2025 18:12:23 +0100 Subject: [PATCH 2/2] add const match example Internal-tag: #[88020] --- xls/examples/BUILD | 27 +++++++++++ xls/examples/const_match.x | 99 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 xls/examples/const_match.x diff --git a/xls/examples/BUILD b/xls/examples/BUILD index 24613d69f4..e64486c428 100644 --- a/xls/examples/BUILD +++ b/xls/examples/BUILD @@ -1396,3 +1396,30 @@ build_test( name = "xls_pipeline_build_test", targets = [":xls_pipeline"], ) + +xls_dslx_library( + name = "const_match_dslx", + srcs = ["const_match.x"], +) + +xls_dslx_test( + name = "const_match_test", + size = "small", + srcs = ["const_match.x"], + dslx_test_args = {"compare": "jit"}, +) + +xls_dslx_ir( + name = "const_match_ir", + dslx_top = "main", + ir_conv_args = {"lower_to_proc_scoped_channels": "true"}, + ir_file = "const_match.ir", + library = ":const_match_dslx", +) + +xls_dslx_opt_ir( + name = "const_match_opt_ir", + srcs = ["const_match.x"], + ir_conv_args = {"lower_to_proc_scoped_channels": "true"}, + dslx_top = "main", +) diff --git a/xls/examples/const_match.x b/xls/examples/const_match.x new file mode 100644 index 0000000000..8955ff6c2c --- /dev/null +++ b/xls/examples/const_match.x @@ -0,0 +1,99 @@ +#![feature(type_inference_v2)] + +fn match_bool(a: u32, b: u32) -> u32 { + const A = true; + let result = const match A { + true => a, + false => b + }; + result +} + +fn match_bool_param(a: u32, b: u32) -> u32 { + const match A { + true => a, + false => b + } +} + +fn matcher_types() -> u32 { + const B = u32:9; + const match A { + u32:0..u32:3 => u32:0, + u32:4 => u32:1, + u32:5 | u32:6 | u32:7 => u32:2, + // TODO use colon reference + u32:8 => u32:3, + B => u32:800, + _ => u32:1000, + } +} + +type ARG = (u32, (u32, u32, u32)); +fn matcher_tuple() -> u32 { + const match A { + (u32:1, (u32:3, ..)) => u32:0, + (u32:2, (u32:2, _, _)) => u32:1, + (u32:1, ..) => u32:2, + (u32:3, (x, u32:1, u32:1)) => x, + _ => u32:4, + } +} + +#[test] +fn test_all_uniform_types() { + assert_eq(match_bool(u32:1, u32:2), u32:1); + assert_eq(match_bool_param(u32:1, u32:2), u32:1); + assert_eq(match_bool_param(u32:1, u32:2), u32:2); + assert_eq(matcher_types(), u32:0); + assert_eq(matcher_types(), u32:1); + assert_eq(matcher_types(), u32:2); + assert_eq(matcher_types(), u32:3); + assert_eq(matcher_types(), u32:800); + assert_eq(matcher_types(), u32:1000); + let tuple1 = (u32:1, (u32:3, u32:0, u32:0)); + let tuple2 = (u32:2, (u32:2, u32:0, u32:0)); + let tuple3 = (u32:1, (u32:4, u32:0, u32:0)); + let tuple4 = (u32:3, (u32:5, u32:1, u32:1)); + let tuple5 = (u32:7, (u32:5, u32:1, u32:1)); + assert_eq(matcher_tuple(), u32:0); + assert_eq(matcher_tuple(), u32:1); + assert_eq(matcher_tuple(), u32:2); + assert_eq(matcher_tuple(), u32:5); + assert_eq(matcher_tuple(), u32:4); +} + +fn main() -> (u32, u32, u32, u32, u32, u32, u32, u32, u32, u32, u32, u32, u32, u32) { + let tuple1 = (u32:1, (u32:3, u32:0, u32:0)); + let tuple2 = (u32:2, (u32:2, u32:0, u32:0)); + let tuple3 = (u32:1, (u32:4, u32:0, u32:0)); + let tuple4 = (u32:3, (u32:5, u32:1, u32:1)); + let tuple5 = (u32:7, (u32:5, u32:1, u32:1)); + ( + match_bool(u32:1, u32:2), + match_bool_param(u32:1, u32:2), + match_bool_param(u32:1, u32:2), + matcher_types(), + matcher_types(), + matcher_types(), + matcher_types(), + matcher_types(), + matcher_types(), + matcher_tuple(), + matcher_tuple(), + matcher_tuple(), + matcher_tuple(), + matcher_tuple(), + ) +} + +// step 2: do not typecheck the unused branch +// fn foo() -> u32 { +// const B1 = u32:1; +// const B2 = u16:2; +// let result = const match A { +// true => B1, +// _ => B2 +// }; +// result as u32 +// }