diff --git a/Cargo.lock b/Cargo.lock index 2235689..856c810 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -918,7 +918,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e0b1fab2ae45819af2d0731d60f2afe17227ebb1a1538a236da84c93e9a60162" dependencies = [ "dispatch2", - "nix 0.31.1", + "nix 0.31.2", "windows-sys 0.61.2", ] @@ -1601,9 +1601,9 @@ checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "js-sys" -version = "0.3.90" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14dc6f6450b3f6d4ed5b16327f38fed626d375a886159ca555bd7822c0c3a5a6" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" dependencies = [ "once_cell", "wasm-bindgen", @@ -1885,9 +1885,9 @@ checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" [[package]] name = "libc" -version = "0.2.180" +version = "0.2.182" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" +checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" [[package]] name = "libm" @@ -1897,11 +1897,10 @@ checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libredox" -version = "0.1.12" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" +checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" dependencies = [ - "bitflags 2.11.0", "libc", ] @@ -1927,9 +1926,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" [[package]] name = "lock_api" @@ -2101,9 +2100,9 @@ dependencies = [ [[package]] name = "nix" -version = "0.31.1" +version = "0.31.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225e7cfe711e0ba79a68baeddb2982723e4235247aefce1482f2f16c27865b66" +checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" dependencies = [ "bitflags 2.11.0", "cfg-if", @@ -2356,9 +2355,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" [[package]] name = "pin-utils" @@ -2368,9 +2367,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "piper" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96c8c490f422ef9a4efd2cb5b42b76c8613d7e7dfc1caf667b8a3350a5acc066" +checksum = "c835479a4443ded371d6c535cbfd8d31ad92c5d23ae9770a61bc155e4992a3c1" dependencies = [ "atomic-waker", "fastrand", @@ -2781,9 +2780,9 @@ dependencies = [ [[package]] name = "rustix" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" dependencies = [ "bitflags 2.11.0", "errno", @@ -3448,9 +3447,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.113" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60722a937f594b7fde9adb894d7c092fc1bb6612897c46368d18e7a20208eff2" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" dependencies = [ "cfg-if", "once_cell", @@ -3463,9 +3462,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.63" +version = "0.4.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a89f4650b770e4521aa6573724e2aed4704372151bd0de9d16a3bbabb87441a" +checksum = "e9c5522b3a28661442748e09d40924dfb9ca614b21c00d3fd135720e48b67db8" dependencies = [ "cfg-if", "futures-util", @@ -3477,9 +3476,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.113" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fac8c6395094b6b91c4af293f4c79371c163f9a6f56184d2c9a85f5a95f3950" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3487,9 +3486,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.113" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab3fabce6159dc20728033842636887e4877688ae94382766e00b180abac9d60" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" dependencies = [ "bumpalo", "proc-macro2", @@ -3500,9 +3499,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.113" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de0e091bdb824da87dc01d967388880d017a0a9bc4f3bdc0d86ee9f9336e3bb5" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" dependencies = [ "unicode-ident", ] @@ -3543,9 +3542,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.90" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "705eceb4ce901230f8625bd1d665128056ccbe4b7408faa625eec1ba80f59a97" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/docs/DECL.md b/docs/DECL.md new file mode 100644 index 0000000..62f1a7d --- /dev/null +++ b/docs/DECL.md @@ -0,0 +1,385 @@ +``` +Title: Covenant Declarations +Status: Draft +Created: 2026-02-23 +``` + +# Covenant Declarations + +## Summary + +This document describes a minimal declaration API for covenant patterns, where users declare policy functions and the compiler generates covenant entrypoints/wrappers. + +Context: today these patterns are written manually with `OpAuth*`/`OpCov*` plus `readInputState`/`validateOutputState`. The goal here is to standardize the pattern and remove user boilerplate. + +Scope: syntax + semantics only. This is not claiming implementation is finalized. + +1. Dev writes only a transition/verification policy function and annotates it with a covenant macro. +2. Entrypoint(s) are inferred by the compiler from that function’s shape. +3. State is treated as one implicit `State` struct synthesized from all contract fields: + * `1:1` uses `State prev_state` / `State new_state` + * `1:N` uses `State prev_state` / `State[] new_states` + * `N:M` uses `State[] prev_states` / `State[] new_states` +4. `1:N` auth always binds to `this.activeInputIndex`; `N:M` cov id is always `OpInputCovenantId(this.activeInputIndex)`. + +## Macro surface + +Only policy functions are annotated. + +Canonical form: + +```js +#[covenant(binding = auth|cov, from = X, to = Y, mode = verification|transition, groups = multiple|single, termination = disallowed|allowed)] +``` + +Minimal common form (defaults inferred): + +```js +#[covenant(from = X, to = Y)] +``` + +Sugar (aliases over `from/to`): + +```js +#[covenant.singleton] // == #[covenant(from = 1, to = 1)] +#[covenant.fanout(to = Y)] // == #[covenant(from = 1, to = Y)] +``` + +Rules: + +1. `binding = auth` means auth-context lowering (`OpAuth*`). +2. `binding = cov` means shared covenant-context lowering (`OpCov*`). +3. `groups` applies to both bindings. +4. Defaults: `auth -> groups = multiple`, `cov -> groups = single`. +5. If `binding` is omitted: `from == 1 -> auth`, otherwise `cov`. +6. If `mode` is omitted: no returns -> `verification`, has returns -> `transition`. +7. `binding = auth` with `from > 1` is compile error. +8. `binding = cov` with `groups = multiple` is compile error in v1. +9. `termination` is valid only for singleton transition (`from = 1, to = 1, mode = transition`); there it defaults to `disallowed`, and using it elsewhere is a compile error. + +### 1:N verification + +```js +#[covenant(binding = auth, from = 1, to = max_outs, mode = verification, groups = multiple)] +function split(State prev_state, State[] new_states, sig[] approvals) { + // require(...) rules +} +``` + +```js +#[covenant(binding = auth, from = 1, to = max_outs, mode = verification, groups = single)] +function split_single_group(State prev_state, State[] new_states, sig[] approvals) { + // require(...) rules +} +``` + +### N:M verification + +```js +contract C(int max_ins, int max_outs) { + int amount; + byte[32] owner; + int round; + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = verification)] + function transition_ok( + State[] prev_states, + State[] new_states, + sig leader_sig + ) { + // require(...) rules + } +} +``` + +### N:M transition + +```js +#[covenant(binding = cov, from = max_ins, to = max_outs, mode = transition)] +function transition(State[] prev_states, int fee) : (State[] new_states) { + // compute and return new_states +} +``` + +### 1:1 transition + +```js +#[covenant(binding = auth, from = 1, to = 1, mode = transition)] +function roll(State prev_state, byte[32] block_hash) : (State new_state) { + // compute and return next state +} +``` + +## Semantics + +### Verification mode + +Verification mode is the default convenience mode. + +1. Generated entrypoint args are `new_states` plus optional extra call args. +2. Wrapper reads prior state from tx context (`prev_state` or `prev_states`) and calls the policy verification with `(prev_state(s), new_states, call_args...)`. +3. Wrapper validates each output with `validateOutputState(...)` against `new_states`. + +Current compiler shape (`mode = verification`, both bindings): + +1. Policy params must begin with prior-state parameters: + `binding = auth` -> `State prev_state` + `binding = cov` -> `State[] prev_states` +2. Then comes `State[] new_states`. +3. Remaining params are optional extra call args. +4. Generated entrypoint exposes only `new_states` + extra args (not prior-state params). +5. Wrapper reconstructs/injects prior state from tx context: + `auth` from current input state, `cov` from covenant input set via `readInputState(...)`. + +### Transition mode + +Transition mode allows extra call args (`fee` above, etc.) and the policy computes `new_states`. + +Security note (both modes): extra call args (beyond state values validated on outputs) are not directly committed by tx structure. Compiler/runtime must enforce a commitment story and determinism for them. + +Current compiler shape (`mode = transition`, both bindings): + +1. Policy params must begin with prior-state parameters: + `binding = auth` -> `State prev_state` + `binding = cov` -> `State[] prev_states` +2. Remaining params are optional extra call args. +3. Compiler enforces this prefix exactly; invalid prior-state parameter types are compile errors. +4. Wrapper sources prior state from tx context according to binding. +5. Current ABI behavior: + `auth` entrypoint exposes only extra call args. + `cov` leader entrypoint exposes `new_states` or extra call args according to mode, while wrapper also enforces covenant structure checks. + +Cardinality in transition mode: + +1. Single-state return shape -> exact one continuation (`out_count == 1`) with direct `validateOutputState(...)` (no loop). +2. `State[]` return shape -> exact cardinality by returned length (`out_count == returned_len`) and per-output validation in a loop. +3. For singleton (`from=1,to=1`), `State[]` returns are rejected by default. +4. Singleton `State[]` returns are allowed only with `termination = allowed`; this enables explicit zero-or-one continuation. + +### Singleton termination opt-in + +Default singleton transition is strict continuation: + +```js +#[covenant.singleton(mode = transition)] +function bump(State prev_state, int delta) : (State) { + return({ value: prev_state.value + delta }); +} +``` + +Termination-enabled singleton transition: + +```js +#[covenant.singleton(mode = transition, termination = allowed)] +function bump_or_terminate(State prev_state, State[] next_states) : (State[]) { + // [] => terminate + // [x] => continue with one successor + return(next_states); +} +``` + +### `groups` + +`binding = auth, groups = multiple` (default): no global uniqueness check across the tx. + +`binding = auth, groups = single`: enforce that current covenant id has a single continuation auth group in this tx: + +```js +byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); +require(OpCovOutCount(cov_id) == OpAuthOutputCount(this.activeInputIndex)); +``` + +No explicit `cov_id != false` check is needed; `OpCovOutCount(cov_id)` fails if `cov_id` is not valid covenant-id data. + +`binding = cov`: `groups = single` only (v1). `groups = multiple` is rejected. + +## Inferred entrypoints + +Given policy function `f`: + +1. `1:N` generates one entrypoint: + + * `__f` +2. `N:M` generates two entrypoints: + + * `__leader_f` + * `__delegate_f` + +`__delegate_f` does not call policy. It enforces delegation-path invariants only. + +## Complex example + +### Source (user writes this only) + +```js +pragma silverscript ^0.1.0; + +contract VaultNM( + int max_ins, + int max_outs, + int init_amount, + byte[32] init_owner, + int init_round +) { + int amount = init_amount; + byte[32] owner = init_owner; + int round = init_round; + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = verification)] + function conserve_and_bump(State[] prev_states, State[] new_states, sig leader_sig) { + require(new_states.length > 0); + + int in_sum = 0; + for(i, 0, max_ins) { + if (i < prev_states.length) { + in_sum = in_sum + prev_states[i].amount; + } + } + + int out_sum = 0; + for(i, 0, max_outs) { + if (i < new_states.length) { + out_sum = out_sum + new_states[i].amount; + + // all outputs keep same owner as leader input + require(new_states[i].owner == prev_states[0].owner); + + // round must advance exactly by 1 + require(new_states[i].round == prev_states[0].round + 1); + } + } + + require(in_sum >= out_sum); + } +} +``` + +### Generated code (conceptual; policy body unchanged) + +```js +pragma silverscript ^0.1.0; + +contract VaultNM( + int max_ins, + int max_outs, + int init_amount, + byte[32] init_owner, + int init_round +) { + int amount = init_amount; + byte[32] owner = init_owner; + int round = init_round; + + // Compiler-lowered policy function (renamed to avoid collision with generated entrypoints) + // same body as source: + function __covenant_policy_conserve_and_bump(State[] prev_states, State[] new_states, sig leader_sig) { ... } + + // Generated for N:M leader path + entrypoint function __leader_conserve_and_bump(State[] new_states, sig leader_sig) { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + int in_count = OpCovInputCount(cov_id); + int out_count = OpCovOutCount(cov_id); + require(out_count == new_states.length); + + // k=0 must execute leader path + require(OpCovInputIdx(cov_id, 0) == this.activeInputIndex); + + State[] prev_states = []; + for(k, 0, max_ins) { + if (k < in_count) { + int in_idx = OpCovInputIdx(cov_id, k); + { + amount: int p_amount, + owner: byte[32] p_owner, + round: int p_round + } = readInputState(in_idx); + + prev_states.push({ + amount: p_amount, + owner: p_owner, + round: p_round + }); + } + } + + __covenant_policy_conserve_and_bump(prev_states, new_states, leader_sig); + + for(k, 0, max_outs) { + if (k < out_count) { + int out_idx = OpCovOutputIdx(cov_id, k); + validateOutputState(out_idx, { + amount: new_states[k].amount, + owner: new_states[k].owner, + round: new_states[k].round + }); + } + } + } + + // Generated for N:M delegate path + entrypoint function __delegate_conserve_and_bump() { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + // delegate path must not be leader + require(OpCovInputIdx(cov_id, 0) != this.activeInputIndex); + } +} +``` + +## Additional example: 1:1 transition with `OpChainblockSeqCommit` + +State is `seqcommit`; call arg is `block_hash`. + +### Source (user writes this only) + +```js +pragma silverscript ^0.1.0; + +contract SeqCommitMirror(byte[32] init_seqcommit) { + byte[32] seqcommit = init_seqcommit; + + #[covenant(binding = auth, from = 1, to = 1, mode = transition)] + function roll_seqcommit(State prev_state, byte[32] block_hash) : (State new_state) { + byte[32] new_seqcommit = OpChainblockSeqCommit(block_hash); + return { + seqcommit: new_seqcommit + }; + } +} +``` + +### Generated code (conceptual; policy body unchanged) + +```js +pragma silverscript ^0.1.0; + +contract SeqCommitMirror(byte[32] init_seqcommit) { + byte[32] seqcommit = init_seqcommit; + + // Compiler-lowered policy function (renamed to avoid entrypoint name collision) + // same body as source: + function __covenant_policy_roll_seqcommit(State prev_state, byte[32] block_hash) : (State new_state) { ... } + + // Generated 1:1 covenant entrypoint + entrypoint function __roll_seqcommit(byte[32] block_hash) { + State prev_state = { + seqcommit: seqcommit + }; + + (State new_state) = __covenant_policy_roll_seqcommit(prev_state, block_hash); + + require(OpAuthOutputCount(this.activeInputIndex) == 1); + int out_idx = OpAuthOutputIdx(this.activeInputIndex, 0); + validateOutputState(out_idx, { + seqcommit: new_state.seqcommit + }); + } +} +``` + +## Implementation notes + +1. `State` is an implicit compiler type synthesized from contract fields. +2. Internally the compiler can lower `State`/`State[]` into any representation; this doc only fixes the user-facing API. +3. Existing `readInputState`/`validateOutputState` remain the codegen backbone. +4. v1 keeps one `N:M` transition group per tx. diff --git a/TUTORIAL.md b/docs/TUTORIAL.md similarity index 100% rename from TUTORIAL.md rename to docs/TUTORIAL.md diff --git a/silverscript-lang/src/ast.rs b/silverscript-lang/src/ast.rs index 504f7d6..da65831 100644 --- a/silverscript-lang/src/ast.rs +++ b/silverscript-lang/src/ast.rs @@ -8,13 +8,15 @@ use crate::errors::CompilerError; use crate::parser::{Rule, parse_source_file, parse_type_name as parse_type_name_rule}; pub use crate::span::{Span, SpanUtils}; +pub mod visit; + #[derive(Debug, Clone)] struct Identifier<'i> { name: String, span: Span<'i>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ContractAst<'i> { pub name: String, pub params: Vec>, @@ -30,7 +32,7 @@ pub struct ContractAst<'i> { pub name_span: Span<'i>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct StructAst<'i> { pub name: String, pub fields: Vec>, @@ -40,7 +42,7 @@ pub struct StructAst<'i> { pub name_span: Span<'i>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct StructFieldAst<'i> { pub type_ref: TypeRef, pub name: String, @@ -65,7 +67,7 @@ pub fn format_contract_ast(contract: &ContractAst<'_>) -> String { formatter.finish() } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ContractFieldAst<'i> { pub type_ref: TypeRef, pub name: String, @@ -78,9 +80,11 @@ pub struct ContractFieldAst<'i> { pub name_span: Span<'i>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct FunctionAst<'i> { pub name: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub attributes: Vec>, pub params: Vec>, pub entrypoint: bool, #[serde(default)] @@ -96,7 +100,28 @@ pub struct FunctionAst<'i> { pub body_span: Span<'i>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct FunctionAttributeAst<'i> { + pub path: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub args: Vec>, + #[serde(skip_deserializing)] + pub span: Span<'i>, + #[serde(skip_deserializing)] + pub path_spans: Vec>, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct FunctionAttributeArgAst<'i> { + pub name: String, + pub expr: Expr<'i>, + #[serde(skip_deserializing)] + pub span: Span<'i>, + #[serde(skip_deserializing)] + pub name_span: Span<'i>, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ParamAst<'i> { pub type_ref: TypeRef, pub name: String, @@ -108,7 +133,7 @@ pub struct ParamAst<'i> { pub name_span: Span<'i>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct StateBindingAst<'i> { pub field_name: String, pub type_ref: TypeRef, @@ -224,7 +249,7 @@ impl TypeRef { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(tag = "kind", content = "data", rename_all = "snake_case")] pub enum Statement<'i> { VariableDefinition { @@ -388,14 +413,14 @@ impl<'i> Statement<'i> { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(tag = "kind", content = "data", rename_all = "snake_case")] pub enum ConsoleArg<'i> { Identifier(String, #[serde(skip_deserializing)] Span<'i>), Literal(Expr<'i>), } -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum TimeVar { ThisAge, @@ -651,7 +676,7 @@ pub enum IntrospectionKind { OutputScriptPubKey, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ConstantAst<'i> { pub type_ref: TypeRef, pub name: String, @@ -1214,6 +1239,9 @@ fn parse_struct_definition<'i>(pair: Pair<'i, Rule>) -> Result, Co let mut inner = pair.into_inner(); let name_pair = inner.next().ok_or_else(|| CompilerError::Unsupported("missing struct name".to_string()))?; let Identifier { name, span: name_span } = parse_identifier(name_pair)?; + if name == "State" { + return Err(CompilerError::Unsupported("'State' is a reserved struct name".to_string()).with_span(&span)); + } let mut fields = Vec::new(); for field_pair in inner { if field_pair.as_rule() == Rule::struct_field_definition { @@ -1237,6 +1265,15 @@ fn parse_struct_field_definition<'i>(pair: Pair<'i, Rule>) -> Result(pair: Pair<'i, Rule>) -> Result, CompilerError> { let span = Span::from(pair.as_span()); let mut inner = pair.into_inner(); + let mut attributes = Vec::new(); + + while let Some(next) = inner.peek() { + if next.as_rule() != Rule::function_attribute { + break; + } + let attr_pair = inner.next().expect("checked"); + attributes.push(parse_function_attribute(attr_pair)?); + } let first = inner.next().ok_or_else(|| CompilerError::Unsupported("missing function name".to_string()))?; let (entrypoint, name_pair) = if first.as_rule() == Rule::entrypoint { @@ -1275,7 +1312,71 @@ fn parse_function_definition<'i>(pair: Pair<'i, Rule>) -> Result } let body_span = body_span.unwrap_or(span); - Ok(FunctionAst { name, entrypoint, params, return_types, return_type_spans, body, span, name_span, body_span }) + Ok(FunctionAst { name, attributes, entrypoint, params, return_types, return_type_spans, body, span, name_span, body_span }) +} + +fn parse_function_attribute<'i>(pair: Pair<'i, Rule>) -> Result, CompilerError> { + let span = Span::from(pair.as_span()); + let mut inner = pair.into_inner(); + + let path_pair = inner.next().ok_or_else(|| CompilerError::Unsupported("missing attribute path".to_string()))?; + let (path, path_spans) = parse_attribute_path(path_pair)?; + + let mut args = Vec::new(); + if let Some(args_pair) = inner.next() { + args = parse_attribute_args(args_pair)?; + } + + Ok(FunctionAttributeAst { path, args, span, path_spans }) +} + +fn parse_attribute_path<'i>(pair: Pair<'i, Rule>) -> Result<(Vec, Vec>), CompilerError> { + if pair.as_rule() != Rule::attribute_path { + return Err(CompilerError::Unsupported("expected attribute path".to_string())); + } + let mut path = Vec::new(); + let mut spans = Vec::new(); + for inner in pair.into_inner() { + if inner.as_rule() != Rule::Identifier { + continue; + } + path.push(inner.as_str().to_string()); + spans.push(Span::from(inner.as_span())); + } + if path.is_empty() { + return Err(CompilerError::Unsupported("attribute path must not be empty".to_string())); + } + Ok((path, spans)) +} + +fn parse_attribute_args<'i>(pair: Pair<'i, Rule>) -> Result>, CompilerError> { + if pair.as_rule() != Rule::attribute_args { + return Err(CompilerError::Unsupported("expected attribute arguments".to_string())); + } + let mut out = Vec::new(); + for inner in pair.into_inner() { + if inner.as_rule() != Rule::attribute_arg { + continue; + } + out.push(parse_attribute_arg(inner)?); + } + Ok(out) +} + +fn parse_attribute_arg<'i>(pair: Pair<'i, Rule>) -> Result, CompilerError> { + let span = Span::from(pair.as_span()); + if pair.as_rule() != Rule::attribute_arg { + return Err(CompilerError::Unsupported("expected attribute argument".to_string())); + } + let mut inner = pair.into_inner(); + let name_pair = inner.next().ok_or_else(|| CompilerError::Unsupported("missing attribute argument name".to_string()))?; + let expr_pair = inner.next().ok_or_else(|| CompilerError::Unsupported("missing attribute argument value".to_string()))?; + + let name = name_pair.as_str().to_string(); + let name_span = Span::from(name_pair.as_span()); + let expr = parse_expression(expr_pair)?; + + Ok(FunctionAttributeArgAst { name, expr, span, name_span }) } fn parse_constant_definition<'i>(pair: Pair<'i, Rule>) -> Result, CompilerError> { diff --git a/silverscript-lang/src/ast/visit.rs b/silverscript-lang/src/ast/visit.rs new file mode 100644 index 0000000..3bb939b --- /dev/null +++ b/silverscript-lang/src/ast/visit.rs @@ -0,0 +1,394 @@ +use super::{ + ConsoleArg, ConstantAst, ContractAst, ContractFieldAst, Expr, ExprKind, FunctionAst, FunctionAttributeArgAst, + FunctionAttributeAst, ParamAst, StateBindingAst, Statement, +}; +use crate::span::Span; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum NameKind { + Contract, + ContractField, + Constant, + Function, + Parameter, + AttributePathSegment, + AttributeArg, + LocalBinding, + AssignmentTarget, + LoopBinding, + StateField, + StateBinding, + CallTarget, + IdentifierExpr, + ConsoleIdentifier, +} + +pub trait AstVisitorMut<'i> { + fn visit_name(&mut self, _name: &mut String, _kind: NameKind) {} + fn visit_span(&mut self, _span: &mut Span<'i>) {} + + fn visit_contract(&mut self, contract: &mut ContractAst<'i>) { + walk_contract_mut(self, contract); + } + + fn visit_contract_field(&mut self, field: &mut ContractFieldAst<'i>) { + walk_contract_field_mut(self, field); + } + + fn visit_constant(&mut self, constant: &mut ConstantAst<'i>) { + walk_constant_mut(self, constant); + } + + fn visit_function(&mut self, function: &mut FunctionAst<'i>) { + walk_function_mut(self, function); + } + + fn visit_function_attribute(&mut self, attribute: &mut FunctionAttributeAst<'i>) { + walk_function_attribute_mut(self, attribute); + } + + fn visit_function_attribute_arg(&mut self, arg: &mut FunctionAttributeArgAst<'i>) { + walk_function_attribute_arg_mut(self, arg); + } + + fn visit_param(&mut self, param: &mut ParamAst<'i>) { + walk_param_mut(self, param); + } + + fn visit_state_binding(&mut self, binding: &mut StateBindingAst<'i>) { + walk_state_binding_mut(self, binding); + } + + fn visit_statement(&mut self, statement: &mut Statement<'i>) { + walk_statement_mut(self, statement); + } + + fn visit_console_arg(&mut self, arg: &mut ConsoleArg<'i>) { + walk_console_arg_mut(self, arg); + } + + fn visit_expr(&mut self, expr: &mut Expr<'i>) { + walk_expr_mut(self, expr); + } +} + +pub fn visit_contract_mut<'i, V: AstVisitorMut<'i> + ?Sized>(visitor: &mut V, contract: &mut ContractAst<'i>) { + visitor.visit_contract(contract); +} + +pub fn walk_contract_mut<'i, V: AstVisitorMut<'i> + ?Sized>(visitor: &mut V, contract: &mut ContractAst<'i>) { + visitor.visit_name(&mut contract.name, NameKind::Contract); + visitor.visit_span(&mut contract.span); + visitor.visit_span(&mut contract.name_span); + for param in &mut contract.params { + visitor.visit_param(param); + } + for field in &mut contract.fields { + visitor.visit_contract_field(field); + } + for constant in &mut contract.constants { + visitor.visit_constant(constant); + } + for function in &mut contract.functions { + visitor.visit_function(function); + } +} + +pub fn walk_contract_field_mut<'i, V: AstVisitorMut<'i> + ?Sized>(visitor: &mut V, field: &mut ContractFieldAst<'i>) { + visitor.visit_name(&mut field.name, NameKind::ContractField); + visitor.visit_span(&mut field.span); + visitor.visit_span(&mut field.type_span); + visitor.visit_span(&mut field.name_span); + visitor.visit_expr(&mut field.expr); +} + +pub fn walk_constant_mut<'i, V: AstVisitorMut<'i> + ?Sized>(visitor: &mut V, constant: &mut ConstantAst<'i>) { + visitor.visit_name(&mut constant.name, NameKind::Constant); + visitor.visit_span(&mut constant.span); + visitor.visit_span(&mut constant.type_span); + visitor.visit_span(&mut constant.name_span); + visitor.visit_expr(&mut constant.expr); +} + +pub fn walk_function_mut<'i, V: AstVisitorMut<'i> + ?Sized>(visitor: &mut V, function: &mut FunctionAst<'i>) { + visitor.visit_name(&mut function.name, NameKind::Function); + visitor.visit_span(&mut function.span); + visitor.visit_span(&mut function.name_span); + visitor.visit_span(&mut function.body_span); + for span in &mut function.return_type_spans { + visitor.visit_span(span); + } + for attribute in &mut function.attributes { + visitor.visit_function_attribute(attribute); + } + for param in &mut function.params { + visitor.visit_param(param); + } + for statement in &mut function.body { + visitor.visit_statement(statement); + } +} + +pub fn walk_function_attribute_mut<'i, V: AstVisitorMut<'i> + ?Sized>(visitor: &mut V, attribute: &mut FunctionAttributeAst<'i>) { + visitor.visit_span(&mut attribute.span); + for span in &mut attribute.path_spans { + visitor.visit_span(span); + } + for segment in &mut attribute.path { + visitor.visit_name(segment, NameKind::AttributePathSegment); + } + for arg in &mut attribute.args { + visitor.visit_function_attribute_arg(arg); + } +} + +pub fn walk_function_attribute_arg_mut<'i, V: AstVisitorMut<'i> + ?Sized>(visitor: &mut V, arg: &mut FunctionAttributeArgAst<'i>) { + visitor.visit_name(&mut arg.name, NameKind::AttributeArg); + visitor.visit_span(&mut arg.span); + visitor.visit_span(&mut arg.name_span); + visitor.visit_expr(&mut arg.expr); +} + +pub fn walk_param_mut<'i, V: AstVisitorMut<'i> + ?Sized>(visitor: &mut V, param: &mut ParamAst<'i>) { + visitor.visit_name(&mut param.name, NameKind::Parameter); + visitor.visit_span(&mut param.span); + visitor.visit_span(&mut param.type_span); + visitor.visit_span(&mut param.name_span); +} + +pub fn walk_state_binding_mut<'i, V: AstVisitorMut<'i> + ?Sized>(visitor: &mut V, binding: &mut StateBindingAst<'i>) { + visitor.visit_name(&mut binding.field_name, NameKind::StateField); + visitor.visit_name(&mut binding.name, NameKind::StateBinding); + visitor.visit_span(&mut binding.span); + visitor.visit_span(&mut binding.field_span); + visitor.visit_span(&mut binding.type_span); + visitor.visit_span(&mut binding.name_span); +} + +pub fn walk_statement_mut<'i, V: AstVisitorMut<'i> + ?Sized>(visitor: &mut V, statement: &mut Statement<'i>) { + match statement { + Statement::VariableDefinition { name, expr, span, type_span, modifier_spans, name_span, .. } => { + visitor.visit_span(span); + visitor.visit_span(type_span); + for span in modifier_spans { + visitor.visit_span(span); + } + visitor.visit_span(name_span); + visitor.visit_name(name, NameKind::LocalBinding); + if let Some(expr) = expr { + visitor.visit_expr(expr); + } + } + Statement::TupleAssignment { + left_name, + right_name, + expr, + span, + left_type_span, + left_name_span, + right_type_span, + right_name_span, + .. + } => { + visitor.visit_span(span); + visitor.visit_span(left_type_span); + visitor.visit_span(left_name_span); + visitor.visit_span(right_type_span); + visitor.visit_span(right_name_span); + visitor.visit_name(left_name, NameKind::AssignmentTarget); + visitor.visit_name(right_name, NameKind::AssignmentTarget); + visitor.visit_expr(expr); + } + Statement::ArrayPush { name, expr, span, name_span } => { + visitor.visit_span(span); + visitor.visit_span(name_span); + visitor.visit_name(name, NameKind::AssignmentTarget); + visitor.visit_expr(expr); + } + Statement::FunctionCall { name, args, span, name_span } => { + visitor.visit_span(span); + visitor.visit_span(name_span); + visitor.visit_name(name, NameKind::CallTarget); + for arg in args { + visitor.visit_expr(arg); + } + } + Statement::FunctionCallAssign { bindings, name, args, span, name_span } => { + visitor.visit_span(span); + visitor.visit_span(name_span); + for binding in bindings { + visitor.visit_param(binding); + } + visitor.visit_name(name, NameKind::CallTarget); + for arg in args { + visitor.visit_expr(arg); + } + } + Statement::StateFunctionCallAssign { bindings, name, args, span, name_span } => { + visitor.visit_span(span); + visitor.visit_span(name_span); + for binding in bindings { + visitor.visit_state_binding(binding); + } + visitor.visit_name(name, NameKind::CallTarget); + for arg in args { + visitor.visit_expr(arg); + } + } + Statement::StructDestructure { bindings, expr, span } => { + visitor.visit_span(span); + for binding in bindings { + visitor.visit_state_binding(binding); + } + visitor.visit_expr(expr); + } + Statement::Assign { name, expr, span, name_span } => { + visitor.visit_span(span); + visitor.visit_span(name_span); + visitor.visit_name(name, NameKind::AssignmentTarget); + visitor.visit_expr(expr); + } + Statement::TimeOp { expr, span, tx_var_span, message_span, .. } => { + visitor.visit_span(span); + visitor.visit_span(tx_var_span); + if let Some(span) = message_span { + visitor.visit_span(span); + } + visitor.visit_expr(expr); + } + Statement::Require { expr, span, message_span, .. } => { + visitor.visit_span(span); + if let Some(span) = message_span { + visitor.visit_span(span); + } + visitor.visit_expr(expr); + } + Statement::Yield { expr, span } => { + visitor.visit_span(span); + visitor.visit_expr(expr); + } + Statement::If { condition, then_branch, else_branch, span, then_span, else_span } => { + visitor.visit_span(span); + visitor.visit_span(then_span); + if let Some(span) = else_span { + visitor.visit_span(span); + } + visitor.visit_expr(condition); + for statement in then_branch { + visitor.visit_statement(statement); + } + if let Some(else_branch) = else_branch { + for statement in else_branch { + visitor.visit_statement(statement); + } + } + } + Statement::For { ident, start, end, body, span, ident_span, body_span } => { + visitor.visit_span(span); + visitor.visit_span(ident_span); + visitor.visit_span(body_span); + visitor.visit_name(ident, NameKind::LoopBinding); + visitor.visit_expr(start); + visitor.visit_expr(end); + for statement in body { + visitor.visit_statement(statement); + } + } + Statement::Return { exprs, span } => { + visitor.visit_span(span); + for expr in exprs { + visitor.visit_expr(expr); + } + } + Statement::Console { args, span } => { + visitor.visit_span(span); + for arg in args { + visitor.visit_console_arg(arg); + } + } + } +} + +pub fn walk_console_arg_mut<'i, V: AstVisitorMut<'i> + ?Sized>(visitor: &mut V, arg: &mut ConsoleArg<'i>) { + match arg { + ConsoleArg::Identifier(name, span) => { + visitor.visit_name(name, NameKind::ConsoleIdentifier); + visitor.visit_span(span); + } + ConsoleArg::Literal(expr) => visitor.visit_expr(expr), + } +} + +pub fn walk_expr_mut<'i, V: AstVisitorMut<'i> + ?Sized>(visitor: &mut V, expr: &mut Expr<'i>) { + visitor.visit_span(&mut expr.span); + match &mut expr.kind { + ExprKind::Identifier(name) => visitor.visit_name(name, NameKind::IdentifierExpr), + ExprKind::Array(items) => { + for item in items { + visitor.visit_expr(item); + } + } + ExprKind::Call { name, args, name_span } | ExprKind::New { name, args, name_span } => { + visitor.visit_span(name_span); + visitor.visit_name(name, NameKind::CallTarget); + for arg in args { + visitor.visit_expr(arg); + } + } + ExprKind::Split { source, index, span, .. } => { + visitor.visit_span(span); + visitor.visit_expr(source); + visitor.visit_expr(index); + } + ExprKind::ArrayIndex { source, index } => { + visitor.visit_expr(source); + visitor.visit_expr(index); + } + ExprKind::Slice { source, start, end, span } => { + visitor.visit_span(span); + visitor.visit_expr(source); + visitor.visit_expr(start); + visitor.visit_expr(end); + } + ExprKind::Unary { expr, .. } => { + visitor.visit_expr(expr); + } + ExprKind::UnarySuffix { source, span, .. } => { + visitor.visit_span(span); + visitor.visit_expr(source); + } + ExprKind::Binary { left, right, .. } => { + visitor.visit_expr(left); + visitor.visit_expr(right); + } + ExprKind::IfElse { condition, then_expr, else_expr } => { + visitor.visit_expr(condition); + visitor.visit_expr(then_expr); + visitor.visit_expr(else_expr); + } + ExprKind::Introspection { index, field_span, .. } => { + visitor.visit_span(field_span); + visitor.visit_expr(index); + } + ExprKind::StateObject(fields) => { + for field in fields { + visitor.visit_name(&mut field.name, NameKind::StateField); + visitor.visit_span(&mut field.span); + visitor.visit_span(&mut field.name_span); + visitor.visit_expr(&mut field.expr); + } + } + ExprKind::FieldAccess { source, field, field_span } => { + visitor.visit_expr(source); + visitor.visit_name(field, NameKind::StateField); + visitor.visit_span(field_span); + } + ExprKind::Int(_) + | ExprKind::Bool(_) + | ExprKind::Byte(_) + | ExprKind::String(_) + | ExprKind::DateLiteral(_) + | ExprKind::Nullary(_) + | ExprKind::NumberWithUnit { .. } => {} + } +} diff --git a/silverscript-lang/src/compiler.rs b/silverscript-lang/src/compiler.rs index 561c438..7455c56 100644 --- a/silverscript-lang/src/compiler.rs +++ b/silverscript-lang/src/compiler.rs @@ -13,12 +13,36 @@ use crate::ast::{ use crate::debug_info::{DebugInfo, SourceSpan}; pub use crate::errors::{CompilerError, ErrorSpan}; use crate::span; +mod covenant_declarations; +use covenant_declarations::lower_covenant_declarations; mod debug_recording; use debug_recording::DebugRecorder; /// Prefix used for synthetic argument bindings during inline function expansion. pub const SYNTHETIC_ARG_PREFIX: &str = "__arg"; +const COVENANT_POLICY_PREFIX: &str = "__covenant_policy"; + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct CovenantDeclCallOptions { + pub is_leader: bool, +} + +fn generated_covenant_policy_name(function_name: &str) -> String { + format!("{COVENANT_POLICY_PREFIX}_{function_name}") +} + +fn generated_covenant_entrypoint_name(function_name: &str) -> String { + format!("__{function_name}") +} + +fn generated_covenant_leader_entrypoint_name(function_name: &str) -> String { + format!("__leader_{function_name}") +} + +fn generated_covenant_delegate_entrypoint_name(function_name: &str) -> String { + format!("__delegate_{function_name}") +} #[derive(Debug, Clone, Copy, Default)] pub struct CompileOptions { @@ -66,7 +90,6 @@ struct StructSpec { } type StructRegistry = HashMap; -type FunctionRegistry<'a, 'i> = HashMap>; pub fn compile_contract<'i>( source: &'i str, @@ -150,6 +173,11 @@ fn struct_name_from_type_ref<'a>(type_ref: &'a TypeRef, structs: &'a StructRegis } } +fn struct_array_name_from_type_ref(type_ref: &TypeRef, structs: &StructRegistry) -> Option { + let element_type = type_ref.element_type()?; + struct_name_from_type_ref(&element_type, structs).map(ToOwned::to_owned) +} + fn ensure_known_or_builtin_type(type_ref: &TypeRef, structs: &StructRegistry, context: &str) -> Result<(), CompilerError> { if type_ref.array_dims.is_empty() { match &type_ref.base { @@ -261,6 +289,35 @@ fn lower_expr<'i>(expr: &Expr<'i>, scope: &LoweringScope, structs: &StructRegist let span = expr.span; match &expr.kind { ExprKind::FieldAccess { .. } => { + if let ExprKind::FieldAccess { source, field, .. } = &expr.kind { + if let ExprKind::ArrayIndex { source: array_source, index } = &source.as_ref().kind { + let (base, mut path, array_type) = resolve_struct_access(array_source, scope, structs)?; + let struct_name = struct_array_name_from_type_ref(&array_type, structs) + .ok_or_else(|| CompilerError::Unsupported("field access requires a struct value".to_string()))?; + let item = structs + .get(&struct_name) + .ok_or_else(|| CompilerError::Unsupported(format!("unknown struct '{struct_name}'")))?; + let field_type = item + .fields + .iter() + .find(|candidate| candidate.name == *field) + .map(|candidate| candidate.type_ref.clone()) + .ok_or_else(|| CompilerError::Unsupported(format!("struct '{}' has no field '{}'", struct_name, field)))?; + if struct_name_from_type_ref(&field_type, structs).is_some() + || struct_array_name_from_type_ref(&field_type, structs).is_some() + { + return Err(CompilerError::Unsupported("nested struct array field access is not supported".to_string())); + } + path.push(field.clone()); + return Ok(Expr::new( + ExprKind::ArrayIndex { + source: Box::new(Expr::identifier(flattened_struct_name(&base, &path))), + index: Box::new(lower_expr(index, scope, structs)?), + }, + span, + )); + } + } let (base, path, type_ref) = resolve_struct_access(expr, scope, structs)?; if struct_name_from_type_ref(&type_ref, structs).is_some() { return Err(CompilerError::Unsupported("struct value must be used in a struct-typed position".to_string())); @@ -338,10 +395,30 @@ fn lower_expr<'i>(expr: &Expr<'i>, scope: &LoweringScope, structs: &StructRegist ExprKind::Introspection { kind: *kind, index: Box::new(lower_expr(index, scope, structs)?), field_span: *field_span }, span, )), - ExprKind::UnarySuffix { source, kind, span: suffix_span } => Ok(Expr::new( - ExprKind::UnarySuffix { source: Box::new(lower_expr(source, scope, structs)?), kind: *kind, span: *suffix_span }, - span, - )), + ExprKind::UnarySuffix { source, kind, span: suffix_span } => { + if matches!(kind, UnarySuffixKind::Length) + && let ExprKind::Identifier(name) = &source.kind + && let Some(type_ref) = scope.vars.get(name) + && struct_array_name_from_type_ref(type_ref, structs).is_some() + { + let first_leaf = flatten_type_ref_leaves(type_ref, structs)? + .into_iter() + .next() + .ok_or_else(|| CompilerError::Unsupported("struct array must contain fields".to_string()))?; + return Ok(Expr::new( + ExprKind::UnarySuffix { + source: Box::new(Expr::identifier(flattened_struct_name(name, &first_leaf.0))), + kind: *kind, + span: *suffix_span, + }, + span, + )); + } + Ok(Expr::new( + ExprKind::UnarySuffix { source: Box::new(lower_expr(source, scope, structs)?), kind: *kind, span: *suffix_span }, + span, + )) + } _ => Ok(expr.clone()), } } @@ -473,6 +550,8 @@ fn lower_struct_value_expr<'i>( let item = structs .get(expected_struct_name) .ok_or_else(|| CompilerError::Unsupported(format!("unknown struct '{expected_struct_name}'")))?; + let scope_types = + scope.vars.iter().map(|(name, type_ref)| (name.clone(), type_name_from_ref(type_ref))).collect::>(); let mut provided = HashMap::new(); for entry in entries { if provided.insert(entry.name.clone(), &entry.expr).is_some() { @@ -494,7 +573,15 @@ fn lower_struct_value_expr<'i>( contract_constants, )?); } else { - lowered.push(lower_expr(field_expr, scope, structs)?); + let lowered_expr = lower_expr(field_expr, scope, structs)?; + if !expr_matches_return_type_ref(&lowered_expr, &field.type_ref, &scope_types, contract_constants) { + return Err(CompilerError::Unsupported(format!( + "struct field '{}' expects {}", + field.name, + field.type_ref.type_name() + ))); + } + lowered.push(lowered_expr); } } if let Some(extra) = provided.keys().next() { @@ -506,30 +593,6 @@ fn lower_struct_value_expr<'i>( } } -fn lower_function_call_args<'i>( - name: &str, - args: &[Expr<'i>], - scope: &LoweringScope, - structs: &StructRegistry, - functions: &FunctionRegistry<'_, 'i>, - contract_fields: &[ContractFieldAst<'i>], - contract_constants: &HashMap>, -) -> Result>, CompilerError> { - let function = functions.get(name).ok_or_else(|| CompilerError::Unsupported(format!("function '{}' not found", name)))?; - if function.params.len() != args.len() { - return Err(CompilerError::Unsupported(format!("function '{}' expects {} arguments", name, function.params.len()))); - } - let mut lowered = Vec::new(); - for (param, arg) in function.params.iter().zip(args.iter()) { - if struct_name_from_type_ref(¶m.type_ref, structs).is_some() { - lowered.extend(lower_struct_value_expr(arg, ¶m.type_ref, scope, structs, contract_fields, contract_constants)?); - } else { - lowered.push(lower_expr(arg, scope, structs)?); - } - } - Ok(lowered) -} - fn infer_struct_expr_type<'i>( expr: &Expr<'i>, scope: &LoweringScope, @@ -669,368 +732,18 @@ fn lower_struct_destructure_statement<'i>( Ok(lowered) } -fn lower_statement<'i>( - stmt: &Statement<'i>, - scope: &mut LoweringScope, - structs: &StructRegistry, - functions: &FunctionRegistry<'_, 'i>, - contract_fields: &[ContractFieldAst<'i>], - contract_constants: &HashMap>, -) -> Result>, CompilerError> { - match stmt { - Statement::VariableDefinition { type_ref, modifiers, name, expr, span, type_span, modifier_spans, name_span } => { - ensure_known_or_builtin_type(type_ref, structs, "variable definition")?; - if struct_name_from_type_ref(type_ref, structs).is_some() { - let expr = - expr.clone().ok_or_else(|| CompilerError::Unsupported("variable definition requires initializer".to_string()))?; - let lowered_values = lower_struct_value_expr(&expr, type_ref, scope, structs, contract_fields, contract_constants)?; - let mut paths = Vec::new(); - flatten_struct_fields(type_ref, structs, &mut Vec::new(), &mut paths)?; - scope.vars.insert(name.clone(), type_ref.clone()); - Ok(paths - .into_iter() - .zip(lowered_values) - .map(|((path, field_type), field_expr)| Statement::VariableDefinition { - type_ref: field_type, - modifiers: modifiers.clone(), - name: flattened_struct_name(name, &path), - expr: Some(field_expr), - span: *span, - type_span: *type_span, - modifier_spans: modifier_spans.clone(), - name_span: *name_span, - }) - .collect()) - } else { - let lowered_expr = expr.as_ref().map(|expr| lower_expr(expr, scope, structs)).transpose()?; - scope.vars.insert(name.clone(), type_ref.clone()); - Ok(vec![Statement::VariableDefinition { - type_ref: type_ref.clone(), - modifiers: modifiers.clone(), - name: name.clone(), - expr: lowered_expr, - span: *span, - type_span: *type_span, - modifier_spans: modifier_spans.clone(), - name_span: *name_span, - }]) - } - } - Statement::FunctionCall { name, args, span, name_span } => { - let args = if name == "validateOutputState" { - args.iter() - .enumerate() - .map(|(index, arg)| { - if index == 1 { - match &arg.kind { - ExprKind::StateObject(fields) => Ok(Expr::new( - ExprKind::StateObject( - fields - .iter() - .map(|field| { - Ok(StateFieldExpr { - name: field.name.clone(), - expr: lower_expr(&field.expr, scope, structs)?, - span: field.span, - name_span: field.name_span, - }) - }) - .collect::, CompilerError>>()?, - ), - arg.span, - )), - _ => { - let state_type = TypeRef { base: TypeBase::Custom("State".to_string()), array_dims: Vec::new() }; - lower_struct_value_to_state_object_expr( - arg, - &state_type, - scope, - structs, - contract_fields, - contract_constants, - ) - } - } - } else { - lower_expr(arg, scope, structs) - } - }) - .collect::, _>>()? - } else { - lower_function_call_args(name, args, scope, structs, functions, contract_fields, contract_constants)? - }; - Ok(vec![Statement::FunctionCall { name: name.clone(), args, span: *span, name_span: *name_span }]) - } - Statement::FunctionCallAssign { bindings, name, args, span, name_span } => { - for binding in bindings { - ensure_known_or_builtin_type(&binding.type_ref, structs, "function call assignment")?; - if struct_name_from_type_ref(&binding.type_ref, structs).is_some() { - return Err(CompilerError::Unsupported( - "struct bindings are not supported in function call assignment".to_string(), - )); - } - } - let lowered_args = lower_function_call_args(name, args, scope, structs, functions, contract_fields, contract_constants)?; - for binding in bindings { - scope.vars.insert(binding.name.clone(), binding.type_ref.clone()); - } - Ok(vec![Statement::FunctionCallAssign { - bindings: bindings.clone(), - name: name.clone(), - args: lowered_args, - span: *span, - name_span: *name_span, - }]) - } - Statement::StateFunctionCallAssign { bindings, name, args, span, .. } => { - if name != "readInputState" { - return Err(CompilerError::Unsupported(format!("unsupported state function '{name}'"))); - } - if args.len() != 1 { - return Err(CompilerError::Unsupported("readInputState(input_idx) expects 1 argument".to_string())); - } - let lowered_expr = Expr::call(name, vec![lower_expr(&args[0], scope, structs)?]); - lower_struct_destructure_statement(bindings, &lowered_expr, *span, scope, structs, contract_fields, contract_constants) - } - Statement::StructDestructure { bindings, expr, span } => { - lower_struct_destructure_statement(bindings, expr, *span, scope, structs, contract_fields, contract_constants) - } - Statement::Assign { name, expr, span, name_span } => { - let target_type = scope.vars.get(name).cloned(); - if let Some(target_type) = target_type { - if struct_name_from_type_ref(&target_type, structs).is_some() { - let lowered_values = - lower_struct_value_expr(expr, &target_type, scope, structs, contract_fields, contract_constants)?; - let mut paths = Vec::new(); - flatten_struct_fields(&target_type, structs, &mut Vec::new(), &mut paths)?; - return Ok(paths - .into_iter() - .zip(lowered_values) - .map(|((path, _), field_expr)| Statement::Assign { - name: flattened_struct_name(name, &path), - expr: field_expr, - span: *span, - name_span: *name_span, - }) - .collect()); - } - } - Ok(vec![Statement::Assign { - name: name.clone(), - expr: lower_expr(expr, scope, structs)?, - span: *span, - name_span: *name_span, - }]) - } - Statement::ArrayPush { name, expr, span, name_span } => Ok(vec![Statement::ArrayPush { - name: name.clone(), - expr: lower_expr(expr, scope, structs)?, - span: *span, - name_span: *name_span, - }]), - Statement::TupleAssignment { - left_type_ref, - left_name, - right_type_ref, - right_name, - expr, - span, - left_type_span, - left_name_span, - right_type_span, - right_name_span, - } => { - ensure_known_or_builtin_type(left_type_ref, structs, "tuple assignment")?; - ensure_known_or_builtin_type(right_type_ref, structs, "tuple assignment")?; - if struct_name_from_type_ref(left_type_ref, structs).is_some() - || struct_name_from_type_ref(right_type_ref, structs).is_some() - { - return Err(CompilerError::Unsupported("tuple assignment does not support struct types".to_string())); - } - let lowered_expr = lower_expr(expr, scope, structs)?; - scope.vars.insert(left_name.clone(), left_type_ref.clone()); - scope.vars.insert(right_name.clone(), right_type_ref.clone()); - Ok(vec![Statement::TupleAssignment { - left_type_ref: left_type_ref.clone(), - left_name: left_name.clone(), - right_type_ref: right_type_ref.clone(), - right_name: right_name.clone(), - expr: lowered_expr, - span: *span, - left_type_span: *left_type_span, - left_name_span: *left_name_span, - right_type_span: *right_type_span, - right_name_span: *right_name_span, - }]) - } - Statement::Require { expr, message, span, message_span } => Ok(vec![Statement::Require { - expr: lower_expr(expr, scope, structs)?, - message: message.clone(), - span: *span, - message_span: *message_span, - }]), - Statement::TimeOp { tx_var, expr, message, span, tx_var_span, message_span } => Ok(vec![Statement::TimeOp { - tx_var: *tx_var, - expr: lower_expr(expr, scope, structs)?, - message: message.clone(), - span: *span, - tx_var_span: *tx_var_span, - message_span: *message_span, - }]), - Statement::If { condition, then_branch, else_branch, span, then_span, else_span } => { - let mut then_scope = scope.clone(); - let lowered_then = lower_block(then_branch, &mut then_scope, structs, functions, contract_fields, contract_constants)?; - let lowered_else = if let Some(else_branch) = else_branch { - let mut else_scope = scope.clone(); - Some(lower_block(else_branch, &mut else_scope, structs, functions, contract_fields, contract_constants)?) - } else { - None - }; - Ok(vec![Statement::If { - condition: lower_expr(condition, scope, structs)?, - then_branch: lowered_then, - else_branch: lowered_else, - span: *span, - then_span: *then_span, - else_span: *else_span, - }]) - } - Statement::For { ident, start, end, body, span, ident_span, body_span } => { - let mut body_scope = scope.clone(); - body_scope.vars.insert(ident.clone(), TypeRef { base: TypeBase::Int, array_dims: Vec::new() }); - let lowered_body = lower_block(body, &mut body_scope, structs, functions, contract_fields, contract_constants)?; - Ok(vec![Statement::For { - ident: ident.clone(), - start: lower_expr(start, scope, structs)?, - end: lower_expr(end, scope, structs)?, - body: lowered_body, - span: *span, - ident_span: *ident_span, - body_span: *body_span, - }]) - } - Statement::Yield { expr, span } => Ok(vec![Statement::Yield { expr: lower_expr(expr, scope, structs)?, span: *span }]), - Statement::Return { exprs, span } => Ok(vec![Statement::Return { - exprs: exprs.iter().map(|expr| lower_expr(expr, scope, structs)).collect::, _>>()?, - span: *span, - }]), - Statement::Console { args, span } => Ok(vec![Statement::Console { - args: args - .iter() - .map(|arg| match arg { - crate::ast::ConsoleArg::Identifier(name, span) => Ok(crate::ast::ConsoleArg::Identifier(name.clone(), *span)), - crate::ast::ConsoleArg::Literal(expr) => Ok(crate::ast::ConsoleArg::Literal(lower_expr(expr, scope, structs)?)), - }) - .collect::, CompilerError>>()?, - span: *span, - }]), - } -} - -fn lower_block<'i>( - statements: &[Statement<'i>], - scope: &mut LoweringScope, - structs: &StructRegistry, - functions: &FunctionRegistry<'_, 'i>, - contract_fields: &[ContractFieldAst<'i>], - contract_constants: &HashMap>, -) -> Result>, CompilerError> { - let mut lowered = Vec::new(); - for stmt in statements { - lowered.extend(lower_statement(stmt, scope, structs, functions, contract_fields, contract_constants)?); - } - Ok(lowered) -} - -fn lower_params<'i>( - params: &[crate::ast::ParamAst<'i>], - scope: &mut LoweringScope, - structs: &StructRegistry, -) -> Result>, CompilerError> { - let mut lowered = Vec::new(); - for param in params { - ensure_known_or_builtin_type(¶m.type_ref, structs, "function parameter")?; - scope.vars.insert(param.name.clone(), param.type_ref.clone()); - if struct_name_from_type_ref(¶m.type_ref, structs).is_some() { - let mut leaves = Vec::new(); - flatten_struct_fields(¶m.type_ref, structs, &mut Vec::new(), &mut leaves)?; - for (path, field_type) in leaves { - lowered.push(crate::ast::ParamAst { - type_ref: field_type, - name: flattened_struct_name(¶m.name, &path), - span: param.span, - type_span: param.type_span, - name_span: param.name_span, - }); - } - } else { - lowered.push(param.clone()); - } - } - Ok(lowered) -} - -fn lower_contract<'i>( - contract: &ContractAst<'i>, - contract_constants: &HashMap>, -) -> Result, CompilerError> { - let structs = build_struct_registry(contract)?; - validate_struct_graph(&structs)?; - +fn validate_contract_struct_usage<'i>(contract: &ContractAst<'i>, structs: &StructRegistry) -> Result<(), CompilerError> { for param in &contract.params { - if struct_name_from_type_ref(¶m.type_ref, &structs).is_some() { - return Err(CompilerError::Unsupported("struct contract parameters are not supported".to_string())); - } - ensure_known_or_builtin_type(¶m.type_ref, &structs, "contract parameter")?; + ensure_known_or_builtin_type(¶m.type_ref, structs, "contract parameter")?; } for field in &contract.fields { - if struct_name_from_type_ref(&field.type_ref, &structs).is_some() { - return Err(CompilerError::Unsupported("struct contract fields are not supported".to_string())); - } - ensure_known_or_builtin_type(&field.type_ref, &structs, "contract field")?; + ensure_known_or_builtin_type(&field.type_ref, structs, "contract field")?; } for constant in &contract.constants { - if struct_name_from_type_ref(&constant.type_ref, &structs).is_some() { - return Err(CompilerError::Unsupported("struct constants are not supported".to_string())); - } - ensure_known_or_builtin_type(&constant.type_ref, &structs, "constant")?; - } - - let functions = contract.functions.iter().map(|function| (function.name.clone(), function)).collect::>(); - let mut lowered_functions = Vec::with_capacity(contract.functions.len()); - for function in &contract.functions { - let mut scope = LoweringScope::default(); - let lowered_params = lower_params(&function.params, &mut scope, &structs)?; - if function.return_types.iter().any(|type_ref| struct_name_from_type_ref(type_ref, &structs).is_some()) { - return Err(CompilerError::Unsupported("struct return types are not supported".to_string())); - } - for type_ref in &function.return_types { - ensure_known_or_builtin_type(type_ref, &structs, "function return type")?; - } - let lowered_body = lower_block(&function.body, &mut scope, &structs, &functions, &contract.fields, contract_constants)?; - lowered_functions.push(FunctionAst { - name: function.name.clone(), - params: lowered_params, - entrypoint: function.entrypoint, - return_types: function.return_types.clone(), - body: lowered_body, - return_type_spans: function.return_type_spans.clone(), - span: function.span, - name_span: function.name_span, - body_span: function.body_span, - }); - } - - Ok(ContractAst { - name: contract.name.clone(), - params: contract.params.clone(), - structs: Vec::new(), - fields: contract.fields.clone(), - constants: contract.constants.clone(), - functions: lowered_functions, - span: contract.span, - name_span: contract.name_span, - }) + ensure_known_or_builtin_type(&constant.type_ref, structs, "constant")?; + } + + Ok(()) } fn compile_contract_impl<'i>( @@ -1042,56 +755,71 @@ fn compile_contract_impl<'i>( if contract.functions.is_empty() { return Err(CompilerError::Unsupported("contract has no functions".to_string())); } - - let entrypoint_functions: Vec<&FunctionAst<'i>> = contract.functions.iter().filter(|func| func.entrypoint).collect(); - if entrypoint_functions.is_empty() { - return Err(CompilerError::Unsupported("contract has no entrypoint functions".to_string())); - } - if contract.params.len() != constructor_args.len() { return Err(CompilerError::Unsupported("constructor argument count mismatch".to_string())); } + let structs = build_struct_registry(contract)?; + validate_struct_graph(&structs)?; + for (param, value) in contract.params.iter().zip(constructor_args.iter()) { let param_type_name = type_name_from_ref(¶m.type_ref); - if !expr_matches_type(value, ¶m_type_name) { + if !expr_matches_declared_type_ref(value, ¶m.type_ref, &structs) { return Err(CompilerError::Unsupported(format!("constructor argument '{}' expects {}", param.name, param_type_name))); } } - let without_selector = entrypoint_functions.len() == 1; - let mut constants: HashMap> = contract.constants.iter().map(|constant| (constant.name.clone(), constant.expr.clone())).collect(); for (param, value) in contract.params.iter().zip(constructor_args.iter()) { constants.insert(param.name.clone(), value.clone()); } - let lowered_contract = lower_contract(contract, &constants)?; + // Preserve struct-typed covenant policy signatures in the user-facing AST and ABI. + // This must be `true` because callers should still see `State` / `State[]` rather than flattened field lists. + let abi_contract = lower_covenant_declarations(contract, &constants, true)?; + // Desugar covenant policy signatures for code generation before struct lowering. + // This must be `false` because the backend and wrapper generation operate on flattened per-field parameters and returns. + let codegen_contract = lower_covenant_declarations(contract, &constants, false)?; + let structs = build_struct_registry(&codegen_contract)?; + validate_struct_graph(&structs)?; + validate_contract_struct_usage(&codegen_contract, &structs)?; + + let entrypoint_functions: Vec<&FunctionAst<'i>> = codegen_contract.functions.iter().filter(|func| func.entrypoint).collect(); + if entrypoint_functions.is_empty() { + return Err(CompilerError::Unsupported("contract has no entrypoint functions".to_string())); + } + + let without_selector = entrypoint_functions.len() == 1; - let functions_map = lowered_contract.functions.iter().cloned().map(|func| (func.name.clone(), func)).collect::>(); + let functions_map = codegen_contract.functions.iter().cloned().map(|func| (func.name.clone(), func)).collect::>(); let function_order = - lowered_contract.functions.iter().enumerate().map(|(index, func)| (func.name.clone(), index)).collect::>(); - let function_abi_entries = build_function_abi_entries(contract); - let uses_script_size = contract_uses_script_size(&lowered_contract); + codegen_contract.functions.iter().enumerate().map(|(index, func)| (func.name.clone(), index)).collect::>(); + let function_abi_entries = build_function_abi_entries(&abi_contract); + let uses_script_size = contract_uses_script_size(&codegen_contract, &structs, &constants); let mut script_size = if uses_script_size { Some(100i64) } else { None }; for _ in 0..32 { let (_contract_fields, field_prolog_script) = - compile_contract_fields(&lowered_contract.fields, &constants, options, script_size)?; + compile_contract_fields(&codegen_contract.fields, &constants, options, script_size, &structs)?; let mut compiled_entrypoints = Vec::new(); let mut recorder = DebugRecorder::new(options.record_debug_infos); recorder.record_constructor_constants(&contract.params, constructor_args); - for (index, func) in lowered_contract.functions.iter().enumerate() { + for (index, func) in codegen_contract.functions.iter().enumerate() { if func.entrypoint { + let mut contract_field_prefix_len = field_prolog_script.len(); + if !without_selector && function_branch_index(&codegen_contract, &func.name)? == 0 { + contract_field_prefix_len += selector_dispatch_branch0_prefix_len()?; + } compiled_entrypoints.push(compile_entrypoint_function( func, index, - &lowered_contract.fields, - field_prolog_script.len(), + &codegen_contract.fields, + contract_field_prefix_len, &constants, options, + &structs, &functions_map, &function_order, script_size, @@ -1100,13 +828,17 @@ fn compile_contract_impl<'i>( } } - let entrypoint_script = if without_selector { - let (name, script) = compiled_entrypoints + let script = if without_selector { + let (name, entrypoint_script) = compiled_entrypoints .first() .ok_or_else(|| CompilerError::Unsupported("contract has no entrypoint functions".to_string()))?; recorder.set_entrypoint_start(name, field_prolog_script.len()); - script.clone() + let mut script = field_prolog_script.clone(); + script.extend(entrypoint_script); + script } else { + // Dispatch on selector first; each selected branch then executes + // the shared contract-field prolog before branch body. let mut builder = ScriptBuilder::new(); let total = compiled_entrypoints.len(); for (index, (name, script)) in compiled_entrypoints.iter().enumerate() { @@ -1115,7 +847,8 @@ fn compile_contract_impl<'i>( builder.add_op(OpNumEqual)?; builder.add_op(OpIf)?; builder.add_op(OpDrop)?; - let start = field_prolog_script.len() + builder.script().len(); + builder.add_ops(&field_prolog_script)?; + let start = builder.script().len(); recorder.set_entrypoint_start(name, start); builder.add_ops(script)?; builder.add_op(OpElse)?; @@ -1133,15 +866,12 @@ fn compile_contract_impl<'i>( builder.drain() }; - let mut script = field_prolog_script.clone(); - script.extend(entrypoint_script); let debug_info = recorder.into_debug_info(source.unwrap_or_default().to_string()); - if !uses_script_size { return Ok(CompiledContract { - contract_name: contract.name.clone(), + contract_name: abi_contract.name.clone(), script, - ast: contract.clone(), + ast: abi_contract.clone(), abi: function_abi_entries, without_selector, debug_info, @@ -1151,9 +881,9 @@ fn compile_contract_impl<'i>( let actual_size = script.len() as i64; if Some(actual_size) == script_size { return Ok(CompiledContract { - contract_name: contract.name.clone(), + contract_name: abi_contract.name.clone(), script, - ast: contract.clone(), + ast: abi_contract.clone(), abi: function_abi_entries, without_selector, debug_info, @@ -1165,7 +895,11 @@ fn compile_contract_impl<'i>( Err(CompilerError::Unsupported("script size did not stabilize".to_string())) } -fn contract_uses_script_size<'i>(contract: &ContractAst<'i>) -> bool { +fn contract_uses_script_size<'i>( + contract: &ContractAst<'i>, + _structs: &StructRegistry, + _contract_constants: &HashMap>, +) -> bool { if contract.constants.iter().any(|constant| expr_uses_script_size(&constant.expr)) { return true; } @@ -1175,11 +909,97 @@ fn contract_uses_script_size<'i>(contract: &ContractAst<'i>) -> bool { contract.functions.iter().any(|func| func.body.iter().any(statement_uses_script_size)) } +fn expr_matches_declared_type_ref<'i>(expr: &Expr<'i>, type_ref: &TypeRef, structs: &StructRegistry) -> bool { + if let Some(struct_name) = struct_name_from_type_ref(type_ref, structs) { + let Some(item) = structs.get(struct_name) else { + return false; + }; + let ExprKind::StateObject(fields) = &expr.kind else { + return false; + }; + if fields.len() != item.fields.len() { + return false; + } + for field in &item.fields { + let Some(value) = fields.iter().find(|entry| entry.name == field.name).map(|entry| &entry.expr) else { + return false; + }; + if !expr_matches_declared_type_ref(value, &field.type_ref, structs) { + return false; + } + } + return true; + } + + if let Some(element_type) = array_element_type_ref(type_ref) { + if struct_name_from_type_ref(&element_type, structs).is_some() { + return matches!(&expr.kind, ExprKind::Array(values) if values.iter().all(|value| expr_matches_declared_type_ref(value, &element_type, structs))); + } + } + + expr_matches_type_ref(expr, type_ref) +} + +fn encode_struct_value<'i>(expr: &Expr<'i>, type_ref: &TypeRef, structs: &StructRegistry) -> Result, CompilerError> { + let struct_name = struct_name_from_type_ref(type_ref, structs) + .ok_or_else(|| CompilerError::Unsupported(format!("expected struct type '{}'", type_ref.type_name())))?; + let item = structs.get(struct_name).ok_or_else(|| CompilerError::Unsupported(format!("unknown struct '{struct_name}'")))?; + let ExprKind::StateObject(fields) = &expr.kind else { + return Err(CompilerError::Unsupported(format!("expression expects struct {}", type_ref.type_name()))); + }; + + let mut out = Vec::new(); + for field in &item.fields { + let value = fields + .iter() + .find(|entry| entry.name == field.name) + .map(|entry| &entry.expr) + .ok_or_else(|| CompilerError::Unsupported(format!("struct field '{}' must be initialized", field.name)))?; + if struct_name_from_type_ref(&field.type_ref, structs).is_some() { + out.extend(encode_struct_value(value, &field.type_ref, structs)?); + } else { + let field_type_name = type_name_from_ref(&field.type_ref); + if field.type_ref.array_dims.is_empty() && field.type_ref.base == TypeBase::Int { + let ExprKind::Int(number) = &value.kind else { + return Err(CompilerError::Unsupported(format!("struct field '{}' expects int", field.name))); + }; + let serialized = serialize_i64(*number, Some(8usize)) + .map_err(|err| CompilerError::Unsupported(format!("failed to serialize int literal {}: {err}", number)))?; + out.extend_from_slice(&data_prefix(serialized.len())); + out.extend(serialized); + } else if is_array_type(&field_type_name) + || matches!(value.kind, ExprKind::Array(_) | ExprKind::String(_) | ExprKind::Byte(_)) + { + let encoded = match &value.kind { + ExprKind::Array(values) => { + if is_byte_array(value) { + values.iter().filter_map(|v| if let ExprKind::Byte(byte) = &v.kind { Some(*byte) } else { None }).collect() + } else { + encode_array_literal(values, &field_type_name)? + } + } + ExprKind::String(string) => string.as_bytes().to_vec(), + ExprKind::Byte(byte) => vec![*byte], + _ => return Err(CompilerError::Unsupported(format!("struct field '{}' expects {}", field.name, field_type_name))), + }; + out.extend_from_slice(&data_prefix(encoded.len())); + out.extend(encoded); + } else { + let encoded = encode_fixed_size_value(value, &field_type_name)?; + out.extend_from_slice(&data_prefix(encoded.len())); + out.extend(encoded); + } + } + } + Ok(out) +} + fn compile_contract_fields<'i>( fields: &[ContractFieldAst<'i>], base_constants: &HashMap>, options: CompileOptions, script_size: Option, + structs: &StructRegistry, ) -> Result<(HashMap>, Vec), CompilerError> { let mut env = base_constants.clone(); let mut field_values = HashMap::new(); @@ -1199,13 +1019,16 @@ fn compile_contract_fields<'i>( let mut resolve_visiting = HashSet::new(); let resolved = resolve_expr(field.expr.clone(), &env, &mut resolve_visiting)?; - if !expr_matches_type_ref(&resolved, &field.type_ref) { + if !expr_matches_declared_type_ref(&resolved, &field.type_ref, structs) { return Err(CompilerError::Unsupported(format!("contract field '{}' expects {}", field.name, type_name))); } let mut compile_visiting = HashSet::new(); let mut stack_depth = 0i64; - if field.type_ref.array_dims.is_empty() && field.type_ref.base == TypeBase::Int { + if struct_name_from_type_ref(&field.type_ref, structs).is_some() { + let encoded = encode_struct_value(&resolved, &field.type_ref, structs)?; + builder.add_data(&encoded)?; + } else if field.type_ref.array_dims.is_empty() && field.type_ref.base == TypeBase::Int { let ExprKind::Int(value) = &resolved.kind else { return Err(CompilerError::Unsupported(format!("contract field '{}' expects compile-time int value", field.name))); }; @@ -1275,7 +1098,7 @@ fn expr_uses_script_size<'i>(expr: &Expr<'i>) -> bool { } ExprKind::Array(values) => values.iter().any(expr_uses_script_size), ExprKind::StateObject(fields) => fields.iter().any(|field| expr_uses_script_size(&field.expr)), - ExprKind::Call { args, .. } => args.iter().any(expr_uses_script_size), + ExprKind::Call { name, args, .. } => name == "readInputState" || args.iter().any(expr_uses_script_size), ExprKind::New { args, .. } => args.iter().any(expr_uses_script_size), ExprKind::Split { source, index, .. } => expr_uses_script_size(source) || expr_uses_script_size(index), ExprKind::Slice { source, start, end, .. } => { @@ -1391,6 +1214,131 @@ fn build_function_abi_entries<'i>(contract: &ContractAst<'i>) -> Vec Result, TypeRef)>, CompilerError> { + if let Some(struct_name) = struct_array_name_from_type_ref(type_ref, structs) { + let outer_dims = type_ref.array_dims.clone(); + let item = structs.get(&struct_name).ok_or_else(|| CompilerError::Unsupported(format!("unknown struct '{struct_name}'")))?; + let mut leaves = Vec::new(); + for field in &item.fields { + let mut field_type = field.type_ref.clone(); + field_type.array_dims.extend(outer_dims.iter().cloned()); + for (mut path, leaf_type) in flatten_type_ref_leaves(&field_type, structs)? { + path.insert(0, field.name.clone()); + leaves.push((path, leaf_type)); + } + } + return Ok(leaves); + } + + let mut leaves = Vec::new(); + flatten_struct_fields(type_ref, structs, &mut Vec::new(), &mut leaves)?; + Ok(leaves) +} + +fn lowering_scope_from_types(types: &HashMap) -> Result { + let mut scope = LoweringScope::default(); + for (name, type_name) in types { + scope.vars.insert(name.clone(), parse_type_ref(type_name)?); + } + Ok(scope) +} + +fn lower_runtime_expr<'i>( + expr: &Expr<'i>, + types: &HashMap, + structs: &StructRegistry, +) -> Result, CompilerError> { + let scope = lowering_scope_from_types(types)?; + lower_expr(expr, &scope, structs) +} + +fn lower_runtime_struct_expr<'i>( + expr: &Expr<'i>, + expected_type: &TypeRef, + types: &HashMap, + structs: &StructRegistry, + contract_fields: &[ContractFieldAst<'i>], + contract_constants: &HashMap>, +) -> Result>, CompilerError> { + let scope = lowering_scope_from_types(types)?; + lower_struct_value_expr(expr, expected_type, &scope, structs, contract_fields, contract_constants) +} + +fn flatten_runtime_value_expr<'i>( + expr: &Expr<'i>, + types: &HashMap, + structs: &StructRegistry, + contract_fields: &[ContractFieldAst<'i>], + contract_constants: &HashMap>, +) -> Result>, CompilerError> { + let scope = lowering_scope_from_types(types)?; + if let Ok(type_ref) = infer_struct_expr_type(expr, &scope, structs, contract_fields) { + if struct_name_from_type_ref(&type_ref, structs).is_some() { + return lower_struct_value_expr(expr, &type_ref, &scope, structs, contract_fields, contract_constants); + } + } + Ok(vec![lower_expr(expr, &scope, structs)?]) +} + +fn flatten_runtime_return_exprs<'i>( + exprs: &[Expr<'i>], + return_types: &[TypeRef], + types: &HashMap, + structs: &StructRegistry, + contract_fields: &[ContractFieldAst<'i>], + contract_constants: &HashMap>, +) -> Result>, CompilerError> { + let mut flattened = Vec::new(); + for (expr, return_type) in exprs.iter().zip(return_types.iter()) { + if struct_name_from_type_ref(return_type, structs).is_some() { + flattened.extend(lower_runtime_struct_expr(expr, return_type, types, structs, contract_fields, contract_constants)?); + } else { + flattened.push(lower_runtime_expr(expr, types, structs)?); + } + } + Ok(flattened) +} + +fn store_struct_binding<'i>( + name: &str, + type_ref: &TypeRef, + expr: &Expr<'i>, + env: &mut HashMap>, + types: &mut HashMap, + structs: &StructRegistry, + contract_fields: &[ContractFieldAst<'i>], + contract_constants: &HashMap>, + is_assignment: bool, +) -> Result<(), CompilerError> { + let lowered_values = lower_runtime_struct_expr(expr, type_ref, types, structs, contract_fields, contract_constants)?; + let leaf_bindings = flatten_type_ref_leaves(type_ref, structs)?; + let original_env = env.clone(); + let mut pending = Vec::with_capacity(leaf_bindings.len()); + + for ((path, field_type), lowered_expr) in leaf_bindings.into_iter().zip(lowered_values.into_iter()) { + let leaf_name = flattened_struct_name(name, &path); + let stored_expr = if is_assignment { + let updated = if let Some(previous) = original_env.get(&leaf_name) { + replace_identifier(&lowered_expr, &leaf_name, previous) + } else { + lowered_expr + }; + resolve_expr_for_runtime(updated, &original_env, types, &mut HashSet::new())? + } else { + lowered_expr + }; + pending.push((leaf_name, type_name_from_ref(&field_type), stored_expr)); + } + + types.insert(name.to_string(), type_name_from_ref(type_ref)); + for (leaf_name, field_type_name, stored_expr) in pending { + types.insert(leaf_name.clone(), field_type_name); + env.insert(leaf_name, stored_expr); + } + + Ok(()) +} + fn type_name_from_ref(type_ref: &TypeRef) -> String { type_ref.type_name() } @@ -1480,6 +1428,8 @@ fn validate_return_types<'i>( exprs: &[Expr<'i>], return_types: &[TypeRef], types: &HashMap, + structs: &StructRegistry, + contract_fields: &[ContractFieldAst<'i>], constants: &HashMap>, ) -> Result<(), CompilerError> { if return_types.is_empty() { @@ -1489,7 +1439,13 @@ fn validate_return_types<'i>( return Err(CompilerError::Unsupported("return values count must match function return types".to_string())); } for (expr, return_type) in exprs.iter().zip(return_types.iter()) { - if !expr_matches_return_type_ref(expr, return_type, types, constants) { + let matches = if struct_name_from_type_ref(return_type, structs).is_some() { + lower_runtime_struct_expr(expr, return_type, types, structs, contract_fields, constants).is_ok() + } else { + expr_matches_return_type_ref(expr, return_type, types, constants) + }; + + if !matches { let type_name = type_name_from_ref(return_type); return Err(CompilerError::Unsupported(format!("return value expects {type_name}"))); } @@ -1596,10 +1552,6 @@ fn infer_fixed_array_type_from_initializer_ref<'i>( } } -fn expr_matches_type<'i>(expr: &Expr<'i>, type_name: &str) -> bool { - parse_type_ref(type_name).is_ok_and(|type_ref| expr_matches_type_ref(expr, &type_ref)) -} - fn array_literal_matches_type_with_env<'i>( values: &[Expr<'i>], type_name: &str, @@ -1688,8 +1640,9 @@ impl<'i> CompiledContract<'i> { let mut builder = ScriptBuilder::new(); for (input, arg) in function.inputs.iter().zip(args) { let type_ref = parse_type_ref(&input.type_name)?; - push_typed_sigscript_arg(&mut builder, arg, &type_ref, &structs) - .map_err(|_| CompilerError::Unsupported(format!("function argument '{}' expects {}", input.name, input.type_name)))?; + push_typed_sigscript_arg(&mut builder, arg, &type_ref, &structs).map_err(|err| { + CompilerError::Unsupported(format!("function argument '{}' expects {} ({err})", input.name, input.type_name)) + })?; } if !self.without_selector { let selector = function_branch_index(&self.ast, function_name)?; @@ -1697,6 +1650,30 @@ impl<'i> CompiledContract<'i> { } Ok(builder.drain()) } + + pub fn build_sig_script_for_covenant_decl( + &self, + function_name: &str, + args: Vec>, + options: CovenantDeclCallOptions, + ) -> Result, CompilerError> { + let auth_entrypoint = generated_covenant_entrypoint_name(function_name); + if self.abi.iter().any(|entry| entry.name == auth_entrypoint) { + return self.build_sig_script(&auth_entrypoint, args); + } + + let entrypoint = if options.is_leader { + generated_covenant_leader_entrypoint_name(function_name) + } else { + generated_covenant_delegate_entrypoint_name(function_name) + }; + + if self.abi.iter().any(|entry| entry.name == entrypoint) { + return self.build_sig_script(&entrypoint, args); + } + + Err(CompilerError::Unsupported(format!("covenant declaration '{}' not found", function_name))) + } } fn push_typed_sigscript_arg<'i>( @@ -1705,6 +1682,56 @@ fn push_typed_sigscript_arg<'i>( type_ref: &TypeRef, structs: &StructRegistry, ) -> Result<(), CompilerError> { + if let Some(element_type) = type_ref.element_type() { + if let Some(struct_name) = struct_name_from_type_ref(&element_type, structs) { + let item = + structs.get(struct_name).ok_or_else(|| CompilerError::Unsupported(format!("unknown struct '{struct_name}'")))?; + let ExprKind::Array(values) = arg.kind else { + return Err(CompilerError::Unsupported("signature script struct array arguments must be array literals".to_string())); + }; + + for field in &item.fields { + let mut field_values = Vec::with_capacity(values.len()); + for value in &values { + let ExprKind::StateObject(entries) = &value.kind else { + return Err(CompilerError::Unsupported( + "signature script struct array arguments must contain object literals".to_string(), + )); + }; + + let mut matched = None; + for entry in entries { + if entry.name == field.name { + if matched.is_some() { + return Err(CompilerError::Unsupported(format!("duplicate struct field '{}'", field.name))); + } + matched = Some(entry.expr.clone()); + } + } + + field_values + .push(matched.ok_or_else(|| { + CompilerError::Unsupported(format!("struct field '{}' must be initialized", field.name)) + })?); + + if let Some(extra) = entries.iter().find(|entry| item.fields.iter().all(|field| field.name != entry.name)) { + return Err(CompilerError::Unsupported(format!("unknown struct field '{}'", extra.name))); + } + } + + let mut field_type = field.type_ref.clone(); + field_type.array_dims.push(ArrayDim::Dynamic); + push_typed_sigscript_arg( + builder, + Expr::new(ExprKind::Array(field_values), span::Span::default()), + &field_type, + structs, + )?; + } + return Ok(()); + } + } + if let Some(struct_name) = struct_name_from_type_ref(type_ref, structs) { let item = structs.get(struct_name).ok_or_else(|| CompilerError::Unsupported(format!("unknown struct '{struct_name}'")))?; let ExprKind::StateObject(fields) = arg.kind else { @@ -1913,6 +1940,17 @@ pub fn function_branch_index<'i>(contract: &ContractAst<'i>, function_name: &str .ok_or_else(|| CompilerError::Unsupported(format!("function '{function_name}' not found"))) } +fn selector_dispatch_branch0_prefix_len() -> Result { + let mut builder = ScriptBuilder::new(); + builder.add_op(OpDup)?; + builder.add_i64(0)?; + builder.add_op(OpNumEqual)?; + builder.add_op(OpIf)?; + builder.add_op(OpDrop)?; + Ok(builder.drain().len()) +} + +#[allow(clippy::too_many_arguments)] fn compile_entrypoint_function<'i>( function: &FunctionAst<'i>, function_index: usize, @@ -1920,39 +1958,60 @@ fn compile_entrypoint_function<'i>( contract_field_prefix_len: usize, constants: &HashMap>, options: CompileOptions, + structs: &StructRegistry, functions: &HashMap>, function_order: &HashMap, script_size: Option, recorder: &mut DebugRecorder<'i>, ) -> Result<(String, Vec), CompilerError> { let contract_field_count = contract_fields.len(); - let param_count = function.params.len(); - let mut params = function - .params + let mut flattened_param_names = Vec::new(); + let mut types = HashMap::new(); + for param in &function.params { + let param_type_name = type_name_from_ref(¶m.type_ref); + types.insert(param.name.clone(), param_type_name.clone()); + if struct_name_from_type_ref(¶m.type_ref, structs).is_some() + || struct_array_name_from_type_ref(¶m.type_ref, structs).is_some() + { + for (path, field_type) in flatten_type_ref_leaves(¶m.type_ref, structs)? { + let leaf_name = flattened_struct_name(¶m.name, &path); + types.insert(leaf_name.clone(), type_name_from_ref(&field_type)); + flattened_param_names.push(leaf_name); + } + } else { + flattened_param_names.push(param.name.clone()); + } + } + + let param_count = flattened_param_names.len(); + let mut params = flattened_param_names .iter() - .map(|param| param.name.clone()) .enumerate() - .map(|(index, name)| (name, (contract_field_count + (param_count - 1 - index)) as i64)) + .map(|(index, name)| (name.clone(), (contract_field_count + (param_count - 1 - index)) as i64)) .collect::>(); for (index, field) in contract_fields.iter().enumerate() { params.insert(field.name.clone(), (contract_field_count - 1 - index) as i64); } - let mut types = - function.params.iter().map(|param| (param.name.clone(), type_name_from_ref(¶m.type_ref))).collect::>(); for field in contract_fields { types.insert(field.name.clone(), type_name_from_ref(&field.type_ref)); } for param in &function.params { let param_type_name = type_name_from_ref(¶m.type_ref); - if is_array_type(¶m_type_name) && array_element_size(¶m_type_name).is_none() { + if is_array_type(¶m_type_name) + && array_element_size(¶m_type_name).is_none() + && struct_array_name_from_type_ref(¶m.type_ref, structs).is_none() + { return Err(CompilerError::Unsupported(format!("array element type must have known size: {}", param_type_name))); } } for return_type in &function.return_types { let return_type_name = type_name_from_ref(return_type); - if is_array_type(&return_type_name) && array_element_size(&return_type_name).is_none() { + if is_array_type(&return_type_name) + && array_element_size(&return_type_name).is_none() + && struct_array_name_from_type_ref(return_type, structs).is_none() + { return Err(CompilerError::Unsupported(format!("array element type must have known size: {return_type_name}"))); } } @@ -1960,6 +2019,13 @@ fn compile_entrypoint_function<'i>( // Remove any constructor/constant names that collide with function param names (prioritizing function parameters on name collision). for param in &function.params { env.remove(¶m.name); + if struct_name_from_type_ref(¶m.type_ref, structs).is_some() + || struct_array_name_from_type_ref(¶m.type_ref, structs).is_some() + { + for (path, _) in flatten_type_ref_leaves(¶m.type_ref, structs)? { + env.remove(&flattened_struct_name(¶m.name, &path)); + } + } } let mut builder = ScriptBuilder::new(); let mut yields: Vec = Vec::new(); @@ -1997,9 +2063,10 @@ fn compile_entrypoint_function<'i>( if index != body_len - 1 { return Err(CompilerError::Unsupported("return statement must be the last statement".to_string())); } - validate_return_types(exprs, &function.return_types, &types, constants)?; + validate_return_types(exprs, &function.return_types, &types, structs, contract_fields, constants)?; for expr in exprs { - let resolved = resolve_expr(expr.clone(), &env, &mut HashSet::new()).map_err(|err| err.with_span(&expr.span))?; + let resolved = resolve_expr_for_runtime(expr.clone(), &env, &types, &mut HashSet::new()) + .map_err(|err| err.with_span(&expr.span))?; yields.push(resolved); } } else { @@ -2013,6 +2080,7 @@ fn compile_entrypoint_function<'i>( contract_fields, contract_field_prefix_len, constants, + structs, functions, function_order, function_index, @@ -2025,7 +2093,17 @@ fn compile_entrypoint_function<'i>( recorder.finish_statement_at(stmt, builder.script().len(), &env, &types)?; } - let yield_count = yields.len(); + let flattened_yields = if has_return { + flatten_runtime_return_exprs(&yields, &function.return_types, &types, structs, contract_fields, constants)? + } else { + let mut flattened = Vec::new(); + for expr in &yields { + flattened.extend(flatten_runtime_value_expr(expr, &types, structs, contract_fields, constants)?); + } + flattened + }; + + let yield_count = flattened_yields.len(); if yield_count == 0 { for _ in 0..param_count { builder.add_op(OpDrop)?; @@ -2036,7 +2114,7 @@ fn compile_entrypoint_function<'i>( builder.add_op(OpTrue)?; } else { let mut stack_depth = 0i64; - for expr in &yields { + for expr in &flattened_yields { compile_expr( expr, &env, @@ -2077,6 +2155,7 @@ fn compile_statement<'i>( contract_fields: &[ContractFieldAst<'i>], contract_field_prefix_len: usize, contract_constants: &HashMap>, + structs: &StructRegistry, functions: &HashMap>, function_order: &HashMap, function_index: usize, @@ -2086,6 +2165,12 @@ fn compile_statement<'i>( ) -> Result<(), CompilerError> { match stmt { Statement::VariableDefinition { type_ref, name, expr, .. } => { + if struct_name_from_type_ref(type_ref, structs).is_some() { + let expr = + expr.as_ref().ok_or_else(|| CompilerError::Unsupported("variable definition requires initializer".to_string()))?; + return store_struct_binding(name, type_ref, expr, env, types, structs, contract_fields, contract_constants, false); + } + let type_name = type_name_from_ref(type_ref); let effective_type_name = if is_array_type(&type_name) && array_size_with_constants(&type_name, contract_constants).is_none() { @@ -2121,13 +2206,17 @@ fn compile_statement<'i>( }, Some(e) if is_byte_array_type => { // byte[] can be initialized from any bytes expression - e.clone() + lower_runtime_expr(e, types, structs)? } Some(e @ Expr { kind: ExprKind::Array(values), .. }) => { if !array_literal_matches_type_with_env(values, &effective_type_name, types, contract_constants) { return Err(CompilerError::Unsupported("array initializer must be another array".to_string())); } - resolve_expr(Expr::new(ExprKind::Array(values.clone()), e.span), env, &mut HashSet::new())? + resolve_expr( + lower_runtime_expr(&Expr::new(ExprKind::Array(values.clone()), e.span), types, structs)?, + env, + &mut HashSet::new(), + )? } Some(_) => return Err(CompilerError::Unsupported("array initializer must be another array".to_string())), None => Expr::new(ExprKind::Array(Vec::new()), span::Span::default()), @@ -2139,6 +2228,7 @@ fn compile_statement<'i>( // Fixed-size arrays like byte[N] can be initialized from expressions let expr = expr.clone().ok_or_else(|| CompilerError::Unsupported("variable definition requires initializer".to_string()))?; + let expr = lower_runtime_expr(&expr, types, structs)?; // For array literals, validate that the size matches the declared type if let ExprKind::Array(values) = &expr.kind { @@ -2170,6 +2260,7 @@ fn compile_statement<'i>( } else { let expr = expr.clone().ok_or_else(|| CompilerError::Unsupported("variable definition requires initializer".to_string()))?; + let expr = lower_runtime_expr(&expr, types, structs)?; let expected_type_ref = parse_type_ref(&effective_type_name)?; if !expr_matches_return_type_ref(&expr, &expected_type_ref, types, contract_constants) { return Err(CompilerError::Unsupported(format!("variable '{}' expects {}", name, effective_type_name))); @@ -2245,9 +2336,10 @@ fn compile_statement<'i>( Ok(()) } Statement::Require { expr, .. } => { + let expr = lower_runtime_expr(expr, types, structs)?; let mut stack_depth = 0i64; compile_expr( - expr, + &expr, env, params, types, @@ -2262,7 +2354,8 @@ fn compile_statement<'i>( Ok(()) } Statement::TimeOp { tx_var, expr, .. } => { - compile_time_op_statement(tx_var, expr, env, params, types, builder, options, script_size, contract_constants) + let expr = lower_runtime_expr(expr, types, structs)?; + compile_time_op_statement(tx_var, &expr, env, params, types, builder, options, script_size, contract_constants) } Statement::If { condition, then_branch, else_branch, .. } => compile_if_statement( condition, @@ -2276,6 +2369,7 @@ fn compile_statement<'i>( contract_fields, contract_field_prefix_len, contract_constants, + structs, functions, function_order, function_index, @@ -2297,6 +2391,7 @@ fn compile_statement<'i>( contract_fields, contract_field_prefix_len, contract_constants, + structs, functions, function_order, function_index, @@ -2329,8 +2424,29 @@ fn compile_statement<'i>( }, Statement::FunctionCall { name, args, .. } => { if name == "validateOutputState" { + let lowered_args = if let Some(state_arg) = args.get(1) { + match &state_arg.kind { + ExprKind::StateObject(_) => args.to_vec(), + _ => { + let state_type = TypeRef { base: TypeBase::Custom("State".to_string()), array_dims: Vec::new() }; + let scope = lowering_scope_from_types(types)?; + let mut lowered = args.to_vec(); + lowered[1] = lower_struct_value_to_state_object_expr( + state_arg, + &state_type, + &scope, + structs, + contract_fields, + contract_constants, + )?; + lowered + } + } + } else { + args.to_vec() + }; return compile_validate_output_state_statement( - args, + &lowered_args, env, params, types, @@ -2342,6 +2458,7 @@ fn compile_statement<'i>( contract_constants, ); } + let function = functions.get(name).ok_or_else(|| CompilerError::Unsupported(format!("function '{}' not found", name)))?; let returns = compile_inline_call( name, args, @@ -2352,6 +2469,8 @@ fn compile_statement<'i>( builder, options, contract_constants, + contract_fields, + structs, functions, function_order, function_index, @@ -2359,8 +2478,16 @@ fn compile_statement<'i>( recorder, )?; if !returns.is_empty() { + let flattened_returns = flatten_runtime_return_exprs( + &returns, + &function.return_types, + types, + structs, + contract_fields, + contract_constants, + )?; let mut stack_depth = 0i64; - for expr in returns { + for expr in flattened_returns { compile_expr( &expr, env, @@ -2387,6 +2514,7 @@ fn compile_statement<'i>( env, types, contract_fields, + contract_field_prefix_len, script_size, contract_constants, ); @@ -2397,7 +2525,36 @@ fn compile_statement<'i>( ))) } Statement::StructDestructure { .. } => { - Err(CompilerError::Unsupported("struct destructuring should be lowered before compilation".to_string())) + let Statement::StructDestructure { bindings, expr, span } = stmt else { unreachable!() }; + for binding in bindings { + if struct_name_from_type_ref(&binding.type_ref, structs).is_some() { + types.insert(binding.name.clone(), type_name_from_ref(&binding.type_ref)); + } + } + let mut scope = lowering_scope_from_types(types)?; + for lowered_stmt in + lower_struct_destructure_statement(bindings, expr, *span, &mut scope, structs, contract_fields, contract_constants)? + { + compile_statement( + &lowered_stmt, + env, + params, + types, + builder, + options, + contract_fields, + contract_field_prefix_len, + contract_constants, + structs, + functions, + function_order, + function_index, + yields, + script_size, + recorder, + )?; + } + Ok(()) } Statement::FunctionCallAssign { bindings, name, args, .. } => { let function = functions.get(name).ok_or_else(|| CompilerError::Unsupported(format!("function '{}' not found", name)))?; @@ -2424,6 +2581,8 @@ fn compile_statement<'i>( builder, options, contract_constants, + contract_fields, + structs, functions, function_order, function_index, @@ -2434,13 +2593,42 @@ fn compile_statement<'i>( return Err(CompilerError::Unsupported("return values count must match function return types".to_string())); } for (binding, expr) in bindings.iter().zip(returns.into_iter()) { - env.insert(binding.name.clone(), expr); - types.insert(binding.name.clone(), type_name_from_ref(&binding.type_ref)); + if struct_name_from_type_ref(&binding.type_ref, structs).is_some() { + store_struct_binding( + &binding.name, + &binding.type_ref, + &expr, + env, + types, + structs, + contract_fields, + contract_constants, + false, + )?; + } else { + let lowered = lower_runtime_expr(&expr, types, structs)?; + env.insert(binding.name.clone(), lowered); + types.insert(binding.name.clone(), type_name_from_ref(&binding.type_ref)); + } } Ok(()) } Statement::Assign { name, expr, .. } => { if let Some(type_name) = types.get(name) { + let expected_type_ref = parse_type_ref(type_name)?; + if struct_name_from_type_ref(&expected_type_ref, structs).is_some() { + return store_struct_binding( + name, + &expected_type_ref, + expr, + env, + types, + structs, + contract_fields, + contract_constants, + true, + ); + } if is_array_type(type_name) { match &expr.kind { ExprKind::Identifier(other) => match types.get(other) { @@ -2460,13 +2648,20 @@ fn compile_statement<'i>( } } } - let expected_type_ref = parse_type_ref(type_name)?; - if !expr_matches_return_type_ref(expr, &expected_type_ref, types, contract_constants) { + let lowered_expr = lower_runtime_expr(expr, types, structs)?; + if !expr_matches_return_type_ref(&lowered_expr, &expected_type_ref, types, contract_constants) { return Err(CompilerError::Unsupported(format!("variable '{}' expects {}", name, type_name))); } + let updated = + if let Some(previous) = env.get(name) { replace_identifier(&lowered_expr, name, previous) } else { lowered_expr }; + let resolved = resolve_expr_for_runtime(updated, env, types, &mut HashSet::new())?; + env.insert(name.clone(), resolved); + return Ok(()); } - let updated = if let Some(previous) = env.get(name) { replace_identifier(expr, name, previous) } else { expr.clone() }; - let resolved = resolve_expr(updated, env, &mut HashSet::new())?; + let lowered_expr = lower_runtime_expr(expr, types, structs)?; + let updated = + if let Some(previous) = env.get(name) { replace_identifier(&lowered_expr, name, previous) } else { lowered_expr }; + let resolved = resolve_expr_for_runtime(updated, env, types, &mut HashSet::new())?; env.insert(name.clone(), resolved); Ok(()) } @@ -2502,16 +2697,24 @@ fn encoded_field_chunk_size<'i>( Ok(data_prefix(payload_size).len() + payload_size) } +fn encoded_state_len<'i>( + contract_fields: &[ContractFieldAst<'i>], + contract_constants: &HashMap>, +) -> Result { + contract_fields.iter().try_fold(0usize, |acc, field| Ok(acc + encoded_field_chunk_size(field, contract_constants)?)) +} + fn read_input_state_binding_expr<'i>( input_idx: &Expr<'i>, field: &ContractFieldAst<'i>, + state_start_offset: usize, field_chunk_offset: usize, script_size_value: i64, contract_constants: &HashMap>, ) -> Result, CompilerError> { let (field_payload_offset, field_payload_len, decode_int) = if field.type_ref.array_dims.is_empty() && field.type_ref.base == TypeBase::Int { - (field_chunk_offset + 1, 8usize, true) + (state_start_offset + field_chunk_offset + 1, 8usize, true) } else if field.type_ref.base == TypeBase::Byte { let payload_len = if field.type_ref.array_dims.is_empty() { 1usize @@ -2523,7 +2726,7 @@ fn read_input_state_binding_expr<'i>( )) })? }; - (field_chunk_offset + data_prefix(payload_len).len(), payload_len, false) + (state_start_offset + field_chunk_offset + data_prefix(payload_len).len(), payload_len, false) } else { return Err(CompilerError::Unsupported(format!( "readInputState does not support field type {}", @@ -2558,6 +2761,7 @@ fn compile_read_input_state_statement<'i>( env: &mut HashMap>, types: &mut HashMap, contract_fields: &[ContractFieldAst<'i>], + contract_field_prefix_len: usize, script_size: Option, contract_constants: &HashMap>, ) -> Result<(), CompilerError> { @@ -2580,6 +2784,11 @@ fn compile_read_input_state_statement<'i>( return Err(CompilerError::Unsupported("readInputState bindings must include all contract fields exactly once".to_string())); } + let total_state_len = encoded_state_len(contract_fields, contract_constants)?; + let state_start_offset = contract_field_prefix_len + .checked_sub(total_state_len) + .ok_or_else(|| CompilerError::Unsupported("readInputState state offset underflow".to_string()))?; + let input_idx = args[0].clone(); let mut field_chunk_offset = 0usize; @@ -2594,8 +2803,14 @@ fn compile_read_input_state_statement<'i>( return Err(CompilerError::Unsupported(format!("readInputState binding '{}' expects {}", binding.name, field_type))); } - let binding_expr = - read_input_state_binding_expr(&input_idx, field, field_chunk_offset, script_size_value, contract_constants)?; + let binding_expr = read_input_state_binding_expr( + &input_idx, + field, + state_start_offset, + field_chunk_offset, + script_size_value, + contract_constants, + )?; env.insert(binding.name.clone(), binding_expr); types.insert(binding.name.clone(), binding_type); @@ -2640,6 +2855,11 @@ fn compile_validate_output_state_statement( return Err(CompilerError::Unsupported("new_state must include all contract fields exactly once".to_string())); } + let total_state_len = encoded_state_len(contract_fields, contract_constants)?; + let state_start_offset = contract_field_prefix_len + .checked_sub(total_state_len) + .ok_or_else(|| CompilerError::Unsupported("validateOutputState state offset underflow".to_string()))?; + let mut stack_depth = 0i64; for field in contract_fields { let Some(new_value) = provided.remove(field.name.as_str()) else { @@ -2708,9 +2928,41 @@ fn compile_validate_output_state_statement( stack_depth -= 1; } + for _ in 1..contract_fields.len() { + builder.add_op(OpCat)?; + stack_depth -= 1; + } + let script_size_value = script_size.ok_or_else(|| CompilerError::Unsupported("validateOutputState requires this.scriptSize".to_string()))?; + // Build: prefix || encoded_new_state || suffix where fields sit at [state_start_offset, contract_field_prefix_len). + if state_start_offset > 0 { + builder.add_op(OpTxInputIndex)?; + stack_depth += 1; + builder.add_op(OpDup)?; + stack_depth += 1; + builder.add_op(OpTxInputScriptSigLen)?; + builder.add_op(OpDup)?; + stack_depth += 1; + builder.add_i64(script_size_value)?; + stack_depth += 1; + builder.add_op(OpSub)?; + stack_depth -= 1; + builder.add_i64(state_start_offset as i64)?; + stack_depth += 1; + builder.add_op(OpAdd)?; + stack_depth -= 1; + builder.add_op(OpSwap)?; + builder.add_op(OpTxInputScriptSigSubstr)?; + stack_depth -= 2; + + // Prefix || encoded_new_state + builder.add_op(OpSwap)?; + builder.add_op(OpCat)?; + stack_depth -= 1; + } + builder.add_op(OpTxInputIndex)?; stack_depth += 1; builder.add_op(OpDup)?; @@ -2730,10 +2982,9 @@ fn compile_validate_output_state_statement( builder.add_op(OpTxInputScriptSigSubstr)?; stack_depth -= 2; - for _ in 0..contract_fields.len() { - builder.add_op(OpCat)?; - stack_depth -= 1; - } + // Prefix || encoded_new_state || suffix + builder.add_op(OpCat)?; + stack_depth -= 1; builder.add_op(OpBlake2b)?; builder.add_data(&[0x00, 0x00])?; @@ -2773,17 +3024,124 @@ fn compile_validate_output_state_statement( Ok(()) } +#[derive(Debug)] +struct InlineCallBindings<'i> { + env: HashMap>, + debug_env: HashMap>, + types: HashMap, + compile_params: HashMap, + yield_rewrites: Vec<(String, Expr<'i>)>, +} + +fn prepare_inline_call_bindings<'i>( + callee_name: &str, + function: &FunctionAst<'i>, + args: &[Expr<'i>], + caller_params: &HashMap, + caller_types: &HashMap, + caller_env: &HashMap>, + contract_constants: &HashMap>, + structs: &StructRegistry, + contract_fields: &[ContractFieldAst<'i>], +) -> Result, CompilerError> { + let mut types = caller_types.clone(); + let mut env: HashMap> = contract_constants.clone(); + env.extend(caller_env.clone()); + let mut yield_rewrites = Vec::new(); + let caller_scope = lowering_scope_from_types(caller_types)?; + for (param, arg) in function.params.iter().zip(args.iter()) { + let resolved = resolve_expr(arg.clone(), caller_env, &mut HashSet::new())?; + let param_type_name = type_name_from_ref(¶m.type_ref); + + types.insert(param.name.clone(), param_type_name.clone()); + if struct_name_from_type_ref(¶m.type_ref, structs).is_some() { + yield_rewrites.push((param.name.clone(), resolved.clone())); + if !matches!(&resolved.kind, ExprKind::Identifier(identifier) if identifier == ¶m.name && caller_params.contains_key(identifier)) + { + env.insert(param.name.clone(), resolved.clone()); + } + for ((path, field_type), lowered_expr) in flatten_type_ref_leaves(¶m.type_ref, structs)? + .into_iter() + .zip(lower_struct_value_expr(&resolved, ¶m.type_ref, &caller_scope, structs, contract_fields, contract_constants)?) + { + let leaf_name = flattened_struct_name(¶m.name, &path); + let lowered_expr = resolve_expr(lowered_expr, caller_env, &mut HashSet::new())?; + types.insert(leaf_name.clone(), type_name_from_ref(&field_type)); + if !matches!(&lowered_expr.kind, ExprKind::Identifier(identifier) if identifier == &leaf_name && caller_params.contains_key(identifier)) + { + env.insert(leaf_name, lowered_expr); + } + } + } else { + let (lowered, rewrite_expr) = if is_array_type(¶m_type_name) { + match arg { + Expr { kind: ExprKind::Identifier(identifier), .. } + if caller_types + .get(identifier) + .is_some_and(|other_type| is_type_assignable(other_type, ¶m_type_name, contract_constants)) => + { + ( + caller_env + .get(identifier) + .cloned() + .unwrap_or_else(|| Expr::new(ExprKind::Identifier(identifier.clone()), span::Span::default())), + Expr::new(ExprKind::Identifier(identifier.clone()), span::Span::default()), + ) + } + _ => { + let lowered = lower_runtime_expr(&resolved, caller_types, structs)?; + (lowered.clone(), lowered) + } + } + } else { + let lowered = lower_runtime_expr(&resolved, caller_types, structs)?; + (lowered.clone(), lowered) + }; + yield_rewrites.push((param.name.clone(), rewrite_expr)); + if !matches!(&lowered.kind, ExprKind::Identifier(identifier) if identifier == ¶m.name && caller_params.contains_key(identifier)) + { + env.insert(param.name.clone(), lowered); + } + } + } + + let debug_env = env.clone(); + let compile_params = caller_params.clone(); + + let _ = callee_name; + + Ok(InlineCallBindings { env, debug_env, types, compile_params, yield_rewrites }) +} + +fn rewrite_inline_yields<'i>(yields: Vec>, rewrites: &[(String, Expr<'i>)]) -> Vec> { + if rewrites.is_empty() { + return yields; + } + yields + .into_iter() + .map(|expr| { + let mut current = expr; + for (temp_name, replacement) in rewrites { + current = replace_identifier(¤t, temp_name, replacement); + } + current + }) + .collect() +} + #[allow(clippy::too_many_arguments)] fn compile_inline_call<'i>( name: &str, args: &[Expr<'i>], call_span: SourceSpan, - params: &HashMap, + caller_params: &HashMap, caller_types: &mut HashMap, caller_env: &mut HashMap>, builder: &mut ScriptBuilder, options: CompileOptions, contract_constants: &HashMap>, + contract_fields: &[ContractFieldAst<'i>], + structs: &StructRegistry, functions: &HashMap>, function_order: &HashMap, caller_index: usize, @@ -2800,39 +3158,50 @@ fn compile_inline_call<'i>( if function.params.len() != args.len() { return Err(CompilerError::Unsupported(format!("function '{}' expects {} arguments", name, function.params.len()))); } - for (param, arg) in function.params.iter().zip(args.iter()) { - let param_type_name = type_name_from_ref(¶m.type_ref); - if !expr_matches_type_with_env(arg, ¶m_type_name, caller_types, contract_constants) { - return Err(CompilerError::Unsupported(format!("function argument '{}' expects {}", param.name, param_type_name))); + + if args.len() == function.params.len() { + for (param, arg) in function.params.iter().zip(args.iter()) { + let param_type_name = type_name_from_ref(¶m.type_ref); + let matches = if struct_name_from_type_ref(¶m.type_ref, structs).is_some() { + lower_runtime_struct_expr(arg, ¶m.type_ref, caller_types, structs, contract_fields, contract_constants).is_ok() + } else if struct_array_name_from_type_ref(¶m.type_ref, structs).is_some() { + match &arg.kind { + ExprKind::Identifier(name) => caller_types + .get(name) + .and_then(|type_name| parse_type_ref(type_name).ok()) + .is_some_and(|type_ref| is_type_assignable_ref(&type_ref, ¶m.type_ref, contract_constants)), + _ => expr_matches_declared_type_ref(arg, ¶m.type_ref, structs), + } + } else { + expr_matches_type_with_env(arg, ¶m_type_name, caller_types, contract_constants) + }; + if !matches { + return Err(CompilerError::Unsupported(format!("function argument '{}' expects {}", param.name, param_type_name))); + } } } - let mut types = - function.params.iter().map(|param| (param.name.clone(), type_name_from_ref(¶m.type_ref))).collect::>(); for param in &function.params { let param_type_name = type_name_from_ref(¶m.type_ref); - if is_array_type(¶m_type_name) && array_element_size(¶m_type_name).is_none() { + if is_array_type(¶m_type_name) + && array_element_size(¶m_type_name).is_none() + && struct_array_name_from_type_ref(¶m.type_ref, structs).is_none() + { return Err(CompilerError::Unsupported(format!("array element type must have known size: {}", param_type_name))); } } - let mut env: HashMap> = contract_constants.clone(); - // Copy the caller's __arg_ (function param) bindings into the new inline call's env, allowing nested synthetic argument chain. - for (name, value) in caller_env.iter() { - if name.starts_with(SYNTHETIC_ARG_PREFIX) { - env.insert(name.clone(), value.clone()); - } - } - for (index, (param, arg)) in function.params.iter().zip(args.iter()).enumerate() { - let resolved = resolve_expr(arg.clone(), caller_env, &mut HashSet::new())?; - let temp_name = format!("{SYNTHETIC_ARG_PREFIX}_{name}_{index}"); - let param_type_name = type_name_from_ref(¶m.type_ref); - env.insert(temp_name.clone(), resolved.clone()); - types.insert(temp_name.clone(), param_type_name.clone()); - env.insert(param.name.clone(), Expr::new(ExprKind::Identifier(temp_name.clone()), span::Span::default())); - caller_env.insert(temp_name.clone(), resolved); - caller_types.insert(temp_name, param_type_name); - } + let mut bindings = prepare_inline_call_bindings( + name, + function, + args, + caller_params, + caller_types, + caller_env, + contract_constants, + structs, + contract_fields, + )?; if !options.allow_yield && function.body.iter().any(contains_yield) { return Err(CompilerError::Unsupported("yield requires allow_yield=true".to_string())); @@ -2856,33 +3225,35 @@ fn compile_inline_call<'i>( } let call_start = builder.script().len(); - recorder.begin_inline_call(call_span, call_start, function, &env)?; + recorder.begin_inline_call(call_span, call_start, function, &bindings.debug_env)?; let mut yields: Vec> = Vec::new(); let body_len = function.body.len(); for (index, stmt) in function.body.iter().enumerate() { - recorder.begin_statement_at(builder.script().len(), &env); + recorder.begin_statement_at(builder.script().len(), &bindings.env); if let Statement::Return { exprs, .. } = stmt { if index != body_len - 1 { return Err(CompilerError::Unsupported("return statement must be the last statement".to_string())); } - validate_return_types(exprs, &function.return_types, &types, contract_constants) + validate_return_types(exprs, &function.return_types, &bindings.types, structs, contract_fields, contract_constants) .map_err(|err| err.with_span(&stmt.span()))?; for expr in exprs { - let resolved = resolve_expr(expr.clone(), &env, &mut HashSet::new()).map_err(|err| err.with_span(&expr.span))?; + let resolved = resolve_expr_for_runtime(expr.clone(), &bindings.env, &bindings.types, &mut HashSet::new()) + .map_err(|err| err.with_span(&expr.span))?; yields.push(resolved); } } else { compile_statement( stmt, - &mut env, - params, - &mut types, + &mut bindings.env, + &bindings.compile_params, + &mut bindings.types, builder, options, - &[], + contract_fields, 0, contract_constants, + structs, functions, function_order, callee_index, @@ -2892,21 +3263,12 @@ fn compile_inline_call<'i>( ) .map_err(|err| err.with_span(&stmt.span()))?; } - recorder.finish_statement_at(stmt, builder.script().len(), &env, &types)?; + recorder.finish_statement_at(stmt, builder.script().len(), &bindings.env, &bindings.types)?; } let call_end = builder.script().len(); recorder.finish_inline_call(call_span, call_end, name); - for (name, value) in env.iter() { - if name.starts_with(SYNTHETIC_ARG_PREFIX) { - if let Some(type_name) = types.get(name) { - caller_types.entry(name.clone()).or_insert_with(|| type_name.clone()); - } - caller_env.entry(name.clone()).or_insert_with(|| value.clone()); - } - } - - Ok(yields) + Ok(rewrite_inline_yields(yields, &bindings.yield_rewrites)) } #[allow(clippy::too_many_arguments)] @@ -2922,6 +3284,7 @@ fn compile_if_statement<'i>( contract_fields: &[ContractFieldAst<'i>], contract_field_prefix_len: usize, contract_constants: &HashMap>, + structs: &StructRegistry, functions: &HashMap>, function_order: &HashMap, function_index: usize, @@ -2929,9 +3292,10 @@ fn compile_if_statement<'i>( script_size: Option, recorder: &mut DebugRecorder<'i>, ) -> Result<(), CompilerError> { + let condition = lower_runtime_expr(condition, types, structs)?; let mut stack_depth = 0i64; compile_expr( - condition, + &condition, env, params, types, @@ -2957,6 +3321,7 @@ fn compile_if_statement<'i>( contract_fields, contract_field_prefix_len, contract_constants, + structs, functions, function_order, function_index, @@ -2979,6 +3344,7 @@ fn compile_if_statement<'i>( contract_fields, contract_field_prefix_len, contract_constants, + structs, functions, function_order, function_index, @@ -2990,7 +3356,7 @@ fn compile_if_statement<'i>( builder.add_op(OpEndIf)?; - let resolved_condition = resolve_expr(condition.clone(), &original_env, &mut HashSet::new())?; + let resolved_condition = resolve_expr_for_runtime(condition, &original_env, types, &mut HashSet::new())?; merge_env_after_if(env, &original_env, &then_env, &else_env, &resolved_condition); Ok(()) } @@ -3061,6 +3427,7 @@ fn compile_block<'i>( contract_fields: &[ContractFieldAst<'i>], contract_field_prefix_len: usize, contract_constants: &HashMap>, + structs: &StructRegistry, functions: &HashMap>, function_order: &HashMap, function_index: usize, @@ -3080,6 +3447,7 @@ fn compile_block<'i>( contract_fields, contract_field_prefix_len, contract_constants, + structs, functions, function_order, function_index, @@ -3108,6 +3476,7 @@ fn compile_for_statement<'i>( contract_fields: &[ContractFieldAst<'i>], contract_field_prefix_len: usize, contract_constants: &HashMap>, + structs: &StructRegistry, functions: &HashMap>, function_order: &HashMap, function_index: usize, @@ -3137,6 +3506,7 @@ fn compile_for_statement<'i>( contract_fields, contract_field_prefix_len, contract_constants, + structs, functions, function_order, function_index, @@ -3307,6 +3677,140 @@ fn resolve_expr<'i>( } } +fn resolve_expr_for_runtime<'i>( + expr: Expr<'i>, + env: &HashMap>, + types: &HashMap, + visiting: &mut HashSet, +) -> Result, CompilerError> { + let Expr { kind, span } = expr; + match kind { + ExprKind::Identifier(name) => { + if name.starts_with(SYNTHETIC_ARG_PREFIX) || types.get(&name).is_some_and(|type_name| is_array_type(type_name)) { + return Ok(Expr::new(ExprKind::Identifier(name), span)); + } + if let Some(value) = env.get(&name) { + if !visiting.insert(name.clone()) { + return Err(CompilerError::CyclicIdentifier(name)); + } + let resolved = resolve_expr_for_runtime(value.clone(), env, types, visiting)?; + visiting.remove(&name); + Ok(resolved) + } else { + Ok(Expr::new(ExprKind::Identifier(name), span)) + } + } + ExprKind::Unary { op, expr } => { + Ok(Expr::new(ExprKind::Unary { op, expr: Box::new(resolve_expr_for_runtime(*expr, env, types, visiting)?) }, span)) + } + ExprKind::Binary { op, left, right } => Ok(Expr::new( + ExprKind::Binary { + op, + left: Box::new(resolve_expr_for_runtime(*left, env, types, visiting)?), + right: Box::new(resolve_expr_for_runtime(*right, env, types, visiting)?), + }, + span, + )), + ExprKind::IfElse { condition, then_expr, else_expr } => Ok(Expr::new( + ExprKind::IfElse { + condition: Box::new(resolve_expr_for_runtime(*condition, env, types, visiting)?), + then_expr: Box::new(resolve_expr_for_runtime(*then_expr, env, types, visiting)?), + else_expr: Box::new(resolve_expr_for_runtime(*else_expr, env, types, visiting)?), + }, + span, + )), + ExprKind::Array(values) => Ok(Expr::new( + ExprKind::Array( + values + .into_iter() + .map(|value| resolve_expr_for_runtime(value, env, types, visiting)) + .collect::, _>>()?, + ), + span, + )), + ExprKind::StateObject(fields) => Ok(Expr::new( + ExprKind::StateObject( + fields + .into_iter() + .map(|field| { + Ok(StateFieldExpr { + name: field.name, + expr: resolve_expr_for_runtime(field.expr, env, types, visiting)?, + span: field.span, + name_span: field.name_span, + }) + }) + .collect::, CompilerError>>()?, + ), + span, + )), + ExprKind::FieldAccess { source, field, field_span } => Ok(Expr::new( + ExprKind::FieldAccess { source: Box::new(resolve_expr_for_runtime(*source, env, types, visiting)?), field, field_span }, + span, + )), + ExprKind::Call { name, args, name_span } => Ok(Expr::new( + ExprKind::Call { + name, + args: args + .into_iter() + .map(|arg| resolve_expr_for_runtime(arg, env, types, visiting)) + .collect::, _>>()?, + name_span, + }, + span, + )), + ExprKind::New { name, args, name_span } => Ok(Expr::new( + ExprKind::New { + name, + args: args + .into_iter() + .map(|arg| resolve_expr_for_runtime(arg, env, types, visiting)) + .collect::, _>>()?, + name_span, + }, + span, + )), + ExprKind::Split { source, index, part, span: split_span } => Ok(Expr::new( + ExprKind::Split { + source: Box::new(resolve_expr_for_runtime(*source, env, types, visiting)?), + index: Box::new(resolve_expr_for_runtime(*index, env, types, visiting)?), + part, + span: split_span, + }, + span, + )), + ExprKind::ArrayIndex { source, index } => Ok(Expr::new( + ExprKind::ArrayIndex { + source: Box::new(resolve_expr_for_runtime(*source, env, types, visiting)?), + index: Box::new(resolve_expr_for_runtime(*index, env, types, visiting)?), + }, + span, + )), + ExprKind::Introspection { kind, index, field_span } => Ok(Expr::new( + ExprKind::Introspection { kind, index: Box::new(resolve_expr_for_runtime(*index, env, types, visiting)?), field_span }, + span, + )), + ExprKind::UnarySuffix { source, kind, span: suffix_span } => Ok(Expr::new( + ExprKind::UnarySuffix { + source: Box::new(resolve_expr_for_runtime(*source, env, types, visiting)?), + kind, + span: suffix_span, + }, + span, + )), + ExprKind::Slice { source, start, end, span: slice_span } => Ok(Expr::new( + ExprKind::Slice { + source: Box::new(resolve_expr_for_runtime(*source, env, types, visiting)?), + start: Box::new(resolve_expr_for_runtime(*start, env, types, visiting)?), + end: Box::new(resolve_expr_for_runtime(*end, env, types, visiting)?), + span: slice_span, + }, + span, + )), + other => Ok(Expr::new(other, span)), + } +} + /// Replace `target` identifiers in `expr` with `replacement`. /// /// Example: for `x = x + 1`, this rewrites the right side to diff --git a/silverscript-lang/src/compiler/covenant_declarations.rs b/silverscript-lang/src/compiler/covenant_declarations.rs new file mode 100644 index 0000000..d2b3a14 --- /dev/null +++ b/silverscript-lang/src/compiler/covenant_declarations.rs @@ -0,0 +1,1909 @@ +use super::*; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CovenantBinding { + Auth, + Cov, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CovenantMode { + Verification, + Transition, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CovenantGroups { + Single, + Multiple, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CovenantTermination { + Disallowed, + Allowed, +} + +#[derive(Debug, Clone)] +struct CovenantDeclaration<'i> { + binding: CovenantBinding, + mode: CovenantMode, + groups: CovenantGroups, + singleton: bool, + termination: CovenantTermination, + from_expr: Expr<'i>, + to_expr: Expr<'i>, +} + +#[derive(Debug, Clone)] +enum OutputStateSource<'i> { + Single(Expr<'i>), + PerOutputArrays { + // field_name -> array_binding_name + field_arrays: Vec<(String, String)>, + length_expr: Expr<'i>, + }, +} + +#[derive(Debug, Clone)] +struct VerificationShape<'i> { + prev_field_values: Vec<(String, String)>, + new_field_arrays: Vec<(String, String)>, + entrypoint_params: Vec>, + call_args: Vec>, +} + +#[derive(Debug, Clone)] +struct TransitionShape<'i> { + entrypoint_params: Vec>, + call_args: Vec>, +} + +pub(super) fn lower_covenant_declarations<'i>( + contract: &ContractAst<'i>, + constants: &HashMap>, + preserve_policy_structs: bool, +) -> Result, CompilerError> { + let mut lowered = Vec::new(); + + for function in &contract.functions { + if function.attributes.is_empty() { + lowered.push(function.clone()); + continue; + } + + let declaration = parse_covenant_declaration(function, constants)?; + let desugared_policy = desugar_covenant_policy_state_syntax(function, &declaration, &contract.fields)?; + + let policy_name = generated_covenant_policy_name(&function.name); + + let mut wrapper_policy = desugared_policy; + wrapper_policy.name = policy_name.clone(); + wrapper_policy.entrypoint = false; + wrapper_policy.attributes.clear(); + + let mut policy = if preserve_policy_structs { function.clone() } else { wrapper_policy.clone() }; + policy.name = policy_name.clone(); + policy.entrypoint = false; + policy.attributes.clear(); + lowered.push(policy); + + match declaration.binding { + CovenantBinding::Auth => { + let entrypoint_name = generated_covenant_entrypoint_name(&function.name); + let mut wrapper = + build_auth_wrapper(&wrapper_policy, &policy_name, declaration.clone(), entrypoint_name, &contract.fields)?; + if preserve_policy_structs { + wrapper.params = preserved_entrypoint_params(function, declaration, true, &contract.fields); + } + lowered.push(wrapper); + } + CovenantBinding::Cov => { + let leader_name = generated_covenant_leader_entrypoint_name(&function.name); + let mut leader_wrapper = + build_cov_wrapper(&wrapper_policy, &policy_name, declaration.clone(), leader_name, true, &contract.fields)?; + if preserve_policy_structs { + leader_wrapper.params = preserved_entrypoint_params(function, declaration.clone(), true, &contract.fields); + } + lowered.push(leader_wrapper); + + let delegate_name = generated_covenant_delegate_entrypoint_name(&function.name); + let mut delegate_wrapper = + build_cov_wrapper(&wrapper_policy, &policy_name, declaration.clone(), delegate_name, false, &contract.fields)?; + if preserve_policy_structs { + delegate_wrapper.params = preserved_entrypoint_params(function, declaration, false, &contract.fields); + } + lowered.push(delegate_wrapper); + } + } + } + + let mut lowered_contract = contract.clone(); + lowered_contract.functions = lowered; + Ok(lowered_contract) +} + +fn parse_covenant_declaration<'i>( + function: &FunctionAst<'i>, + constants: &HashMap>, +) -> Result, CompilerError> { + #[derive(Clone, Copy, PartialEq, Eq)] + enum CovenantSyntax { + Canonical, + Singleton, + Fanout, + } + + if function.entrypoint { + return Err(CompilerError::Unsupported( + "#[covenant(...)] must be applied to a policy function, not an entrypoint".to_string(), + )); + } + + if function.attributes.len() != 1 { + return Err(CompilerError::Unsupported("covenant declarations support exactly one #[covenant(...)] attribute".to_string())); + } + + let attribute = &function.attributes[0]; + let syntax = match attribute.path.as_slice() { + [head] if head == "covenant" => CovenantSyntax::Canonical, + [head, tail] if head == "covenant" && tail == "singleton" => CovenantSyntax::Singleton, + [head, tail] if head == "covenant" && tail == "fanout" => CovenantSyntax::Fanout, + _ => { + return Err(CompilerError::Unsupported(format!( + "unsupported function attribute #[{}]; expected #[covenant(...)], #[covenant.singleton], or #[covenant.fanout(...)]", + attribute.path.join(".") + ))); + } + }; + + let mut args_by_name: HashMap<&str, &Expr<'i>> = HashMap::new(); + for arg in &attribute.args { + if args_by_name.insert(arg.name.as_str(), &arg.expr).is_some() { + return Err(CompilerError::Unsupported(format!("duplicate covenant attribute argument '{}'", arg.name))); + } + } + + let allowed_keys: HashSet<&str> = ["binding", "from", "to", "mode", "groups", "termination"].into_iter().collect(); + for arg in &attribute.args { + if !allowed_keys.contains(arg.name.as_str()) { + return Err(CompilerError::Unsupported(format!("unknown covenant attribute argument '{}'", arg.name))); + } + } + + let (from_expr, to_expr) = match syntax { + CovenantSyntax::Canonical => { + let from_expr = args_by_name + .get("from") + .copied() + .ok_or_else(|| CompilerError::Unsupported("missing covenant attribute argument 'from'".to_string()))? + .clone(); + let to_expr = args_by_name + .get("to") + .copied() + .ok_or_else(|| CompilerError::Unsupported("missing covenant attribute argument 'to'".to_string()))? + .clone(); + (from_expr, to_expr) + } + CovenantSyntax::Singleton => { + if args_by_name.contains_key("from") || args_by_name.contains_key("to") { + return Err(CompilerError::Unsupported( + "covenant.singleton is sugar and does not accept 'from' or 'to' arguments".to_string(), + )); + } + (Expr::int(1), Expr::int(1)) + } + CovenantSyntax::Fanout => { + if args_by_name.contains_key("from") { + return Err(CompilerError::Unsupported( + "covenant.fanout is sugar and does not accept a 'from' argument (it is always 1)".to_string(), + )); + } + let to_expr = args_by_name + .get("to") + .copied() + .ok_or_else(|| CompilerError::Unsupported("missing covenant attribute argument 'to'".to_string()))? + .clone(); + (Expr::int(1), to_expr) + } + }; + + let from_value = eval_const_int(&from_expr, constants) + .map_err(|_| CompilerError::Unsupported("covenant 'from' must be a compile-time integer".to_string()))?; + let to_value = eval_const_int(&to_expr, constants) + .map_err(|_| CompilerError::Unsupported("covenant 'to' must be a compile-time integer".to_string()))?; + if from_value < 1 { + return Err(CompilerError::Unsupported("covenant 'from' must be >= 1".to_string())); + } + if to_value < 1 { + return Err(CompilerError::Unsupported("covenant 'to' must be >= 1".to_string())); + } + + let default_binding = if from_value == 1 { CovenantBinding::Auth } else { CovenantBinding::Cov }; + let binding = match args_by_name.get("binding").copied() { + Some(expr) => { + let binding_name = parse_attr_ident_arg("binding", Some(expr))?; + match binding_name.as_str() { + "auth" => CovenantBinding::Auth, + "cov" => CovenantBinding::Cov, + other => { + return Err(CompilerError::Unsupported(format!("covenant binding must be auth|cov, got '{}'", other))); + } + } + } + None => default_binding, + }; + + let mode = match args_by_name.get("mode").copied() { + Some(expr) => { + let mode_name = parse_attr_ident_arg("mode", Some(expr))?; + match mode_name.as_str() { + "verification" => CovenantMode::Verification, + "transition" => CovenantMode::Transition, + other => { + return Err(CompilerError::Unsupported(format!("covenant mode must be verification|transition, got '{}'", other))); + } + } + } + None => { + if function.return_types.is_empty() { + CovenantMode::Verification + } else { + CovenantMode::Transition + } + } + }; + + let groups = match args_by_name.get("groups").copied() { + Some(expr) => { + let groups_name = parse_attr_ident_arg("groups", Some(expr))?; + match groups_name.as_str() { + "single" => CovenantGroups::Single, + "multiple" => CovenantGroups::Multiple, + other => { + return Err(CompilerError::Unsupported(format!("covenant groups must be single|multiple, got '{}'", other))); + } + } + } + None => match binding { + CovenantBinding::Auth => CovenantGroups::Multiple, + CovenantBinding::Cov => CovenantGroups::Single, + }, + }; + + let termination = match args_by_name.get("termination").copied() { + Some(expr) => { + let termination_name = parse_attr_ident_arg("termination", Some(expr))?; + match termination_name.as_str() { + "disallowed" => CovenantTermination::Disallowed, + "allowed" => CovenantTermination::Allowed, + other => { + return Err(CompilerError::Unsupported(format!( + "covenant termination must be disallowed|allowed, got '{}'", + other + ))); + } + } + } + None => CovenantTermination::Disallowed, + }; + + if binding == CovenantBinding::Auth && from_value != 1 { + return Err(CompilerError::Unsupported("binding=auth requires from = 1".to_string())); + } + if binding == CovenantBinding::Cov && from_value == 1 && args_by_name.contains_key("binding") { + eprintln!( + "warning: #[covenant(...)] on function '{}' uses binding=cov with from=1; binding=auth is usually a better default", + function.name + ); + } + if binding == CovenantBinding::Cov && groups == CovenantGroups::Multiple { + return Err(CompilerError::Unsupported("binding=cov with groups=multiple is not supported yet".to_string())); + } + + if args_by_name.contains_key("termination") && mode != CovenantMode::Transition { + return Err(CompilerError::Unsupported("termination is only supported in mode=transition".to_string())); + } + if args_by_name.contains_key("termination") && !(from_value == 1 && to_value == 1) { + return Err(CompilerError::Unsupported("termination is only supported for singleton covenants (from=1, to=1)".to_string())); + } + + if mode == CovenantMode::Verification && !function.return_types.is_empty() { + return Err(CompilerError::Unsupported("verification mode policy functions must not declare return values".to_string())); + } + if mode == CovenantMode::Transition && function.return_types.is_empty() { + return Err(CompilerError::Unsupported("transition mode policy functions must declare return values".to_string())); + } + + Ok(CovenantDeclaration { + binding, + mode, + groups, + singleton: from_value == 1 && to_value == 1, + termination, + from_expr: from_expr.clone(), + to_expr: to_expr.clone(), + }) +} + +fn parse_attr_ident_arg<'i>(name: &str, value: Option<&Expr<'i>>) -> Result { + let value = value.ok_or_else(|| CompilerError::Unsupported(format!("missing covenant attribute argument '{}'", name)))?; + match &value.kind { + ExprKind::Identifier(identifier) => Ok(identifier.clone()), + _ => Err(CompilerError::Unsupported(format!("covenant attribute argument '{}' must be an identifier", name))), + } +} + +fn preserved_entrypoint_params<'i>( + function: &FunctionAst<'i>, + declaration: CovenantDeclaration<'i>, + leader: bool, + contract_fields: &[ContractFieldAst<'i>], +) -> Vec> { + if contract_fields.is_empty() { + return match (declaration.binding, leader) { + (CovenantBinding::Cov, false) => Vec::new(), + _ => function.params.clone(), + }; + } + + match (declaration.binding, declaration.mode, leader) { + (CovenantBinding::Auth, _, _) => function.params.iter().skip(1).cloned().collect(), + (CovenantBinding::Cov, CovenantMode::Verification, true) => function.params.iter().skip(1).cloned().collect(), + (CovenantBinding::Cov, CovenantMode::Transition, true) => function.params.clone(), + (CovenantBinding::Cov, _, false) => Vec::new(), + } +} + +fn build_auth_wrapper<'i>( + policy: &FunctionAst<'i>, + policy_name: &str, + declaration: CovenantDeclaration<'i>, + entrypoint_name: String, + contract_fields: &[ContractFieldAst<'i>], +) -> Result, CompilerError> { + let mut body = Vec::new(); + let mut entrypoint_params = policy.params.clone(); + + let active_input = active_input_index_expr(); + let out_count_name = "__cov_out_count"; + body.push(var_def_statement(int_type_ref(), out_count_name, Expr::call("OpAuthOutputCount", vec![active_input.clone()]))); + + if declaration.groups == CovenantGroups::Single { + let cov_id_name = "__cov_id"; + body.push(var_def_statement(bytes32_type_ref(), cov_id_name, Expr::call("OpInputCovenantId", vec![active_input.clone()]))); + let cov_out_count_name = "__cov_shared_out_count"; + body.push(var_def_statement( + int_type_ref(), + cov_out_count_name, + Expr::call("OpCovOutCount", vec![identifier_expr(cov_id_name)]), + )); + body.push(require_statement(binary_expr(BinaryOp::Eq, identifier_expr(cov_out_count_name), identifier_expr(out_count_name)))); + } + + if declaration.mode == CovenantMode::Verification && !contract_fields.is_empty() { + let shape = parse_verification_shape(policy, contract_fields, CovenantBinding::Auth)?; + entrypoint_params = shape.entrypoint_params.clone(); + body.push(call_statement(policy_name, shape.call_args)); + body.push(require_statement(binary_expr(BinaryOp::Le, identifier_expr(out_count_name), declaration.to_expr.clone()))); + append_auth_output_array_state_checks( + &mut body, + &active_input, + out_count_name, + declaration.to_expr.clone(), + shape.new_field_arrays, + contract_fields, + ); + } else { + let mut call_args: Vec> = policy.params.iter().map(|param| identifier_expr(¶m.name)).collect(); + if declaration.mode == CovenantMode::Transition && !contract_fields.is_empty() { + let shape = parse_transition_shape(policy, contract_fields, CovenantBinding::Auth)?; + entrypoint_params = shape.entrypoint_params; + call_args = shape.call_args; + } + let state_source = append_policy_call_and_capture_next_state( + &mut body, + policy, + policy_name, + declaration.mode, + declaration.singleton, + declaration.termination, + contract_fields, + call_args, + )?; + if !contract_fields.is_empty() { + match state_source { + OutputStateSource::Single(next_state_expr) => { + if declaration.mode == CovenantMode::Transition || declaration.singleton { + body.push(require_statement(binary_expr(BinaryOp::Eq, identifier_expr(out_count_name), Expr::int(1)))); + let out_idx_name = "__cov_out_idx"; + body.push(var_def_statement( + int_type_ref(), + out_idx_name, + Expr::call("OpAuthOutputIdx", vec![active_input.clone(), Expr::int(0)]), + )); + body.push(call_statement("validateOutputState", vec![identifier_expr(out_idx_name), next_state_expr])); + } else { + body.push(require_statement(binary_expr( + BinaryOp::Le, + identifier_expr(out_count_name), + declaration.to_expr.clone(), + ))); + append_auth_output_state_checks( + &mut body, + &active_input, + out_count_name, + declaration.to_expr.clone(), + next_state_expr, + ); + } + } + OutputStateSource::PerOutputArrays { field_arrays, length_expr } => { + body.push(require_statement(binary_expr( + BinaryOp::Le, + identifier_expr(out_count_name), + declaration.to_expr.clone(), + ))); + body.push(require_statement(binary_expr(BinaryOp::Eq, identifier_expr(out_count_name), length_expr.clone()))); + append_auth_output_array_state_checks( + &mut body, + &active_input, + out_count_name, + declaration.to_expr.clone(), + field_arrays, + contract_fields, + ); + } + } + } else { + body.push(require_statement(binary_expr(BinaryOp::Le, identifier_expr(out_count_name), declaration.to_expr.clone()))); + } + } + + Ok(generated_entrypoint(policy, entrypoint_name, entrypoint_params, body)) +} + +fn build_cov_wrapper<'i>( + policy: &FunctionAst<'i>, + policy_name: &str, + declaration: CovenantDeclaration<'i>, + entrypoint_name: String, + leader: bool, + contract_fields: &[ContractFieldAst<'i>], +) -> Result, CompilerError> { + let mut body = Vec::new(); + let mut leader_params = policy.params.clone(); + + let active_input = active_input_index_expr(); + let cov_id_name = "__cov_id"; + body.push(var_def_statement(bytes32_type_ref(), cov_id_name, Expr::call("OpInputCovenantId", vec![active_input.clone()]))); + + let leader_idx_expr = Expr::call("OpCovInputIdx", vec![identifier_expr(cov_id_name), Expr::int(0)]); + body.push(require_statement(binary_expr(if leader { BinaryOp::Eq } else { BinaryOp::Ne }, leader_idx_expr, active_input))); + + if leader { + let in_count_name = "__cov_in_count"; + body.push(var_def_statement(int_type_ref(), in_count_name, Expr::call("OpCovInputCount", vec![identifier_expr(cov_id_name)]))); + body.push(require_statement(binary_expr(BinaryOp::Le, identifier_expr(in_count_name), declaration.from_expr.clone()))); + + let out_count_name = "__cov_out_count"; + body.push(var_def_statement(int_type_ref(), out_count_name, Expr::call("OpCovOutCount", vec![identifier_expr(cov_id_name)]))); + + if declaration.mode == CovenantMode::Verification && !contract_fields.is_empty() { + let shape = parse_verification_shape(policy, contract_fields, CovenantBinding::Cov)?; + leader_params = shape.entrypoint_params.clone(); + + append_cov_input_state_reads_into_policy_prev_arrays( + &mut body, + cov_id_name, + in_count_name, + declaration.from_expr.clone(), + contract_fields, + &shape.prev_field_values, + ); + body.push(call_statement(policy_name, shape.call_args)); + body.push(require_statement(binary_expr(BinaryOp::Le, identifier_expr(out_count_name), declaration.to_expr.clone()))); + append_cov_output_array_state_checks( + &mut body, + cov_id_name, + out_count_name, + declaration.to_expr.clone(), + shape.new_field_arrays, + contract_fields, + ); + } else { + let mut transition_shape: Option> = None; + if declaration.mode == CovenantMode::Transition && !contract_fields.is_empty() { + let shape = parse_transition_shape(policy, contract_fields, CovenantBinding::Cov)?; + leader_params = shape.entrypoint_params.clone(); + transition_shape = Some(shape); + } + append_cov_input_state_reads(&mut body, cov_id_name, in_count_name, declaration.from_expr.clone(), contract_fields); + let call_args = transition_shape + .map(|shape| shape.call_args) + .unwrap_or_else(|| policy.params.iter().map(|param| identifier_expr(¶m.name)).collect()); + let state_source = append_policy_call_and_capture_next_state( + &mut body, + policy, + policy_name, + declaration.mode, + declaration.singleton, + declaration.termination, + contract_fields, + call_args, + )?; + if !contract_fields.is_empty() { + match state_source { + OutputStateSource::Single(next_state_expr) => { + if declaration.mode == CovenantMode::Transition || declaration.singleton { + body.push(require_statement(binary_expr(BinaryOp::Eq, identifier_expr(out_count_name), Expr::int(1)))); + let out_idx_name = "__cov_out_idx"; + body.push(var_def_statement( + int_type_ref(), + out_idx_name, + Expr::call("OpCovOutputIdx", vec![identifier_expr(cov_id_name), Expr::int(0)]), + )); + body.push(call_statement("validateOutputState", vec![identifier_expr(out_idx_name), next_state_expr])); + } else { + body.push(require_statement(binary_expr( + BinaryOp::Le, + identifier_expr(out_count_name), + declaration.to_expr.clone(), + ))); + append_cov_output_state_checks( + &mut body, + cov_id_name, + out_count_name, + declaration.to_expr.clone(), + next_state_expr, + ); + } + } + OutputStateSource::PerOutputArrays { field_arrays, length_expr } => { + body.push(require_statement(binary_expr( + BinaryOp::Le, + identifier_expr(out_count_name), + declaration.to_expr.clone(), + ))); + body.push(require_statement(binary_expr(BinaryOp::Eq, identifier_expr(out_count_name), length_expr.clone()))); + append_cov_output_array_state_checks( + &mut body, + cov_id_name, + out_count_name, + declaration.to_expr.clone(), + field_arrays, + contract_fields, + ); + } + } + } else { + body.push(require_statement(binary_expr(BinaryOp::Le, identifier_expr(out_count_name), declaration.to_expr.clone()))); + } + } + } + + let params = if leader { leader_params } else { Vec::new() }; + Ok(generated_entrypoint(policy, entrypoint_name, params, body)) +} + +fn generated_entrypoint<'i>( + policy: &FunctionAst<'i>, + entrypoint_name: String, + params: Vec>, + body: Vec>, +) -> FunctionAst<'i> { + FunctionAst { + name: entrypoint_name, + attributes: Vec::new(), + params, + entrypoint: true, + return_types: Vec::new(), + body, + return_type_spans: Vec::new(), + span: policy.span, + name_span: policy.name_span, + body_span: policy.body_span, + } +} + +fn int_type_ref() -> TypeRef { + TypeRef { base: TypeBase::Int, array_dims: Vec::new() } +} + +fn bytes32_type_ref() -> TypeRef { + TypeRef { base: TypeBase::Byte, array_dims: vec![ArrayDim::Fixed(32)] } +} + +fn active_input_index_expr<'i>() -> Expr<'i> { + Expr::new(ExprKind::Nullary(NullaryOp::ActiveInputIndex), span::Span::default()) +} + +fn identifier_expr<'i>(name: &str) -> Expr<'i> { + Expr::new(ExprKind::Identifier(name.to_string()), span::Span::default()) +} + +fn binary_expr<'i>(op: BinaryOp, left: Expr<'i>, right: Expr<'i>) -> Expr<'i> { + Expr::new(ExprKind::Binary { op, left: Box::new(left), right: Box::new(right) }, span::Span::default()) +} + +fn var_def_statement<'i>(type_ref: TypeRef, name: &str, expr: Expr<'i>) -> Statement<'i> { + Statement::VariableDefinition { + type_ref, + modifiers: Vec::new(), + name: name.to_string(), + expr: Some(expr), + span: span::Span::default(), + type_span: span::Span::default(), + modifier_spans: Vec::new(), + name_span: span::Span::default(), + } +} + +fn var_decl_statement<'i>(type_ref: TypeRef, name: &str) -> Statement<'i> { + Statement::VariableDefinition { + type_ref, + modifiers: Vec::new(), + name: name.to_string(), + expr: None, + span: span::Span::default(), + type_span: span::Span::default(), + modifier_spans: Vec::new(), + name_span: span::Span::default(), + } +} + +fn require_statement<'i>(expr: Expr<'i>) -> Statement<'i> { + Statement::Require { expr, message: None, span: span::Span::default(), message_span: None } +} + +fn call_statement<'i>(name: &str, args: Vec>) -> Statement<'i> { + Statement::FunctionCall { name: name.to_string(), args, span: span::Span::default(), name_span: span::Span::default() } +} + +fn function_call_assign_statement<'i>(bindings: Vec>, name: &str, args: Vec>) -> Statement<'i> { + Statement::FunctionCallAssign { + bindings, + name: name.to_string(), + args, + span: span::Span::default(), + name_span: span::Span::default(), + } +} + +fn array_push_statement<'i>(name: &str, expr: Expr<'i>) -> Statement<'i> { + Statement::ArrayPush { name: name.to_string(), expr, span: span::Span::default(), name_span: span::Span::default() } +} + +fn typed_binding<'i>(type_ref: TypeRef, name: &str) -> crate::ast::ParamAst<'i> { + crate::ast::ParamAst { + type_ref, + name: name.to_string(), + span: span::Span::default(), + type_span: span::Span::default(), + name_span: span::Span::default(), + } +} + +fn if_statement<'i>(condition: Expr<'i>, then_branch: Vec>) -> Statement<'i> { + Statement::If { + condition, + then_branch, + else_branch: None, + span: span::Span::default(), + then_span: span::Span::default(), + else_span: None, + } +} + +fn for_statement<'i>(ident: &str, start: Expr<'i>, end: Expr<'i>, body: Vec>) -> Statement<'i> { + Statement::For { + ident: ident.to_string(), + start, + end, + body, + span: span::Span::default(), + ident_span: span::Span::default(), + body_span: span::Span::default(), + } +} + +fn state_binding<'i>(field_name: &str, type_ref: TypeRef, name: &str) -> StateBindingAst<'i> { + StateBindingAst { + field_name: field_name.to_string(), + type_ref, + name: name.to_string(), + span: span::Span::default(), + field_span: span::Span::default(), + type_span: span::Span::default(), + name_span: span::Span::default(), + } +} + +fn state_call_assign_statement<'i>(bindings: Vec>, name: &str, args: Vec>) -> Statement<'i> { + Statement::StateFunctionCallAssign { + bindings, + name: name.to_string(), + args, + span: span::Span::default(), + name_span: span::Span::default(), + } +} + +fn state_object_expr_from_contract_fields<'i>(contract_fields: &[ContractFieldAst<'i>]) -> Expr<'i> { + let fields = contract_fields + .iter() + .map(|field| StateFieldExpr { + name: field.name.clone(), + expr: identifier_expr(&field.name), + span: span::Span::default(), + name_span: span::Span::default(), + }) + .collect(); + Expr::new(ExprKind::StateObject(fields), span::Span::default()) +} + +fn state_object_expr_from_field_bindings<'i>( + contract_fields: &[ContractFieldAst<'i>], + binding_by_field: &HashMap, +) -> Expr<'i> { + let fields = contract_fields + .iter() + .map(|field| { + let binding_name = binding_by_field + .get(&field.name) + .cloned() + .unwrap_or_else(|| panic!("missing state binding for field '{}'", field.name)); + StateFieldExpr { + name: field.name.clone(), + expr: identifier_expr(&binding_name), + span: span::Span::default(), + name_span: span::Span::default(), + } + }) + .collect(); + Expr::new(ExprKind::StateObject(fields), span::Span::default()) +} + +fn state_object_expr_from_field_arrays_at_index<'i>( + contract_fields: &[ContractFieldAst<'i>], + field_arrays: &[(String, String)], + index_expr: Expr<'i>, +) -> Expr<'i> { + let by_field = field_arrays.iter().cloned().collect::>(); + let fields = contract_fields + .iter() + .map(|field| { + let array_name = + by_field.get(&field.name).cloned().unwrap_or_else(|| panic!("missing state array binding for field '{}'", field.name)); + StateFieldExpr { + name: field.name.clone(), + expr: Expr::new( + ExprKind::ArrayIndex { source: Box::new(identifier_expr(&array_name)), index: Box::new(index_expr.clone()) }, + span::Span::default(), + ), + span: span::Span::default(), + name_span: span::Span::default(), + } + }) + .collect(); + Expr::new(ExprKind::StateObject(fields), span::Span::default()) +} + +fn length_expr<'i>(expr: Expr<'i>) -> Expr<'i> { + Expr::new( + ExprKind::UnarySuffix { source: Box::new(expr), kind: UnarySuffixKind::Length, span: span::Span::default() }, + span::Span::default(), + ) +} + +fn return_type_is_per_output_array(return_type: &TypeRef, field_type: &TypeRef) -> bool { + return_type.base == field_type.base + && return_type.array_dims.len() == field_type.array_dims.len() + 1 + && return_type.array_dims[..field_type.array_dims.len()] == field_type.array_dims[..] +} + +fn dynamic_array_of(type_ref: &TypeRef) -> TypeRef { + let mut array_type = type_ref.clone(); + array_type.array_dims.push(ArrayDim::Dynamic); + array_type +} + +#[derive(Debug, Clone, Default)] +struct CovenantStateRewriteContext { + single_states: HashMap>, + state_arrays: HashMap>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CovenantReturnDesugaring { + Existing, + SingleState, + StateArray, +} + +fn is_state_type_ref(type_ref: &TypeRef) -> bool { + type_ref.array_dims.is_empty() && matches!(&type_ref.base, TypeBase::Custom(name) if name == "State") +} + +fn is_state_array_type_ref(type_ref: &TypeRef) -> bool { + !type_ref.array_dims.is_empty() && matches!(&type_ref.base, TypeBase::Custom(name) if name == "State") +} + +fn state_param_prefix(name: &str) -> String { + name.strip_suffix("_states").or_else(|| name.strip_suffix("_state")).map(ToOwned::to_owned).unwrap_or_else(|| name.to_string()) +} + +fn field_binding_name(base: &str, field_name: &str) -> String { + format!("{}_{}", state_param_prefix(base), field_name) +} + +fn append_desugared_state_param<'i>( + params: &mut Vec>, + ctx: &mut CovenantStateRewriteContext, + param: &crate::ast::ParamAst<'i>, + contract_fields: &[ContractFieldAst<'i>], +) { + if is_state_type_ref(¶m.type_ref) { + let bindings = + contract_fields.iter().map(|field| (field.name.clone(), field_binding_name(¶m.name, &field.name))).collect::>(); + ctx.single_states.insert(param.name.clone(), bindings.clone()); + for field in contract_fields { + params.push(typed_binding(field.type_ref.clone(), &field_binding_name(¶m.name, &field.name))); + } + } else if is_state_array_type_ref(¶m.type_ref) { + let bindings = + contract_fields.iter().map(|field| (field.name.clone(), field_binding_name(¶m.name, &field.name))).collect::>(); + ctx.state_arrays.insert(param.name.clone(), bindings.clone()); + for field in contract_fields { + params.push(typed_binding(dynamic_array_of(&field.type_ref), &field_binding_name(¶m.name, &field.name))); + } + } else { + params.push(param.clone()); + } +} + +fn append_desugared_state_params<'i>( + params: &mut Vec>, + ctx: &mut CovenantStateRewriteContext, + policy_params: &[crate::ast::ParamAst<'i>], + contract_fields: &[ContractFieldAst<'i>], +) { + for param in policy_params { + append_desugared_state_param(params, ctx, param, contract_fields); + } +} + +fn ordered_state_fields<'i>(expr: &Expr<'i>, contract_fields: &[ContractFieldAst<'i>]) -> Result>, CompilerError> { + let ExprKind::StateObject(entries) = &expr.kind else { + return Err(CompilerError::Unsupported("expected a State expression".to_string())); + }; + + let mut by_name = HashMap::new(); + for entry in entries { + if by_name.insert(entry.name.as_str(), entry.expr.clone()).is_some() { + return Err(CompilerError::Unsupported(format!("duplicate state field '{}'", entry.name))); + } + } + + let mut ordered = Vec::with_capacity(contract_fields.len()); + for field in contract_fields { + let expr = by_name + .remove(field.name.as_str()) + .ok_or_else(|| CompilerError::Unsupported(format!("missing state field '{}'", field.name)))?; + ordered.push(expr); + } + if let Some(extra) = by_name.keys().next() { + return Err(CompilerError::Unsupported(format!("unknown state field '{}'", extra))); + } + Ok(ordered) +} + +fn rewrite_state_expr_to_object<'i>( + expr: &Expr<'i>, + ctx: &CovenantStateRewriteContext, + contract_fields: &[ContractFieldAst<'i>], +) -> Result, CompilerError> { + match &expr.kind { + ExprKind::Identifier(name) => { + if let Some(bindings) = ctx.single_states.get(name) { + let by_field = bindings.iter().cloned().collect::>(); + return Ok(state_object_expr_from_field_bindings(contract_fields, &by_field)); + } + } + ExprKind::ArrayIndex { source, index } => { + if let ExprKind::Identifier(name) = &source.kind { + if let Some(bindings) = ctx.state_arrays.get(name) { + return Ok(state_object_expr_from_field_arrays_at_index( + contract_fields, + bindings, + rewrite_covenant_policy_expr(index, ctx, contract_fields)?, + )); + } + } + } + _ => {} + } + + rewrite_covenant_policy_expr(expr, ctx, contract_fields) +} + +fn expand_state_expr<'i>( + expr: &Expr<'i>, + ctx: &CovenantStateRewriteContext, + contract_fields: &[ContractFieldAst<'i>], +) -> Result>, CompilerError> { + match &expr.kind { + ExprKind::Identifier(name) => { + if let Some(bindings) = ctx.single_states.get(name) { + let by_field = bindings.iter().cloned().collect::>(); + return Ok(contract_fields + .iter() + .map(|field| { + let binding = by_field + .get(&field.name) + .cloned() + .unwrap_or_else(|| panic!("missing state binding for field '{}'", field.name)); + identifier_expr(&binding) + }) + .collect()); + } + } + ExprKind::ArrayIndex { source, index } => { + if let ExprKind::Identifier(name) = &source.kind { + if let Some(bindings) = ctx.state_arrays.get(name) { + let index_expr = rewrite_covenant_policy_expr(index, ctx, contract_fields)?; + return Ok(contract_fields + .iter() + .map(|field| { + let array_name = bindings + .iter() + .find(|(field_name, _)| field_name == &field.name) + .map(|(_, binding_name)| binding_name.clone()) + .unwrap_or_else(|| panic!("missing state array binding for field '{}'", field.name)); + Expr::new( + ExprKind::ArrayIndex { + source: Box::new(identifier_expr(&array_name)), + index: Box::new(index_expr.clone()), + }, + expr.span, + ) + }) + .collect()); + } + } + } + _ => {} + } + + let rewritten = rewrite_state_expr_to_object(expr, ctx, contract_fields)?; + ordered_state_fields(&rewritten, contract_fields) +} + +fn expand_state_array_expr<'i>( + expr: &Expr<'i>, + ctx: &CovenantStateRewriteContext, + contract_fields: &[ContractFieldAst<'i>], +) -> Result>, CompilerError> { + let ExprKind::Identifier(name) = &expr.kind else { + return Err(CompilerError::Unsupported("State[] covenant returns currently must be a named State[] value".to_string())); + }; + + let Some(bindings) = ctx.state_arrays.get(name) else { + return Err(CompilerError::Unsupported("State[] covenant returns currently must refer to a State[] parameter".to_string())); + }; + + Ok(contract_fields + .iter() + .map(|field| { + let array_name = bindings + .iter() + .find(|(field_name, _)| field_name == &field.name) + .map(|(_, binding_name)| binding_name.clone()) + .unwrap_or_else(|| panic!("missing state array binding for field '{}'", field.name)); + identifier_expr(&array_name) + }) + .collect()) +} + +fn rewrite_covenant_policy_expr<'i>( + expr: &Expr<'i>, + ctx: &CovenantStateRewriteContext, + contract_fields: &[ContractFieldAst<'i>], +) -> Result, CompilerError> { + match &expr.kind { + ExprKind::FieldAccess { source, field, field_span } => { + if let ExprKind::Identifier(name) = &source.kind { + if let Some(bindings) = ctx.single_states.get(name) { + let binding_name = bindings + .iter() + .find(|(field_name, _)| field_name == field) + .map(|(_, binding_name)| binding_name.clone()) + .ok_or_else(|| CompilerError::Unsupported(format!("State has no field '{}'", field)))?; + return Ok(Expr::new(ExprKind::Identifier(binding_name), expr.span)); + } + } + + if let ExprKind::ArrayIndex { source: array_source, index } = &source.kind { + if let ExprKind::Identifier(name) = &array_source.kind { + if let Some(bindings) = ctx.state_arrays.get(name) { + let array_name = bindings + .iter() + .find(|(field_name, _)| field_name == field) + .map(|(_, binding_name)| binding_name.clone()) + .ok_or_else(|| CompilerError::Unsupported(format!("State has no field '{}'", field)))?; + return Ok(Expr::new( + ExprKind::ArrayIndex { + source: Box::new(identifier_expr(&array_name)), + index: Box::new(rewrite_covenant_policy_expr(index, ctx, contract_fields)?), + }, + expr.span, + )); + } + } + } + + Ok(Expr::new( + ExprKind::FieldAccess { + source: Box::new(rewrite_covenant_policy_expr(source, ctx, contract_fields)?), + field: field.clone(), + field_span: *field_span, + }, + expr.span, + )) + } + ExprKind::ArrayIndex { source, index } => { + if let ExprKind::Identifier(name) = &source.kind { + if ctx.state_arrays.contains_key(name) { + return Ok(state_object_expr_from_field_arrays_at_index( + contract_fields, + ctx.state_arrays.get(name).expect("state array bindings exist"), + rewrite_covenant_policy_expr(index, ctx, contract_fields)?, + )); + } + } + + Ok(Expr::new( + ExprKind::ArrayIndex { + source: Box::new(rewrite_covenant_policy_expr(source, ctx, contract_fields)?), + index: Box::new(rewrite_covenant_policy_expr(index, ctx, contract_fields)?), + }, + expr.span, + )) + } + ExprKind::Identifier(name) => { + if let Some(bindings) = ctx.single_states.get(name) { + let by_field = bindings.iter().cloned().collect::>(); + return Ok(state_object_expr_from_field_bindings(contract_fields, &by_field)); + } + Ok(expr.clone()) + } + ExprKind::UnarySuffix { source, kind: UnarySuffixKind::Length, span } => { + if let ExprKind::Identifier(name) = &source.kind { + if let Some(bindings) = ctx.state_arrays.get(name) { + let first_field_array = bindings + .first() + .map(|(_, binding_name)| binding_name.clone()) + .ok_or_else(|| CompilerError::Unsupported("State[] requires at least one contract field".to_string()))?; + return Ok(length_expr(identifier_expr(&first_field_array))); + } + } + Ok(Expr::new( + ExprKind::UnarySuffix { + source: Box::new(rewrite_covenant_policy_expr(source, ctx, contract_fields)?), + kind: UnarySuffixKind::Length, + span: *span, + }, + expr.span, + )) + } + ExprKind::Unary { op, expr: inner } => Ok(Expr::new( + ExprKind::Unary { op: *op, expr: Box::new(rewrite_covenant_policy_expr(inner, ctx, contract_fields)?) }, + expr.span, + )), + ExprKind::Binary { op, left, right } => Ok(Expr::new( + ExprKind::Binary { + op: *op, + left: Box::new(rewrite_covenant_policy_expr(left, ctx, contract_fields)?), + right: Box::new(rewrite_covenant_policy_expr(right, ctx, contract_fields)?), + }, + expr.span, + )), + ExprKind::IfElse { condition, then_expr, else_expr } => Ok(Expr::new( + ExprKind::IfElse { + condition: Box::new(rewrite_covenant_policy_expr(condition, ctx, contract_fields)?), + then_expr: Box::new(rewrite_covenant_policy_expr(then_expr, ctx, contract_fields)?), + else_expr: Box::new(rewrite_covenant_policy_expr(else_expr, ctx, contract_fields)?), + }, + expr.span, + )), + ExprKind::Array(values) => Ok(Expr::new( + ExprKind::Array( + values.iter().map(|value| rewrite_covenant_policy_expr(value, ctx, contract_fields)).collect::, _>>()?, + ), + expr.span, + )), + ExprKind::StateObject(fields) => Ok(Expr::new( + ExprKind::StateObject( + fields + .iter() + .map(|field| { + Ok(StateFieldExpr { + name: field.name.clone(), + expr: rewrite_covenant_policy_expr(&field.expr, ctx, contract_fields)?, + span: field.span, + name_span: field.name_span, + }) + }) + .collect::, CompilerError>>()?, + ), + expr.span, + )), + ExprKind::Call { name, args, name_span } => Ok(Expr::new( + ExprKind::Call { + name: name.clone(), + args: args.iter().map(|arg| rewrite_covenant_policy_expr(arg, ctx, contract_fields)).collect::, _>>()?, + name_span: *name_span, + }, + expr.span, + )), + ExprKind::New { name, args, name_span } => Ok(Expr::new( + ExprKind::New { + name: name.clone(), + args: args.iter().map(|arg| rewrite_covenant_policy_expr(arg, ctx, contract_fields)).collect::, _>>()?, + name_span: *name_span, + }, + expr.span, + )), + ExprKind::Split { source, index, part, span } => Ok(Expr::new( + ExprKind::Split { + source: Box::new(rewrite_covenant_policy_expr(source, ctx, contract_fields)?), + index: Box::new(rewrite_covenant_policy_expr(index, ctx, contract_fields)?), + part: *part, + span: *span, + }, + expr.span, + )), + ExprKind::Slice { source, start, end, span } => Ok(Expr::new( + ExprKind::Slice { + source: Box::new(rewrite_covenant_policy_expr(source, ctx, contract_fields)?), + start: Box::new(rewrite_covenant_policy_expr(start, ctx, contract_fields)?), + end: Box::new(rewrite_covenant_policy_expr(end, ctx, contract_fields)?), + span: *span, + }, + expr.span, + )), + ExprKind::Introspection { kind, index, field_span } => Ok(Expr::new( + ExprKind::Introspection { + kind: *kind, + index: Box::new(rewrite_covenant_policy_expr(index, ctx, contract_fields)?), + field_span: *field_span, + }, + expr.span, + )), + ExprKind::UnarySuffix { source, kind, span } => Ok(Expr::new( + ExprKind::UnarySuffix { + source: Box::new(rewrite_covenant_policy_expr(source, ctx, contract_fields)?), + kind: *kind, + span: *span, + }, + expr.span, + )), + _ => Ok(expr.clone()), + } +} + +fn rewrite_covenant_policy_statement<'i>( + stmt: &Statement<'i>, + ctx: &CovenantStateRewriteContext, + contract_fields: &[ContractFieldAst<'i>], + return_desugaring: CovenantReturnDesugaring, +) -> Result, CompilerError> { + Ok(match stmt { + Statement::VariableDefinition { type_ref, modifiers, name, expr, span, type_span, modifier_spans, name_span } => { + Statement::VariableDefinition { + type_ref: type_ref.clone(), + modifiers: modifiers.clone(), + name: name.clone(), + expr: expr.as_ref().map(|expr| rewrite_covenant_policy_expr(expr, ctx, contract_fields)).transpose()?, + span: *span, + type_span: *type_span, + modifier_spans: modifier_spans.clone(), + name_span: *name_span, + } + } + Statement::TupleAssignment { + left_type_ref, + left_name, + right_type_ref, + right_name, + expr, + span, + left_type_span, + left_name_span, + right_type_span, + right_name_span, + } => Statement::TupleAssignment { + left_type_ref: left_type_ref.clone(), + left_name: left_name.clone(), + right_type_ref: right_type_ref.clone(), + right_name: right_name.clone(), + expr: rewrite_covenant_policy_expr(expr, ctx, contract_fields)?, + span: *span, + left_type_span: *left_type_span, + left_name_span: *left_name_span, + right_type_span: *right_type_span, + right_name_span: *right_name_span, + }, + Statement::ArrayPush { name, expr, span, name_span } => Statement::ArrayPush { + name: name.clone(), + expr: rewrite_covenant_policy_expr(expr, ctx, contract_fields)?, + span: *span, + name_span: *name_span, + }, + Statement::FunctionCall { name, args, span, name_span } => Statement::FunctionCall { + name: name.clone(), + args: args.iter().map(|arg| rewrite_covenant_policy_expr(arg, ctx, contract_fields)).collect::, _>>()?, + span: *span, + name_span: *name_span, + }, + Statement::FunctionCallAssign { bindings, name, args, span, name_span } => Statement::FunctionCallAssign { + bindings: bindings.clone(), + name: name.clone(), + args: args.iter().map(|arg| rewrite_covenant_policy_expr(arg, ctx, contract_fields)).collect::, _>>()?, + span: *span, + name_span: *name_span, + }, + Statement::StateFunctionCallAssign { bindings, name, args, span, name_span } => Statement::StateFunctionCallAssign { + bindings: bindings.clone(), + name: name.clone(), + args: args.iter().map(|arg| rewrite_covenant_policy_expr(arg, ctx, contract_fields)).collect::, _>>()?, + span: *span, + name_span: *name_span, + }, + Statement::StructDestructure { bindings, expr, span } => Statement::StructDestructure { + bindings: bindings.clone(), + expr: rewrite_covenant_policy_expr(expr, ctx, contract_fields)?, + span: *span, + }, + Statement::Assign { name, expr, span, name_span } => Statement::Assign { + name: name.clone(), + expr: rewrite_covenant_policy_expr(expr, ctx, contract_fields)?, + span: *span, + name_span: *name_span, + }, + Statement::TimeOp { tx_var, expr, message, span, tx_var_span, message_span } => Statement::TimeOp { + tx_var: *tx_var, + expr: rewrite_covenant_policy_expr(expr, ctx, contract_fields)?, + message: message.clone(), + span: *span, + tx_var_span: *tx_var_span, + message_span: *message_span, + }, + Statement::Require { expr, message, span, message_span } => Statement::Require { + expr: rewrite_covenant_policy_expr(expr, ctx, contract_fields)?, + message: message.clone(), + span: *span, + message_span: *message_span, + }, + Statement::If { condition, then_branch, else_branch, span, then_span, else_span } => Statement::If { + condition: rewrite_covenant_policy_expr(condition, ctx, contract_fields)?, + then_branch: then_branch + .iter() + .map(|stmt| rewrite_covenant_policy_statement(stmt, ctx, contract_fields, return_desugaring)) + .collect::, _>>()?, + else_branch: else_branch + .as_ref() + .map(|branch| { + branch + .iter() + .map(|stmt| rewrite_covenant_policy_statement(stmt, ctx, contract_fields, return_desugaring)) + .collect::, CompilerError>>() + }) + .transpose()?, + span: *span, + then_span: *then_span, + else_span: *else_span, + }, + Statement::For { ident, start, end, body, span, ident_span, body_span } => Statement::For { + ident: ident.clone(), + start: rewrite_covenant_policy_expr(start, ctx, contract_fields)?, + end: rewrite_covenant_policy_expr(end, ctx, contract_fields)?, + body: body + .iter() + .map(|stmt| rewrite_covenant_policy_statement(stmt, ctx, contract_fields, return_desugaring)) + .collect::, _>>()?, + span: *span, + ident_span: *ident_span, + body_span: *body_span, + }, + Statement::Yield { expr, span } => { + Statement::Yield { expr: rewrite_covenant_policy_expr(expr, ctx, contract_fields)?, span: *span } + } + Statement::Return { exprs, span } => { + let rewritten_exprs = match return_desugaring { + CovenantReturnDesugaring::Existing => { + exprs.iter().map(|expr| rewrite_covenant_policy_expr(expr, ctx, contract_fields)).collect::, _>>()? + } + CovenantReturnDesugaring::SingleState => { + if exprs.len() != 1 { + return Err(CompilerError::Unsupported( + "State covenant returns must return exactly one State value".to_string(), + )); + } + expand_state_expr(&exprs[0], ctx, contract_fields)? + } + CovenantReturnDesugaring::StateArray => { + if exprs.len() != 1 { + return Err(CompilerError::Unsupported( + "State[] covenant returns must return exactly one State[] value".to_string(), + )); + } + expand_state_array_expr(&exprs[0], ctx, contract_fields)? + } + }; + Statement::Return { exprs: rewritten_exprs, span: *span } + } + Statement::Console { args, span } => Statement::Console { + args: args + .iter() + .map(|arg| match arg { + crate::ast::ConsoleArg::Identifier(name, ident_span) => { + Ok(crate::ast::ConsoleArg::Identifier(name.clone(), *ident_span)) + } + crate::ast::ConsoleArg::Literal(expr) => { + Ok(crate::ast::ConsoleArg::Literal(rewrite_covenant_policy_expr(expr, ctx, contract_fields)?)) + } + }) + .collect::, CompilerError>>()?, + span: *span, + }, + }) +} + +fn desugar_covenant_policy_state_syntax<'i>( + policy: &FunctionAst<'i>, + declaration: &CovenantDeclaration<'i>, + contract_fields: &[ContractFieldAst<'i>], +) -> Result, CompilerError> { + if contract_fields.is_empty() { + return Ok(policy.clone()); + } + + let mut ctx = CovenantStateRewriteContext::default(); + let mut params = Vec::new(); + + match (declaration.binding, declaration.mode) { + (CovenantBinding::Auth, CovenantMode::Verification) => { + if policy.params.len() < 2 + || !is_state_type_ref(&policy.params[0].type_ref) + || !is_state_array_type_ref(&policy.params[1].type_ref) + { + return Err(CompilerError::Unsupported(format!( + "mode=verification with binding=auth on function '{}' expects parameters '(State prev_state, State[] new_states, ...)'", + policy.name + ))); + } + + let prev_name = policy.params[0].name.clone(); + let new_name = policy.params[1].name.clone(); + let prev_bindings = contract_fields + .iter() + .map(|field| (field.name.clone(), field_binding_name(&prev_name, &field.name))) + .collect::>(); + let new_bindings = contract_fields + .iter() + .map(|field| (field.name.clone(), field_binding_name(&new_name, &field.name))) + .collect::>(); + ctx.single_states.insert(prev_name.clone(), prev_bindings.clone()); + ctx.state_arrays.insert(new_name.clone(), new_bindings.clone()); + + for field in contract_fields { + params.push(typed_binding(field.type_ref.clone(), &field_binding_name(&prev_name, &field.name))); + } + for field in contract_fields { + params.push(typed_binding(dynamic_array_of(&field.type_ref), &field_binding_name(&new_name, &field.name))); + } + append_desugared_state_params(&mut params, &mut ctx, &policy.params[2..], contract_fields); + } + (CovenantBinding::Cov, CovenantMode::Verification) => { + if policy.params.len() < 2 + || !is_state_array_type_ref(&policy.params[0].type_ref) + || !is_state_array_type_ref(&policy.params[1].type_ref) + { + return Err(CompilerError::Unsupported(format!( + "mode=verification with binding=cov on function '{}' expects parameters '(State[] prev_states, State[] new_states, ...)'", + policy.name + ))); + } + + let prev_name = policy.params[0].name.clone(); + let new_name = policy.params[1].name.clone(); + let prev_bindings = contract_fields + .iter() + .map(|field| (field.name.clone(), field_binding_name(&prev_name, &field.name))) + .collect::>(); + let new_bindings = contract_fields + .iter() + .map(|field| (field.name.clone(), field_binding_name(&new_name, &field.name))) + .collect::>(); + ctx.state_arrays.insert(prev_name.clone(), prev_bindings.clone()); + ctx.state_arrays.insert(new_name.clone(), new_bindings.clone()); + + for field in contract_fields { + params.push(typed_binding(dynamic_array_of(&field.type_ref), &field_binding_name(&prev_name, &field.name))); + } + for field in contract_fields { + params.push(typed_binding(dynamic_array_of(&field.type_ref), &field_binding_name(&new_name, &field.name))); + } + append_desugared_state_params(&mut params, &mut ctx, &policy.params[2..], contract_fields); + } + (CovenantBinding::Auth, CovenantMode::Transition) => { + if policy.params.is_empty() || !is_state_type_ref(&policy.params[0].type_ref) { + return Err(CompilerError::Unsupported(format!( + "mode=transition with binding=auth on function '{}' expects parameters '(State prev_state, ...)'", + policy.name + ))); + } + + let prev_name = policy.params[0].name.clone(); + let prev_bindings = contract_fields + .iter() + .map(|field| (field.name.clone(), field_binding_name(&prev_name, &field.name))) + .collect::>(); + ctx.single_states.insert(prev_name.clone(), prev_bindings.clone()); + + for field in contract_fields { + params.push(typed_binding(field.type_ref.clone(), &field_binding_name(&prev_name, &field.name))); + } + append_desugared_state_params(&mut params, &mut ctx, &policy.params[1..], contract_fields); + } + (CovenantBinding::Cov, CovenantMode::Transition) => { + if policy.params.is_empty() || !is_state_array_type_ref(&policy.params[0].type_ref) { + return Err(CompilerError::Unsupported(format!( + "mode=transition with binding=cov on function '{}' expects parameters '(State[] prev_states, ...)'", + policy.name + ))); + } + + let prev_name = policy.params[0].name.clone(); + let prev_bindings = contract_fields + .iter() + .map(|field| (field.name.clone(), field_binding_name(&prev_name, &field.name))) + .collect::>(); + ctx.state_arrays.insert(prev_name.clone(), prev_bindings.clone()); + + for field in contract_fields { + params.push(typed_binding(dynamic_array_of(&field.type_ref), &field_binding_name(&prev_name, &field.name))); + } + append_desugared_state_params(&mut params, &mut ctx, &policy.params[1..], contract_fields); + } + } + + let (return_types, return_desugaring) = match declaration.mode { + CovenantMode::Verification => (policy.return_types.clone(), CovenantReturnDesugaring::Existing), + CovenantMode::Transition => { + if policy.return_types.len() != 1 { + return Err(CompilerError::Unsupported(format!( + "mode=transition on function '{}' with contract state expects exactly one return type: 'State' or 'State[]'", + policy.name + ))); + } + + if is_state_type_ref(&policy.return_types[0]) { + (contract_fields.iter().map(|field| field.type_ref.clone()).collect(), CovenantReturnDesugaring::SingleState) + } else if is_state_array_type_ref(&policy.return_types[0]) { + (contract_fields.iter().map(|field| dynamic_array_of(&field.type_ref)).collect(), CovenantReturnDesugaring::StateArray) + } else { + return Err(CompilerError::Unsupported(format!( + "mode=transition on function '{}' with contract state expects return type 'State' or 'State[]'", + policy.name + ))); + } + } + }; + + let return_type_spans = match return_desugaring { + CovenantReturnDesugaring::Existing => policy.return_type_spans.clone(), + CovenantReturnDesugaring::SingleState | CovenantReturnDesugaring::StateArray => { + if let Some(span) = policy.return_type_spans.first().copied() { vec![span; contract_fields.len()] } else { Vec::new() } + } + }; + + let body = policy + .body + .iter() + .map(|stmt| rewrite_covenant_policy_statement(stmt, &ctx, contract_fields, return_desugaring)) + .collect::, _>>()?; + + Ok(FunctionAst { + name: policy.name.clone(), + attributes: policy.attributes.clone(), + params, + entrypoint: policy.entrypoint, + return_types, + body, + return_type_spans, + span: policy.span, + name_span: policy.name_span, + body_span: policy.body_span, + }) +} + +fn parse_verification_shape<'i>( + policy: &FunctionAst<'i>, + contract_fields: &[ContractFieldAst<'i>], + binding: CovenantBinding, +) -> Result, CompilerError> { + let field_count = contract_fields.len(); + let required = field_count * 2; + let binding_name = match binding { + CovenantBinding::Auth => "auth", + CovenantBinding::Cov => "cov", + }; + let prev_label = match binding { + CovenantBinding::Auth => "params", + CovenantBinding::Cov => "arrays", + }; + let new_label = match binding { + CovenantBinding::Auth => "array params", + CovenantBinding::Cov => "arrays", + }; + if policy.params.len() < required { + return Err(CompilerError::Unsupported(format!( + "mode=verification with binding={} on function '{}' requires {} prev-state {} + {} new-state {} (one per contract field)", + binding_name, policy.name, field_count, prev_label, field_count, new_label + ))); + } + + let mut prev_field_values = Vec::with_capacity(field_count); + let mut new_field_arrays = Vec::with_capacity(field_count); + for (idx, field) in contract_fields.iter().enumerate() { + let prev_expected = match binding { + CovenantBinding::Auth => field.type_ref.clone(), + CovenantBinding::Cov => dynamic_array_of(&field.type_ref), + }; + let prev_param = &policy.params[idx]; + if prev_param.type_ref != prev_expected { + return Err(CompilerError::Unsupported(format!( + "mode=verification with binding={} on function '{}' expects prev-state param '{}' to be '{}', got '{}'", + binding_name, + policy.name, + prev_param.name, + type_name_from_ref(&prev_expected), + type_name_from_ref(&prev_param.type_ref) + ))); + } + prev_field_values.push((field.name.clone(), prev_param.name.clone())); + + let new_expected = dynamic_array_of(&field.type_ref); + let new_param = &policy.params[field_count + idx]; + if new_param.type_ref != new_expected { + return Err(CompilerError::Unsupported(format!( + "mode=verification with binding={} on function '{}' expects new-state param '{}' to be '{}', got '{}'", + binding_name, + policy.name, + new_param.name, + type_name_from_ref(&new_expected), + type_name_from_ref(&new_param.type_ref) + ))); + } + new_field_arrays.push((field.name.clone(), new_param.name.clone())); + } + + let entrypoint_params = policy.params[field_count..].to_vec(); + let call_args = match binding { + CovenantBinding::Auth => { + let mut args = Vec::with_capacity(policy.params.len()); + for field in contract_fields { + args.push(identifier_expr(&field.name)); + } + for param in &entrypoint_params { + args.push(identifier_expr(¶m.name)); + } + args + } + CovenantBinding::Cov => policy.params.iter().map(|param| identifier_expr(¶m.name)).collect(), + }; + + Ok(VerificationShape { prev_field_values, new_field_arrays, entrypoint_params, call_args }) +} + +fn parse_transition_shape<'i>( + policy: &FunctionAst<'i>, + contract_fields: &[ContractFieldAst<'i>], + binding: CovenantBinding, +) -> Result, CompilerError> { + let field_count = contract_fields.len(); + let binding_name = match binding { + CovenantBinding::Auth => "auth", + CovenantBinding::Cov => "cov", + }; + let prev_label = match binding { + CovenantBinding::Auth => "params", + CovenantBinding::Cov => "arrays", + }; + if policy.params.len() < field_count { + return Err(CompilerError::Unsupported(format!( + "mode=transition with binding={} on function '{}' requires {} prev-state {} (one per contract field) before call args", + binding_name, policy.name, field_count, prev_label + ))); + } + + for (idx, field) in contract_fields.iter().enumerate() { + let prev_expected = match binding { + CovenantBinding::Auth => field.type_ref.clone(), + CovenantBinding::Cov => dynamic_array_of(&field.type_ref), + }; + let prev_param = &policy.params[idx]; + if prev_param.type_ref != prev_expected { + return Err(CompilerError::Unsupported(format!( + "mode=transition with binding={} on function '{}' expects prev-state param '{}' to be '{}', got '{}'", + binding_name, + policy.name, + prev_param.name, + type_name_from_ref(&prev_expected), + type_name_from_ref(&prev_param.type_ref) + ))); + } + } + + match binding { + CovenantBinding::Auth => { + let entrypoint_params = policy.params[field_count..].to_vec(); + let mut call_args = Vec::with_capacity(policy.params.len()); + for field in contract_fields { + call_args.push(identifier_expr(&field.name)); + } + for param in &entrypoint_params { + call_args.push(identifier_expr(¶m.name)); + } + Ok(TransitionShape { entrypoint_params, call_args }) + } + CovenantBinding::Cov => Ok(TransitionShape { + entrypoint_params: policy.params.clone(), + call_args: policy.params.iter().map(|param| identifier_expr(¶m.name)).collect(), + }), + } +} + +fn append_policy_call_and_capture_next_state<'i>( + body: &mut Vec>, + policy: &FunctionAst<'i>, + policy_name: &str, + mode: CovenantMode, + singleton: bool, + termination: CovenantTermination, + contract_fields: &[ContractFieldAst<'i>], + call_args: Vec>, +) -> Result, CompilerError> { + match mode { + CovenantMode::Verification => { + body.push(call_statement(policy_name, call_args)); + Ok(OutputStateSource::Single(state_object_expr_from_contract_fields(contract_fields))) + } + CovenantMode::Transition => { + if policy.return_types.len() != contract_fields.len() { + return Err(CompilerError::Unsupported(format!( + "transition mode policy function '{}' must return exactly {} values (one per contract field)", + policy.name, + contract_fields.len() + ))); + } + + let mut shape_is_single = true; + let mut shape_is_per_output_arrays = true; + for (field, return_type) in contract_fields.iter().zip(policy.return_types.iter()) { + shape_is_single &= type_name_from_ref(return_type) == type_name_from_ref(&field.type_ref); + shape_is_per_output_arrays &= return_type_is_per_output_array(return_type, &field.type_ref); + } + if !shape_is_single && !shape_is_per_output_arrays { + return Err(CompilerError::Unsupported(format!( + "transition mode policy function '{}' returns must be either exactly State fields or per-field arrays", + policy.name + ))); + } + if singleton && shape_is_per_output_arrays && termination != CovenantTermination::Allowed { + return Err(CompilerError::Unsupported(format!( + "transition mode singleton policy function '{}' must return a single State (arrays are not allowed unless termination=allowed)", + policy.name + ))); + } + + let mut bindings = Vec::new(); + let mut binding_by_field = HashMap::new(); + for (field, return_type) in contract_fields.iter().zip(policy.return_types.iter()) { + let binding_name = format!("__cov_new_{}", field.name); + bindings.push(typed_binding(return_type.clone(), &binding_name)); + binding_by_field.insert(field.name.clone(), binding_name); + } + + body.push(function_call_assign_statement(bindings, policy_name, call_args)); + if shape_is_single { + Ok(OutputStateSource::Single(state_object_expr_from_field_bindings(contract_fields, &binding_by_field))) + } else { + let first_field = &contract_fields[0].name; + let first_array_name = binding_by_field + .get(first_field) + .cloned() + .unwrap_or_else(|| panic!("missing transition binding for field '{}'", first_field)); + let expected_len_expr = length_expr(identifier_expr(&first_array_name)); + for field in contract_fields.iter().skip(1) { + let array_name = binding_by_field + .get(&field.name) + .cloned() + .unwrap_or_else(|| panic!("missing transition binding for field '{}'", field.name)); + body.push(require_statement(binary_expr( + BinaryOp::Eq, + length_expr(identifier_expr(&array_name)), + expected_len_expr.clone(), + ))); + } + + let field_arrays = contract_fields + .iter() + .map(|field| { + let name = binding_by_field + .get(&field.name) + .cloned() + .unwrap_or_else(|| panic!("missing transition binding for field '{}'", field.name)); + (field.name.clone(), name) + }) + .collect(); + Ok(OutputStateSource::PerOutputArrays { field_arrays, length_expr: expected_len_expr }) + } + } + } +} + +fn append_auth_output_state_checks<'i>( + body: &mut Vec>, + active_input: &Expr<'i>, + out_count_name: &str, + to_expr: Expr<'i>, + next_state_expr: Expr<'i>, +) { + let loop_var = "__cov_k"; + let out_idx_name = "__cov_out_idx"; + let cond = binary_expr(BinaryOp::Lt, identifier_expr(loop_var), identifier_expr(out_count_name)); + let then_branch = vec![ + var_def_statement( + int_type_ref(), + out_idx_name, + Expr::call("OpAuthOutputIdx", vec![active_input.clone(), identifier_expr(loop_var)]), + ), + call_statement("validateOutputState", vec![identifier_expr(out_idx_name), next_state_expr]), + ]; + body.push(for_statement(loop_var, Expr::int(0), to_expr, vec![if_statement(cond, then_branch)])); +} + +fn append_cov_input_state_reads<'i>( + body: &mut Vec>, + cov_id_name: &str, + in_count_name: &str, + from_expr: Expr<'i>, + contract_fields: &[ContractFieldAst<'i>], +) { + if contract_fields.is_empty() { + return; + } + let loop_var = "__cov_in_k"; + let in_idx_name = "__cov_in_idx"; + let cond = binary_expr(BinaryOp::Lt, identifier_expr(loop_var), identifier_expr(in_count_name)); + let mut then_branch = Vec::new(); + then_branch.push(var_def_statement( + int_type_ref(), + in_idx_name, + Expr::call("OpCovInputIdx", vec![identifier_expr(cov_id_name), identifier_expr(loop_var)]), + )); + let bindings = contract_fields + .iter() + .map(|field| state_binding(&field.name, field.type_ref.clone(), &format!("__cov_prev_{}", field.name))) + .collect(); + then_branch.push(state_call_assign_statement(bindings, "readInputState", vec![identifier_expr(in_idx_name)])); + body.push(for_statement(loop_var, Expr::int(0), from_expr, vec![if_statement(cond, then_branch)])); +} + +fn append_cov_input_state_reads_into_policy_prev_arrays<'i>( + body: &mut Vec>, + cov_id_name: &str, + in_count_name: &str, + from_expr: Expr<'i>, + contract_fields: &[ContractFieldAst<'i>], + prev_field_arrays: &[(String, String)], +) { + if contract_fields.is_empty() { + return; + } + let prev_by_field: HashMap<_, _> = prev_field_arrays.iter().cloned().collect(); + for field in contract_fields { + let array_name = + prev_by_field.get(&field.name).unwrap_or_else(|| panic!("missing prev-state array param for field '{}'", field.name)); + body.push(var_decl_statement(dynamic_array_of(&field.type_ref), array_name)); + } + + let loop_var = "__cov_in_k"; + let in_idx_name = "__cov_in_idx"; + let cond = binary_expr(BinaryOp::Lt, identifier_expr(loop_var), identifier_expr(in_count_name)); + let mut then_branch = Vec::new(); + then_branch.push(var_def_statement( + int_type_ref(), + in_idx_name, + Expr::call("OpCovInputIdx", vec![identifier_expr(cov_id_name), identifier_expr(loop_var)]), + )); + let bindings = contract_fields + .iter() + .map(|field| state_binding(&field.name, field.type_ref.clone(), &format!("__cov_prev_{}", field.name))) + .collect(); + then_branch.push(state_call_assign_statement(bindings, "readInputState", vec![identifier_expr(in_idx_name)])); + for field in contract_fields { + let array_name = + prev_by_field.get(&field.name).unwrap_or_else(|| panic!("missing prev-state array param for field '{}'", field.name)); + then_branch.push(array_push_statement(array_name, identifier_expr(&format!("__cov_prev_{}", field.name)))); + } + body.push(for_statement(loop_var, Expr::int(0), from_expr, vec![if_statement(cond, then_branch)])); +} + +fn append_cov_output_state_checks<'i>( + body: &mut Vec>, + cov_id_name: &str, + out_count_name: &str, + to_expr: Expr<'i>, + next_state_expr: Expr<'i>, +) { + let loop_var = "__cov_k"; + let out_idx_name = "__cov_out_idx"; + let cond = binary_expr(BinaryOp::Lt, identifier_expr(loop_var), identifier_expr(out_count_name)); + let then_branch = vec![ + var_def_statement( + int_type_ref(), + out_idx_name, + Expr::call("OpCovOutputIdx", vec![identifier_expr(cov_id_name), identifier_expr(loop_var)]), + ), + call_statement("validateOutputState", vec![identifier_expr(out_idx_name), next_state_expr]), + ]; + body.push(for_statement(loop_var, Expr::int(0), to_expr, vec![if_statement(cond, then_branch)])); +} + +fn append_auth_output_array_state_checks<'i>( + body: &mut Vec>, + active_input: &Expr<'i>, + out_count_name: &str, + to_expr: Expr<'i>, + field_arrays: Vec<(String, String)>, + contract_fields: &[ContractFieldAst<'i>], +) { + let loop_var = "__cov_k"; + let out_idx_name = "__cov_out_idx"; + let cond = binary_expr(BinaryOp::Lt, identifier_expr(loop_var), identifier_expr(out_count_name)); + let mut then_branch = Vec::new(); + then_branch.push(var_def_statement( + int_type_ref(), + out_idx_name, + Expr::call("OpAuthOutputIdx", vec![active_input.clone(), identifier_expr(loop_var)]), + )); + let next_state_expr = state_object_expr_from_field_arrays_at_index(contract_fields, &field_arrays, identifier_expr(loop_var)); + then_branch.push(call_statement("validateOutputState", vec![identifier_expr(out_idx_name), next_state_expr])); + body.push(for_statement(loop_var, Expr::int(0), to_expr, vec![if_statement(cond, then_branch)])); +} + +fn append_cov_output_array_state_checks<'i>( + body: &mut Vec>, + cov_id_name: &str, + out_count_name: &str, + to_expr: Expr<'i>, + field_arrays: Vec<(String, String)>, + contract_fields: &[ContractFieldAst<'i>], +) { + let loop_var = "__cov_k"; + let out_idx_name = "__cov_out_idx"; + let cond = binary_expr(BinaryOp::Lt, identifier_expr(loop_var), identifier_expr(out_count_name)); + let mut then_branch = Vec::new(); + then_branch.push(var_def_statement( + int_type_ref(), + out_idx_name, + Expr::call("OpCovOutputIdx", vec![identifier_expr(cov_id_name), identifier_expr(loop_var)]), + )); + let next_state_expr = state_object_expr_from_field_arrays_at_index(contract_fields, &field_arrays, identifier_expr(loop_var)); + then_branch.push(call_statement("validateOutputState", vec![identifier_expr(out_idx_name), next_state_expr])); + body.push(for_statement(loop_var, Expr::int(0), to_expr, vec![if_statement(cond, then_branch)])); +} diff --git a/silverscript-lang/src/silverscript.pest b/silverscript-lang/src/silverscript.pest index efc0a62..514da2f 100644 --- a/silverscript-lang/src/silverscript.pest +++ b/silverscript-lang/src/silverscript.pest @@ -11,7 +11,11 @@ contract_item = { struct_definition | constant_definition | contract_field_defin struct_definition = { "struct" ~ Identifier ~ "{" ~ struct_field_definition* ~ "}" } struct_field_definition = { type_name ~ Identifier ~ ";" } entrypoint = { "entrypoint" } -function_definition = { entrypoint? ~ "function" ~ Identifier ~ parameter_list ~ return_type_list? ~ "{" ~ statement* ~ "}" } +function_definition = { function_attribute* ~ entrypoint? ~ "function" ~ Identifier ~ parameter_list ~ return_type_list? ~ "{" ~ statement* ~ "}" } +function_attribute = { "#[" ~ attribute_path ~ attribute_args? ~ "]" } +attribute_path = { Identifier ~ ("." ~ Identifier)* } +attribute_args = { "(" ~ (attribute_arg ~ ("," ~ attribute_arg)* ~ ","?)? ~ ")" } +attribute_arg = { Identifier ~ "=" ~ expression } constant_definition = { type_name ~ "constant" ~ Identifier ~ "=" ~ expression ~ ";" } contract_field_definition = { type_name ~ Identifier ~ "=" ~ expression ~ ";" } diff --git a/silverscript-lang/tests/ast_spans_tests.rs b/silverscript-lang/tests/ast_spans_tests.rs index 0dca9fa..4b2880d 100644 --- a/silverscript-lang/tests/ast_spans_tests.rs +++ b/silverscript-lang/tests/ast_spans_tests.rs @@ -58,3 +58,58 @@ fn populates_slice_expression_spans() { assert_span_text(source, start.span.as_str(), "1"); assert_span_text(source, end.span.as_str(), "3"); } + +#[test] +fn parses_function_attributes_and_for_ast() { + let source = r#" + contract Decls(int max_outs) { + #[covenant(binding = cov, from = 2, to = max_outs, mode = verification)] + function policy() { + int dyn = tx.outputs.length; + for(i, 0, dyn) { + require(i >= 0); + } + } + } + "#; + + let contract = parse_contract_ast(source).expect("contract should parse"); + let function = &contract.functions[0]; + assert_eq!(function.attributes.len(), 1); + + let attribute = &function.attributes[0]; + assert_eq!(attribute.path, vec!["covenant"]); + assert_eq!(attribute.args.len(), 4); + assert_eq!(attribute.args[0].name, "binding"); + assert_eq!(attribute.args[1].name, "from"); + assert_eq!(attribute.args[2].name, "to"); + assert_eq!(attribute.args[3].name, "mode"); + assert_span_text(source, attribute.path_spans[0].as_str(), "covenant"); +} + +#[test] +fn parses_multiple_and_noarg_function_attributes() { + let source = r#" + contract Attrs(int max_outs) { + #[covenant(binding = auth, from = 1, to = max_outs + 1, mode = verification)] + #[experimental] + function policy() { + require(true); + } + } + "#; + + let contract = parse_contract_ast(source).expect("contract should parse"); + let function = &contract.functions[0]; + assert_eq!(function.attributes.len(), 2); + + let first = &function.attributes[0]; + assert_eq!(first.path, vec!["covenant"]); + assert_eq!(first.args.len(), 4); + assert_eq!(first.args[2].name, "to"); + assert_span_text(source, first.args[2].expr.span.as_str(), "max_outs + 1"); + + let second = &function.attributes[1]; + assert_eq!(second.path, vec!["experimental"]); + assert!(second.args.is_empty()); +} diff --git a/silverscript-lang/tests/compiler_tests.rs b/silverscript-lang/tests/compiler_tests.rs index cb45bc1..adaa067 100644 --- a/silverscript-lang/tests/compiler_tests.rs +++ b/silverscript-lang/tests/compiler_tests.rs @@ -13,7 +13,8 @@ use kaspa_txscript::script_builder::ScriptBuilder; use kaspa_txscript::{EngineCtx, EngineFlags, SeqCommitAccessor, TxScriptEngine, pay_to_address_script, pay_to_script_hash_script}; use silverscript_lang::ast::{Expr, parse_contract_ast}; use silverscript_lang::compiler::{ - CompileOptions, CompiledContract, compile_contract, compile_contract_ast, function_branch_index, struct_object, + CompileOptions, CompiledContract, CovenantDeclCallOptions, FunctionAbiEntry, FunctionInputAbi, compile_contract, + compile_contract_ast, function_branch_index, struct_object, }; fn run_script_with_selector(script: Vec, selector: Option) -> Result<(), kaspa_txscript_errors::TxScriptError> { @@ -141,6 +142,32 @@ fn accepts_constructor_args_with_matching_types() { compile_contract(source, &args, CompileOptions::default()).expect("compile succeeds"); } +#[test] +fn supports_struct_contract_params_fields_and_constants() { + let source = r#" + contract TopLevelStructs(Pair init_pair) { + struct Pair { + int amount; + byte[2] code; + } + + Pair constant DEFAULT_PAIR = {amount: 7, code: 0x1234}; + Pair from_param = init_pair; + Pair from_constant = DEFAULT_PAIR; + + entrypoint function main() { + require(true); + } + } + "#; + + let args = vec![struct_object(vec![("amount", Expr::int(11)), ("code", Expr::bytes(vec![0xab, 0xcd]))])]; + let compiled = compile_contract(source, &args, CompileOptions::default()).expect("compile succeeds"); + let selector = selector_for(&compiled, "main"); + let result = run_script_with_selector(compiled.script, selector); + assert!(result.is_ok(), "top-level struct param/field/constant contract should run: {result:?}"); +} + #[test] fn compile_contract_omits_debug_info_when_recording_disabled() { let source = r#" @@ -388,6 +415,74 @@ fn build_sig_script_rejects_wrong_argument_type() { assert!(result.is_err()); } +#[test] +fn build_sig_script_for_covenant_decl_routes_to_hidden_auth_entrypoint() { + let source = r#" + contract Counter(int init_value) { + int value = init_value; + + #[covenant.singleton] + function step(State prev_state, State[] new_states) { + require(new_states.length <= 1); + } + } + "#; + + let compiled = compile_contract(source, &[Expr::int(7)], CompileOptions::default()).expect("compile succeeds"); + let args = vec![vec![struct_object(vec![("value", Expr::int(8))])].into()]; + + let actual = compiled + .build_sig_script_for_covenant_decl("step", args.clone(), CovenantDeclCallOptions { is_leader: false }) + .expect("covenant sigscript builds"); + let expected = compiled.build_sig_script("__step", args).expect("hidden entrypoint sigscript builds"); + + assert_eq!(actual, expected); +} + +#[test] +fn build_sig_script_for_covenant_decl_routes_to_hidden_cov_entrypoints() { + let source = r#" + contract Pair(int init_value) { + int value = init_value; + + #[covenant(from = 2, to = 2)] + function rebalance(State[] prev_states, State[] new_states) { + require(new_states.length == 1); + } + } + "#; + + let compiled = compile_contract(source, &[Expr::int(7)], CompileOptions::default()).expect("compile succeeds"); + let leader_args = vec![vec![struct_object(vec![("value", Expr::int(8))])].into()]; + + let leader = compiled + .build_sig_script_for_covenant_decl("rebalance", leader_args.clone(), CovenantDeclCallOptions { is_leader: true }) + .expect("leader sigscript builds"); + let expected_leader = compiled.build_sig_script("__leader_rebalance", leader_args).expect("hidden leader sigscript builds"); + assert_eq!(leader, expected_leader); + + let delegate = compiled + .build_sig_script_for_covenant_decl("rebalance", vec![], CovenantDeclCallOptions { is_leader: false }) + .expect("delegate sigscript builds"); + let expected_delegate = compiled.build_sig_script("__delegate_rebalance", vec![]).expect("hidden delegate sigscript builds"); + assert_eq!(delegate, expected_delegate); +} + +#[test] +fn build_sig_script_for_covenant_decl_rejects_unknown_declaration() { + let source = r#" + contract C() { + entrypoint function spend() { + require(true); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); + let result = compiled.build_sig_script_for_covenant_decl("missing", vec![], CovenantDeclCallOptions { is_leader: false }); + assert!(result.is_err()); +} + #[test] fn rejects_double_underscore_variable_names() { let source = r#" @@ -410,6 +505,57 @@ fn rejects_double_underscore_variable_names() { assert!(parse_contract_ast(source).is_err()); } +#[test] +fn rejects_double_underscore_function_names() { + let source = r#" + contract Bad() { + function __hidden() { + require(true); + } + + entrypoint function main() { + require(true); + } + } + "#; + + assert!(parse_contract_ast(source).is_err()); +} + +#[test] +fn rejects_double_underscore_struct_names() { + let source = r#" + contract Bad() { + struct __Hidden { + int value; + } + + entrypoint function main() { + require(true); + } + } + "#; + + assert!(parse_contract_ast(source).is_err()); +} + +#[test] +fn rejects_struct_named_state() { + let source = r#" + contract Bad() { + struct State { + int value; + } + + entrypoint function main() { + require(true); + } + } + "#; + + assert!(parse_contract_ast(source).is_err()); +} + #[test] fn rejects_yield_without_allow_option() { let source = r#" @@ -529,6 +675,37 @@ fn compiles_struct_sugar_for_locals_calls_and_field_access() { assert!(result.is_ok(), "script should execute successfully: {result:?}"); } +#[test] +fn compiles_struct_return_types_in_inline_calls() { + let source = r#" + contract C() { + struct S { + int a; + string b; + } + + function make(int a) : (S) { + return({a: a, b: "12345"}); + } + + function check(S x) { + require(x.a == 0); + require(x.b.length == 5); + } + + entrypoint function main() { + (S out) = make(0); + check(out); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); + let selector = selector_for(&compiled, "main"); + let result = run_script_with_selector(compiled.script, selector); + assert!(result.is_ok(), "struct-return inline call should execute successfully: {result:?}"); +} + #[test] fn build_sig_script_supports_struct_entrypoint_arguments() { let source = r#" @@ -575,6 +752,944 @@ fn build_sig_script_supports_state_entrypoint_arguments() { assert_eq!(sigscript, expected); } +fn struct_array_arg<'i>(values: Vec<(i64, Vec)>) -> Expr<'i> { + values.into_iter().map(|(a, b)| struct_object(vec![("a", Expr::int(a)), ("b", Expr::bytes(b))])).collect::>().into() +} + +fn state_array_arg<'i>(values: Vec) -> Expr<'i> { + values.into_iter().map(|value| struct_object(vec![("value", Expr::int(value))])).collect::>().into() +} + +fn matrix_state_array_arg<'i>(values: Vec<(i64, Vec)>) -> Expr<'i> { + values + .into_iter() + .map(|(amount, owner)| struct_object(vec![("amount", Expr::int(amount)), ("owner", Expr::bytes(owner))])) + .collect::>() + .into() +} + +fn replace_compiled_interface<'i>( + compiled: &mut CompiledContract<'i>, + source: &'i str, + entrypoint_name: &str, + inputs: &[(&str, &str)], +) { + compiled.ast = parse_contract_ast(source).expect("interface parses"); + compiled.abi = vec![FunctionAbiEntry { + name: entrypoint_name.to_string(), + inputs: inputs + .iter() + .map(|(name, type_name)| FunctionInputAbi { name: (*name).to_string(), type_name: (*type_name).to_string() }) + .collect(), + }]; +} + +#[test] +fn build_sig_script_for_covenant_decl_supports_all_covenant_ast_examples() { + struct Case { + source: &'static str, + constructor_args: Vec>, + function_name: &'static str, + args: Vec>, + options: CovenantDeclCallOptions, + generated_covenant_entrypoint_name: &'static str, + } + + let owner = vec![7u8; 32]; + let next_owner = vec![9u8; 32]; + let matrix_singleton_transition_source = r#" + contract Matrix(int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant.singleton(mode = transition)] + function step(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + } + "#; + let matrix_singleton_terminate_source = r#" + contract Matrix(int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant.singleton(mode = transition, termination = allowed)] + function step(State prev_state, State[] next_states) : (State[]) { + return(next_states); + } + } + "#; + let matrix_fanout_verification_source = r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant.fanout(to = max_outs, mode = verification)] + function step(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + } + "#; + let matrix_all_source = r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = auth, from = 1, to = max_outs, mode = verification, groups = multiple)] + function auth_verification_multi(State prev_state, State[] new_states, int nonce) { + require(nonce >= 0); + } + + #[covenant(binding = auth, from = 1, to = max_outs, mode = verification, groups = single)] + function auth_verification_single(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + + #[covenant(binding = auth, from = 1, to = max_outs, mode = transition)] + function auth_transition(State prev_state, int fee) : (State) { + return({ amount: prev_state.amount - fee, owner: prev_state.owner }); + } + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = verification)] + function cov_verification(State[] prev_states, State[] new_states, int nonce) { + require(nonce >= 0); + } + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = transition)] + function cov_transition(State[] prev_states, int fee) : (State[]) { + require(fee >= 0); + return(prev_states); + } + + #[covenant(from = 1, to = max_outs)] + function inferred_auth(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + + #[covenant(from = max_ins, to = max_outs)] + function inferred_cov(State[] prev_states, State[] new_states) { + require(new_states.length == new_states.length); + } + + #[covenant(from = 1, to = 1)] + function inferred_transition(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + + #[covenant.singleton(mode = transition)] + function singleton_transition(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + + #[covenant.singleton(mode = transition, termination = allowed)] + function singleton_terminate(State prev_state, State[] next_states) : (State[]) { + require(prev_state.amount >= 0); + return(next_states); + } + + #[covenant.fanout(to = max_outs, mode = verification)] + function fanout_verification(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + } + "#; + + let cases = vec![ + Case { + source: r#" + contract Decls(int max_outs) { + int value = 0; + + #[covenant(binding = auth, from = 1, to = max_outs, groups = single)] + function split(State prev_state, State[] new_states, int amount) { + require(amount >= 0); + } + } + "#, + constructor_args: vec![Expr::int(4)], + function_name: "split", + args: vec![state_array_arg(vec![11]), Expr::int(3)], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__split", + }, + Case { + source: r#" + contract Decls(int max_ins, int max_outs) { + int value = 0; + + #[covenant(from = max_ins, to = max_outs, mode = verification)] + function transition_ok(State[] prev_states, State[] new_states, int delta) { + require(delta >= 0); + } + } + "#, + constructor_args: vec![Expr::int(2), Expr::int(3)], + function_name: "transition_ok", + args: vec![state_array_arg(vec![10, 11]), Expr::int(1)], + options: CovenantDeclCallOptions { is_leader: true }, + generated_covenant_entrypoint_name: "__leader_transition_ok", + }, + Case { + source: r#" + contract Decls(int max_ins, int max_outs) { + int value = 0; + + #[covenant(from = max_ins, to = max_outs, mode = verification)] + function transition_ok(State[] prev_states, State[] new_states, int delta) { + require(delta >= 0); + } + } + "#, + constructor_args: vec![Expr::int(2), Expr::int(3)], + function_name: "transition_ok", + args: vec![], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__delegate_transition_ok", + }, + Case { + source: r#" + contract Decls(int init_value) { + int value = init_value; + + #[covenant.singleton(mode = transition)] + function bump(State prev_state, int delta) : (State) { + return({ value: prev_state.value + delta }); + } + } + "#, + constructor_args: vec![Expr::int(7)], + function_name: "bump", + args: vec![Expr::int(2)], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__bump", + }, + Case { + source: r#" + contract Decls(int max_outs, int init_value) { + int value = init_value; + + #[covenant(from = 1, to = max_outs, mode = transition)] + function fanout(State prev_state, State[] next_states) : (State[]) { + return(next_states); + } + } + "#, + constructor_args: vec![Expr::int(4), Expr::int(10)], + function_name: "fanout", + args: vec![state_array_arg(vec![11, 12])], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__fanout", + }, + Case { + source: r#" + contract Decls(int init_value) { + int value = init_value; + + #[covenant.singleton(mode = transition, termination = allowed)] + function bump_or_terminate(State prev_state, State[] next_states) : (State[]) { + return(next_states); + } + } + "#, + constructor_args: vec![Expr::int(10)], + function_name: "bump_or_terminate", + args: vec![state_array_arg(vec![13])], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__bump_or_terminate", + }, + Case { + source: r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = auth, from = 1, to = max_outs, mode = verification, groups = multiple)] + function step(State prev_state, State[] new_states, int nonce) { + require(nonce >= 0); + } + } + "#, + constructor_args: vec![Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())]), Expr::int(0)], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__step", + }, + Case { + source: r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = auth, from = 1, to = max_outs, mode = verification, groups = single)] + function step(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + } + "#, + constructor_args: vec![Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())])], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__step", + }, + Case { + source: r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = auth, from = 1, to = max_outs, mode = transition)] + function step(State prev_state, int fee) : (State) { + return({ amount: prev_state.amount - fee, owner: prev_state.owner }); + } + } + "#, + constructor_args: vec![Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![Expr::int(1)], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__step", + }, + Case { + source: r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = verification)] + function step(State[] prev_states, State[] new_states, int nonce) { + require(nonce >= 0); + } + } + "#, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())]), Expr::int(0)], + options: CovenantDeclCallOptions { is_leader: true }, + generated_covenant_entrypoint_name: "__leader_step", + }, + Case { + source: r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = verification)] + function step(State[] prev_states, State[] new_states, int nonce) { + require(nonce >= 0); + } + } + "#, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__delegate_step", + }, + Case { + source: r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = transition)] + function step(State[] prev_states, int fee) : (State[]) { + require(fee >= 0); + return(prev_states); + } + } + "#, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![matrix_state_array_arg(vec![(10, owner.clone())]), Expr::int(1)], + options: CovenantDeclCallOptions { is_leader: true }, + generated_covenant_entrypoint_name: "__leader_step", + }, + Case { + source: r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = transition)] + function step(State[] prev_states, int fee) : (State[]) { + require(fee >= 0); + return(prev_states); + } + } + "#, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__delegate_step", + }, + Case { + source: r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(from = 1, to = max_outs)] + function step(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + } + "#, + constructor_args: vec![Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())])], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__step", + }, + Case { + source: r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(from = max_ins, to = max_outs)] + function step(State[] prev_states, State[] new_states) { + require(new_states.length == new_states.length); + } + } + "#, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())])], + options: CovenantDeclCallOptions { is_leader: true }, + generated_covenant_entrypoint_name: "__leader_step", + }, + Case { + source: r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(from = max_ins, to = max_outs)] + function step(State[] prev_states, State[] new_states) { + require(new_states.length == new_states.length); + } + } + "#, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__delegate_step", + }, + Case { + source: r#" + contract Matrix(int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(from = 1, to = 1)] + function step(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + } + "#, + constructor_args: vec![Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![Expr::int(1)], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__step", + }, + Case { + source: matrix_singleton_transition_source, + constructor_args: vec![Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![Expr::int(1)], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__step", + }, + Case { + source: matrix_singleton_terminate_source, + constructor_args: vec![Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())])], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__step", + }, + Case { + source: matrix_fanout_verification_source, + constructor_args: vec![Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "step", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())])], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__step", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "auth_verification_multi", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())]), Expr::int(0)], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__auth_verification_multi", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "auth_verification_single", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())])], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__auth_verification_single", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "auth_transition", + args: vec![Expr::int(1)], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__auth_transition", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "cov_verification", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())]), Expr::int(0)], + options: CovenantDeclCallOptions { is_leader: true }, + generated_covenant_entrypoint_name: "__leader_cov_verification", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "cov_verification", + args: vec![], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__delegate_cov_verification", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "cov_transition", + args: vec![matrix_state_array_arg(vec![(10, owner.clone())]), Expr::int(1)], + options: CovenantDeclCallOptions { is_leader: true }, + generated_covenant_entrypoint_name: "__leader_cov_transition", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "cov_transition", + args: vec![], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__delegate_cov_transition", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "inferred_auth", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())])], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__inferred_auth", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "inferred_cov", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())])], + options: CovenantDeclCallOptions { is_leader: true }, + generated_covenant_entrypoint_name: "__leader_inferred_cov", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "inferred_cov", + args: vec![], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__delegate_inferred_cov", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "inferred_transition", + args: vec![Expr::int(1)], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__inferred_transition", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "singleton_transition", + args: vec![Expr::int(1)], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__singleton_transition", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "singleton_terminate", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())])], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__singleton_terminate", + }, + Case { + source: matrix_all_source, + constructor_args: vec![Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(owner.clone())], + function_name: "fanout_verification", + args: vec![matrix_state_array_arg(vec![(11, next_owner.clone())])], + options: CovenantDeclCallOptions { is_leader: false }, + generated_covenant_entrypoint_name: "__fanout_verification", + }, + ]; + + for case in cases { + let compiled = compile_contract(case.source, &case.constructor_args, CompileOptions::default()).expect("compile succeeds"); + let sigscript = compiled + .build_sig_script_for_covenant_decl(case.function_name, case.args.clone(), case.options) + .expect("covenant declaration sigscript builds"); + let expected = compiled + .build_sig_script(case.generated_covenant_entrypoint_name, case.args) + .expect("generated entrypoint sigscript builds"); + assert_eq!(sigscript, expected, "covenant declaration sigscript should match generated entrypoint for {}", case.function_name); + } +} + +#[test] +fn runtime_rejects_regular_struct_array_entrypoint_arguments_without_struct_signature() { + let source = r#" + contract C() { + struct S { + int a; + byte[2] b; + } + + entrypoint function main(int[] items_a, byte[2][] items_b) { + require(items_a.length == 2); + require(items_b.length == 2); + require(items_a[0] == 7); + require(items_a[1] == 9); + require(items_b[0] == 0x0102); + require(items_b[1] == 0x0304); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); + let main_param_types: Vec = compiled + .ast + .functions + .iter() + .find(|function| function.name == "main") + .expect("main exists") + .params + .iter() + .map(|param| param.type_ref.type_name()) + .collect(); + assert_eq!(main_param_types, vec!["int[]".to_string(), "byte[2][]".to_string()]); + + let err = compiled + .build_sig_script("main", vec![struct_array_arg(vec![(7, vec![0x01, 0x02]), (9, vec![0x03, 0x04])])]) + .expect_err("struct[] arguments should be rejected when the entrypoint signature is not struct-typed"); + assert!(err.to_string().contains("expects 2 arguments"), "unexpected error: {err}"); +} + +#[test] +fn runtime_supports_regular_struct_array_entrypoint_arguments_with_struct_signature() { + let source = r#" + contract C() { + struct S { + int a; + byte[2] b; + } + + entrypoint function main(int[] items_a, byte[2][] items_b) { + require(items_a.length == 2); + require(items_b.length == 2); + require(items_a[0] == 7); + require(items_a[1] == 9); + require(items_b[0] == 0x0102); + require(items_b[1] == 0x0304); + } + } + "#; + + let struct_signature_source = r#" + contract C() { + struct S { + int a; + byte[2] b; + } + + entrypoint function main(S[] x) { + require(x.length == 2); + require(x[0].a == 7); + require(x[1].a == 9); + require(x[0].b == 0x0102); + require(x[1].b == 0x0304); + } + } + "#; + + let mut compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); + replace_compiled_interface(&mut compiled, struct_signature_source, "main", &[("x", "S[]")]); + + let main_param_types: Vec = compiled + .ast + .functions + .iter() + .find(|function| function.name == "main") + .expect("main exists") + .params + .iter() + .map(|param| param.type_ref.type_name()) + .collect(); + assert_eq!(main_param_types, vec!["S[]".to_string()]); + + let sigscript = compiled + .build_sig_script("main", vec![struct_array_arg(vec![(7, vec![0x01, 0x02]), (9, vec![0x03, 0x04])])]) + .expect("sigscript builds"); + let result = run_script_with_sigscript(compiled.script, sigscript); + + assert!(result.is_ok(), "regular struct[] entrypoint arg should execute successfully: {result:?}"); +} + +#[test] +fn runtime_supports_direct_struct_array_entrypoint_signature() { + let source = r#" + contract C() { + struct S { + int a; + byte[2] b; + } + + entrypoint function f(S[] x) { + require(x.length == 2); + require(x[0].a == 7); + require(x[1].a == 9); + require(x[0].b == 0x0102); + require(x[1].b == 0x0304); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); + let f_param_types: Vec = compiled + .ast + .functions + .iter() + .find(|function| function.name == "f") + .expect("f exists") + .params + .iter() + .map(|param| param.type_ref.type_name()) + .collect(); + assert_eq!(f_param_types, vec!["S[]".to_string()]); + + let sigscript = compiled + .build_sig_script("f", vec![struct_array_arg(vec![(7, vec![0x01, 0x02]), (9, vec![0x03, 0x04])])]) + .expect("sigscript builds"); + let result = run_script_with_sigscript(compiled.script, sigscript); + + assert!(result.is_ok(), "direct struct[] entrypoint signature should execute successfully: {result:?}"); +} + +#[test] +fn runtime_rejects_regular_struct_array_non_entrypoint_arguments_without_struct_signature() { + let source = r#" + contract C() { + struct S { + int a; + byte[2] b; + } + + function verify(int[] items_a, byte[2][] items_b) { + require(items_a.length == 2); + require(items_b.length == 2); + require(items_a[0] == 7); + require(items_a[1] == 9); + require(items_b[0] == 0x0102); + require(items_b[1] == 0x0304); + } + + entrypoint function main(int[] items_a, byte[2][] items_b) { + verify(items_a, items_b); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); + let main_param_types: Vec = compiled + .ast + .functions + .iter() + .find(|function| function.name == "main") + .expect("main exists") + .params + .iter() + .map(|param| param.type_ref.type_name()) + .collect(); + assert_eq!(main_param_types, vec!["int[]".to_string(), "byte[2][]".to_string()]); + + let verify_param_types: Vec = compiled + .ast + .functions + .iter() + .find(|function| function.name == "verify") + .expect("verify exists") + .params + .iter() + .map(|param| param.type_ref.type_name()) + .collect(); + assert_eq!(verify_param_types, vec!["int[]".to_string(), "byte[2][]".to_string()]); + + let err = compiled + .build_sig_script("main", vec![struct_array_arg(vec![(7, vec![0x01, 0x02]), (9, vec![0x03, 0x04])])]) + .expect_err("struct[] arguments should be rejected when entrypoint and internal function signatures are not struct-typed"); + assert!(err.to_string().contains("expects 2 arguments"), "unexpected error: {err}"); +} + +#[test] +fn runtime_supports_regular_struct_array_non_entrypoint_arguments_with_struct_signature() { + let source = r#" + contract C() { + struct S { + int a; + byte[2] b; + } + + function verify(int[] items_a, byte[2][] items_b) { + require(items_a.length == 2); + require(items_b.length == 2); + require(items_a[0] == 7); + require(items_a[1] == 9); + require(items_b[0] == 0x0102); + require(items_b[1] == 0x0304); + } + + entrypoint function main(int[] items_a, byte[2][] items_b) { + verify(items_a, items_b); + } + } + "#; + + let struct_signature_source = r#" + contract C() { + struct S { + int a; + byte[2] b; + } + + function verify(S[] x) { + require(x.length == 2); + require(x[0].a == 7); + require(x[1].a == 9); + require(x[0].b == 0x0102); + require(x[1].b == 0x0304); + } + + entrypoint function main(S[] x) { + verify(x); + } + } + "#; + + let mut compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); + replace_compiled_interface(&mut compiled, struct_signature_source, "main", &[("x", "S[]")]); + + let main_param_types: Vec = compiled + .ast + .functions + .iter() + .find(|function| function.name == "main") + .expect("main exists") + .params + .iter() + .map(|param| param.type_ref.type_name()) + .collect(); + assert_eq!(main_param_types, vec!["S[]".to_string()]); + + let verify_param_types: Vec = compiled + .ast + .functions + .iter() + .find(|function| function.name == "verify") + .expect("verify exists") + .params + .iter() + .map(|param| param.type_ref.type_name()) + .collect(); + assert_eq!(verify_param_types, vec!["S[]".to_string()]); + + let sigscript = compiled + .build_sig_script("main", vec![struct_array_arg(vec![(7, vec![0x01, 0x02]), (9, vec![0x03, 0x04])])]) + .expect("sigscript builds"); + let result = run_script_with_sigscript(compiled.script, sigscript); + + assert!(result.is_ok(), "regular struct[] arg should flow through non-entrypoint calls at runtime: {result:?}"); +} + +#[test] +fn rejects_wrong_argument_type_for_direct_struct_array_non_entrypoint_signature() { + let source = r#" + contract C() { + struct S { + int a; + byte[2] b; + } + + function verify(S[] x) { + require(x.length == 2); + } + + entrypoint function main() { + int[] xs = [7, 9]; + verify(xs); + } + } + "#; + + let err = compile_contract(source, &[], CompileOptions::default()) + .expect_err("wrong non-entrypoint struct[] argument type should be rejected"); + assert!(err.to_string().contains("expects S[]") || err.to_string().contains("expects struct S"), "unexpected error: {err}"); +} + +#[test] +fn runtime_supports_direct_struct_array_non_entrypoint_signature() { + let source = r#" + contract C() { + struct S { + int a; + byte[2] b; + } + + function verify(S[] x) { + require(x.length == 2); + require(x[0].a == 7); + require(x[1].a == 9); + require(x[0].b == 0x0102); + require(x[1].b == 0x0304); + } + + entrypoint function main(S[] x) { + verify(x); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); + let verify_param_types: Vec = compiled + .ast + .functions + .iter() + .find(|function| function.name == "verify") + .expect("verify exists") + .params + .iter() + .map(|param| param.type_ref.type_name()) + .collect(); + assert_eq!(verify_param_types, vec!["S[]".to_string()]); + + let sigscript = compiled + .build_sig_script("main", vec![struct_array_arg(vec![(7, vec![0x01, 0x02]), (9, vec![0x03, 0x04])])]) + .expect("sigscript builds"); + let result = run_script_with_sigscript(compiled.script, sigscript); + + assert!(result.is_ok(), "direct struct[] non-entrypoint signature should execute successfully: {result:?}"); +} + #[test] fn rejects_struct_literal_with_wrong_field_type_in_function_call() { let source = r#" @@ -595,7 +1710,11 @@ fn rejects_struct_literal_with_wrong_field_type_in_function_call() { "#; let err = compile_contract(source, &[], CompileOptions::default()).expect_err("compile should fail"); - assert!(err.to_string().contains("function argument '__struct_x_a' expects int") || err.to_string().contains("expects int")); + assert!( + err.to_string().contains("function argument '__struct_x_a' expects int") + || err.to_string().contains("expects int") + || err.to_string().contains("expects S") + ); } #[test] @@ -1987,6 +3106,37 @@ fn runs_contract_with_fields_prolog() { assert!(run_script_with_selector(compiled.script, selector).is_ok()); } +#[test] +fn runs_selector_dispatch_with_contract_fields() { + let source = r#" + contract C() { + int x = 5; + byte[2] y = 0x1234; + + entrypoint function a() { + require(true); + } + + entrypoint function b() { + require(x == 5); + require(y == 0x1234); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); + assert!(!compiled.without_selector, "test requires selector dispatch"); + + let sigscript_a = compiled.build_sig_script("a", vec![]).expect("sigscript a builds"); + let sigscript_b = compiled.build_sig_script("b", vec![]).expect("sigscript b builds"); + + let result_a = run_script_with_sigscript(compiled.script.clone(), sigscript_a); + assert!(result_a.is_ok(), "entrypoint a runtime failed: {}", result_a.unwrap_err()); + + let result_b = run_script_with_sigscript(compiled.script, sigscript_b); + assert!(result_b.is_ok(), "entrypoint b runtime failed: {}", result_b.unwrap_err()); +} + #[test] fn compiles_validate_output_state_to_expected_script() { let source = r#" @@ -2050,6 +3200,9 @@ fn compiles_validate_output_state_to_expected_script() { // resulting chunk: <0x02><0x3412> .add_op(OpCat) .unwrap() + // combine x_chunk || y_chunk + .add_op(OpCat) + .unwrap() // ---- Extract REST_OF_SCRIPT from current input signature script ---- // current input index @@ -2084,10 +3237,7 @@ fn compiles_validate_output_state_to_expected_script() { .unwrap() // ---- new_redeem_script = ---- - // concatenate y_chunk with rest - .add_op(OpCat) - .unwrap() - // prepend x_chunk + // append REST_OF_SCRIPT to merged new-state chunks .add_op(OpCat) .unwrap() diff --git a/silverscript-lang/tests/covenant_compiler_tests.rs b/silverscript-lang/tests/covenant_compiler_tests.rs new file mode 100644 index 0000000..c7b19d8 --- /dev/null +++ b/silverscript-lang/tests/covenant_compiler_tests.rs @@ -0,0 +1,434 @@ +use kaspa_txscript::opcodes::codes::{OpAuthOutputCount, OpCovInputCount, OpCovInputIdx, OpCovOutCount, OpInputCovenantId}; +use silverscript_lang::ast::Expr; +use silverscript_lang::compiler::{CompileOptions, compile_contract}; + +#[test] +fn lowers_auth_covenant_declaration_to_hidden_entrypoint_name() { + let source = r#" + contract Decls(int max_outs) { + #[covenant(binding = auth, from = 1, to = max_outs, mode = verification)] + function spend(int amount) { + require(amount >= 0); + } + } + "#; + + let compiled = compile_contract(source, &[Expr::int(3)], CompileOptions::default()).expect("compile succeeds"); + assert!(compiled.without_selector); + assert_eq!(compiled.abi.len(), 1); + assert_eq!(compiled.abi[0].name, "__spend"); + assert!(compiled.ast.functions.iter().any(|f| f.name == "__covenant_policy_spend" && !f.entrypoint)); + assert!(compiled.ast.functions.iter().any(|f| f.name == "__spend" && f.entrypoint)); + assert!(compiled.script.contains(&OpAuthOutputCount)); +} + +#[test] +fn infers_auth_binding_from_from_equal_one_when_binding_omitted() { + let source = r#" + contract Decls(int max_outs) { + #[covenant(from = 1, to = max_outs)] + function spend(int amount) { + require(amount >= 0); + } + } + "#; + + let compiled = compile_contract(source, &[Expr::int(3)], CompileOptions::default()).expect("compile succeeds"); + assert!(compiled.without_selector); + assert_eq!(compiled.abi.len(), 1); + assert_eq!(compiled.abi[0].name, "__spend"); + assert!(compiled.ast.functions.iter().any(|f| f.name == "__covenant_policy_spend" && !f.entrypoint)); + assert!(compiled.ast.functions.iter().any(|f| f.name == "__spend" && f.entrypoint)); + assert!(compiled.script.contains(&OpAuthOutputCount)); +} + +#[test] +fn lowers_cov_covenant_to_leader_and_delegate_entrypoints() { + let source = r#" + contract Decls(int max_ins, int max_outs) { + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = verification)] + function transition_ok(int nonce) { + require(nonce >= 0); + } + } + "#; + + let compiled = compile_contract(source, &[Expr::int(2), Expr::int(4)], CompileOptions::default()).expect("compile succeeds"); + let abi_names: Vec<&str> = compiled.abi.iter().map(|entry| entry.name.as_str()).collect(); + assert_eq!(abi_names, vec!["__leader_transition_ok", "__delegate_transition_ok"]); + assert!(compiled.ast.functions.iter().any(|f| f.name == "__covenant_policy_transition_ok" && !f.entrypoint)); + assert!(compiled.script.contains(&OpCovInputCount)); + assert!(compiled.script.contains(&OpCovOutCount)); + assert!(compiled.script.contains(&OpCovInputIdx)); +} + +#[test] +fn infers_cov_binding_from_from_greater_than_one_when_binding_omitted() { + let source = r#" + contract Decls(int max_ins, int max_outs) { + #[covenant(from = max_ins, to = max_outs)] + function transition_ok(int nonce) { + require(nonce >= 0); + } + } + "#; + + let compiled = compile_contract(source, &[Expr::int(2), Expr::int(4)], CompileOptions::default()).expect("compile succeeds"); + let abi_names: Vec<&str> = compiled.abi.iter().map(|entry| entry.name.as_str()).collect(); + assert_eq!(abi_names, vec!["__leader_transition_ok", "__delegate_transition_ok"]); + assert!(compiled.ast.functions.iter().any(|f| f.name == "__covenant_policy_transition_ok" && !f.entrypoint)); + assert!(compiled.script.contains(&OpCovInputCount)); + assert!(compiled.script.contains(&OpCovOutCount)); + assert!(compiled.script.contains(&OpCovInputIdx)); +} + +#[test] +fn rejects_cov_verification_without_prev_new_field_arrays() { + let source = r#" + contract Decls() { + int value = 0; + + #[covenant(from = 2, to = 2, mode = verification)] + function transition_ok(int nonce) { + require(nonce >= 0); + } + } + "#; + + let err = compile_contract(source, &[], CompileOptions::default()) + .expect_err("cov verification with state fields should require prev/new field arrays"); + assert!(err.to_string().contains("expects parameters '(State[] prev_states, State[] new_states, ...)'")); +} + +#[test] +fn rejects_cov_transition_without_prev_field_arrays() { + let source = r#" + contract Decls() { + int value = 0; + + #[covenant(from = 2, to = 2, mode = transition)] + function transition_ok(int nonce) : (int) { + return(value + nonce); + } + } + "#; + + let err = compile_contract(source, &[], CompileOptions::default()) + .expect_err("cov transition with state fields should require prev-state field arrays"); + assert!(err.to_string().contains("expects parameters '(State[] prev_states, ...)'")); +} + +#[test] +fn rejects_auth_verification_without_prev_new_state_shape() { + let source = r#" + contract Decls() { + int value = 0; + + #[covenant(binding = auth, from = 1, to = 2, mode = verification)] + function split(int nonce) { + require(nonce >= 0); + } + } + "#; + + let err = compile_contract(source, &[], CompileOptions::default()) + .expect_err("auth verification with state fields should require prev/new state params"); + assert!(err.to_string().contains("mode=verification with binding=auth")); +} + +#[test] +fn rejects_auth_transition_without_prev_state_shape() { + let source = r#" + contract Decls() { + int value = 0; + + #[covenant(binding = auth, from = 1, to = 2, mode = transition)] + function split(int[] nonce) : (int[]) { + return(nonce); + } + } + "#; + + let err = compile_contract(source, &[], CompileOptions::default()) + .expect_err("auth transition with state fields should require prev-state params"); + assert!(err.to_string().contains("mode=transition with binding=auth")); +} + +#[test] +fn rejects_old_per_field_covenant_state_syntax() { + let source = r#" + contract Decls() { + int value = 0; + + #[covenant(binding = auth, from = 1, to = 2, mode = verification)] + function split(int prev_value, int[] new_values) { + require(prev_value >= 0); + require(new_values.length >= 0); + } + } + "#; + + let err = compile_contract(source, &[], CompileOptions::default()) + .expect_err("old per-field covenant syntax should be rejected for stateful contracts"); + assert!(err.to_string().contains("expects parameters '(State prev_state, State[] new_states, ...)'")); +} + +#[test] +fn lowers_singleton_sugar_to_auth_one_to_one_defaults() { + let source = r#" + contract Decls() { + #[covenant.singleton] + function spend(int amount) { + require(amount >= 0); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); + assert!(compiled.without_selector); + assert_eq!(compiled.abi[0].name, "__spend"); + assert!(compiled.script.contains(&OpAuthOutputCount)); +} + +#[test] +fn lowers_fanout_sugar_to_auth_with_to_bound() { + let source = r#" + contract Decls(int max_outs) { + #[covenant.fanout(to = max_outs)] + function split(int amount) { + require(amount >= 0); + } + } + "#; + + let compiled = compile_contract(source, &[Expr::int(3)], CompileOptions::default()).expect("compile succeeds"); + assert!(compiled.without_selector); + assert_eq!(compiled.abi[0].name, "__split"); + assert!(compiled.script.contains(&OpAuthOutputCount)); +} + +#[test] +fn rejects_fanout_sugar_without_to_argument() { + let source = r#" + contract Decls() { + #[covenant.fanout] + function split() { + require(true); + } + } + "#; + + let err = compile_contract(source, &[], CompileOptions::default()).expect_err("fanout sugar requires to"); + assert!(err.to_string().contains("missing covenant attribute argument 'to'")); +} + +#[test] +fn rejects_singleton_sugar_with_from_or_to_arguments() { + let source = r#" + contract Decls() { + #[covenant.singleton(to = 2)] + function split() { + require(true); + } + } + "#; + + let err = compile_contract(source, &[], CompileOptions::default()).expect_err("singleton sugar should reject from/to"); + assert!(err.to_string().contains("covenant.singleton is sugar and does not accept 'from' or 'to' arguments")); +} + +#[test] +fn rejects_auth_covenant_with_from_not_equal_one() { + let source = r#" + contract Decls() { + #[covenant(binding = auth, from = 2, to = 4, mode = verification)] + function split() { + require(true); + } + } + "#; + + let err = compile_contract(source, &[], CompileOptions::default()).expect_err("auth binding must require from=1"); + assert!(err.to_string().contains("binding=auth requires from = 1")); +} + +#[test] +fn rejects_cov_covenant_groups_multiple_for_now() { + let source = r#" + contract Decls() { + #[covenant(binding = cov, from = 2, to = 4, mode = verification, groups = multiple)] + function step() { + require(true); + } + } + "#; + + let err = compile_contract(source, &[], CompileOptions::default()).expect_err("cov groups=multiple should be rejected"); + assert!(err.to_string().contains("binding=cov with groups=multiple is not supported yet")); +} + +#[test] +fn infers_verification_mode_when_mode_omitted_and_no_returns() { + let source = r#" + contract Decls() { + #[covenant(from = 1, to = 2)] + function check(int x) { + require(x >= 0); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); + assert!(compiled.ast.functions.iter().any(|f| f.name == "__check" && f.entrypoint)); +} + +#[test] +fn infers_transition_mode_when_mode_omitted_and_has_returns() { + let source = r#" + contract Decls(int init_value) { + int value = init_value; + + #[covenant(from = 1, to = 1)] + function roll(State prev_state, int x) : (State) { + return({ value: prev_state.value + x }); + } + } + "#; + + let compiled = compile_contract(source, &[Expr::int(3)], CompileOptions::default()).expect("compile succeeds"); + assert!(compiled.ast.functions.iter().any(|f| f.name == "__roll" && f.entrypoint)); +} + +#[test] +fn rejects_singleton_transition_array_returns_without_termination_allowed() { + let source = r#" + contract Decls(int init_value) { + int value = init_value; + + #[covenant.singleton(mode = transition)] + function roll(State prev_state, State[] next_states) : (State[]) { + return(next_states); + } + } + "#; + + let err = compile_contract(source, &[Expr::int(3)], CompileOptions::default()) + .expect_err("singleton transition arrays should require termination=allowed"); + assert!(err.to_string().contains("arrays are not allowed unless termination=allowed")); +} + +#[test] +fn allows_singleton_transition_array_returns_with_termination_allowed() { + let source = r#" + contract Decls(int init_value) { + int value = init_value; + + #[covenant.singleton(mode = transition, termination = allowed)] + function roll(State prev_state, State[] next_states) : (State[]) { + return(next_states); + } + } + "#; + + let compiled = compile_contract(source, &[Expr::int(3)], CompileOptions::default()).expect("compile succeeds"); + assert!(compiled.ast.functions.iter().any(|f| f.name == "__roll" && f.entrypoint)); +} + +#[test] +fn rejects_termination_allowed_for_non_singleton() { + let source = r#" + contract Decls(int max_outs, int init_value) { + int value = init_value; + + #[covenant(from = 1, to = max_outs, mode = transition, termination = allowed)] + function roll(State prev_state, State[] next_states) : (State[]) { + return(next_states); + } + } + "#; + + let err = compile_contract(source, &[Expr::int(3), Expr::int(10)], CompileOptions::default()) + .expect_err("termination=allowed should be singleton-only"); + assert!(err.to_string().contains("termination is only supported for singleton covenants")); +} + +#[test] +fn rejects_termination_disallowed_for_non_singleton() { + let source = r#" + contract Decls(int max_outs, int init_value) { + int value = init_value; + + #[covenant(from = 1, to = max_outs, mode = transition, termination = disallowed)] + function roll(State prev_state, State[] next_states) : (State[]) { + return(next_states); + } + } + "#; + + let err = compile_contract(source, &[Expr::int(3), Expr::int(10)], CompileOptions::default()) + .expect_err("termination arg should be singleton-only regardless of value"); + assert!(err.to_string().contains("termination is only supported for singleton covenants")); +} + +#[test] +fn rejects_termination_in_verification_mode() { + let source = r#" + contract Decls() { + #[covenant.singleton(mode = verification, termination = allowed)] + function check() { + require(true); + } + } + "#; + + let err = + compile_contract(source, &[], CompileOptions::default()).expect_err("termination should not be allowed in verification mode"); + assert!(err.to_string().contains("termination is only supported in mode=transition")); +} + +#[test] +fn rejects_transition_mode_without_return_values() { + let source = r#" + contract Decls() { + #[covenant(binding = auth, from = 1, to = 1, mode = transition)] + function roll() { + require(true); + } + } + "#; + + let err = compile_contract(source, &[], CompileOptions::default()).expect_err("transition policy must return values"); + assert!(err.to_string().contains("transition mode policy functions must declare return values")); +} + +#[test] +fn rejects_verification_mode_with_return_values() { + let source = r#" + contract Decls() { + #[covenant(binding = auth, from = 1, to = 1, mode = verification)] + function check() : (int) { + return(1); + } + } + "#; + + let err = compile_contract(source, &[], CompileOptions::default()).expect_err("verification policy must not return values"); + assert!(err.to_string().contains("verification mode policy functions must not declare return values")); +} + +#[test] +fn auth_covenant_groups_single_injects_shared_count_check() { + let source = r#" + contract Decls() { + #[covenant(binding = auth, from = 1, to = 4, mode = verification, groups = single)] + function spend() { + require(true); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); + assert!(compiled.script.contains(&OpInputCovenantId)); + assert!(compiled.script.contains(&OpCovOutCount)); + assert!(compiled.script.contains(&OpAuthOutputCount)); +} diff --git a/silverscript-lang/tests/covenant_declaration_ast_tests.rs b/silverscript-lang/tests/covenant_declaration_ast_tests.rs new file mode 100644 index 0000000..c22d4a4 --- /dev/null +++ b/silverscript-lang/tests/covenant_declaration_ast_tests.rs @@ -0,0 +1,1395 @@ +use silverscript_lang::ast::visit::{AstVisitorMut, NameKind, visit_contract_mut}; +use silverscript_lang::ast::{ContractAst, Expr, FunctionAst, parse_contract_ast}; +use silverscript_lang::compiler::{CompileOptions, compile_contract}; +use silverscript_lang::span::Span; +use std::collections::HashSet; + +fn canonicalize_generated_name(name: &str) -> String { + if let Some(rest) = name.strip_prefix("__covenant_policy_") { + return format!("covenant_policy_{rest}"); + } + if let Some(rest) = name.strip_prefix("__cov_") { + return format!("cov_{rest}"); + } + if let Some(rest) = name.strip_prefix("__") { + return rest.to_string(); + } + name.to_string() +} + +struct GeneratedNameCanonicalizer; + +impl<'i> AstVisitorMut<'i> for GeneratedNameCanonicalizer { + fn visit_name(&mut self, name: &mut String, _kind: NameKind) { + *name = canonicalize_generated_name(name); + } + + fn visit_span(&mut self, span: &mut Span<'i>) { + *span = Span::default(); + } +} + +fn normalize_contract(contract: &mut ContractAst<'_>) { + visit_contract_mut(&mut GeneratedNameCanonicalizer, contract); +} + +fn compile_and_normalize_contract<'i>(source: &'i str, constructor_args: &[Expr<'i>]) -> ContractAst<'i> { + let compiled = compile_contract(source, constructor_args, CompileOptions::default()).expect("compile succeeds"); + let mut contract = compiled.ast; + normalize_contract(&mut contract); + contract +} + +fn parse_and_normalize_contract<'i>(source: &'i str) -> ContractAst<'i> { + let mut contract = parse_contract_ast(source).expect("expected contract parses"); + normalize_contract(&mut contract); + contract +} + +fn assert_lowers_to_expected_ast<'i>(source: &'i str, expected_lowered_source: &'i str, constructor_args: &[Expr<'i>]) { + let actual = compile_and_normalize_contract(source, constructor_args); + let expected = parse_and_normalize_contract(expected_lowered_source); + assert_eq!(actual, expected); +} + +fn function_by_name<'a, 'i>(functions: &'a [FunctionAst<'i>], name: &str) -> &'a FunctionAst<'i> { + functions.iter().find(|function| function.name == name).unwrap_or_else(|| panic!("missing function '{}'", name)) +} + +fn assert_param_names(function: &FunctionAst<'_>, expected: &[&str]) { + let actual: Vec<&str> = function.params.iter().map(|param| param.name.as_str()).collect(); + assert_eq!(actual, expected, "unexpected params for '{}'", function.name); +} + +#[test] +fn lowers_auth_groups_single_to_expected_wrapper_ast() { + let source = r#" + contract Decls(int max_outs) { + int value = 0; + + #[covenant(binding = auth, from = 1, to = max_outs, groups = single)] + function split(State prev_state, State[] new_states, int amount) { + require(amount >= 0); + } + } + "#; + + let expected_lowered = r#" + contract Decls(int max_outs) { + int value = 0; + + function covenant_policy_split(State prev_state, State[] new_states, int amount) { + require(amount >= 0); + } + + entrypoint function split(State[] new_states, int amount) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + int cov_shared_out_count = OpCovOutCount(cov_id); + require(cov_shared_out_count == cov_out_count); + + covenant_policy_split(value, new_value, amount); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { value: new_value[cov_k] }); + } + } + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(4)]); +} + +#[test] +fn lowers_cov_to_leader_and_delegate_expected_wrapper_ast() { + let source = r#" + contract Decls(int max_ins, int max_outs) { + int value = 0; + + #[covenant(from = max_ins, to = max_outs, mode = verification)] + function transition_ok(State[] prev_states, State[] new_states, int delta) { + require(delta >= 0); + } + } + "#; + + let expected_lowered = r#" + contract Decls(int max_ins, int max_outs) { + int value = 0; + + function covenant_policy_transition_ok(State[] prev_states, State[] new_states, int delta) { + require(delta >= 0); + } + + entrypoint function leader_transition_ok(State[] new_states, int delta) { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) == this.activeInputIndex); + + int cov_in_count = OpCovInputCount(cov_id); + require(cov_in_count <= max_ins); + + int cov_out_count = OpCovOutCount(cov_id); + int[] prev_value; + + for(cov_in_k, 0, max_ins) { + if (cov_in_k < cov_in_count) { + int cov_in_idx = OpCovInputIdx(cov_id, cov_in_k); + { value: int cov_prev_value } = readInputState(cov_in_idx); + prev_value.push(cov_prev_value); + } + } + + covenant_policy_transition_ok(prev_value, new_value, delta); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpCovOutputIdx(cov_id, cov_k); + validateOutputState(cov_out_idx, { value: new_value[cov_k] }); + } + } + } + + entrypoint function delegate_transition_ok() { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) != this.activeInputIndex); + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(2), Expr::int(3)]); +} + +#[test] +fn lowers_singleton_transition_uses_returned_state_in_validation() { + let source = r#" + contract Decls(int init_value) { + int value = init_value; + + #[covenant.singleton(mode = transition)] + function bump(State prev_state, int delta) : (State) { + return({ value: prev_state.value + delta }); + } + } + "#; + + let expected_lowered = r#" + contract Decls(int init_value) { + int value = init_value; + + function covenant_policy_bump(State prev_state, int delta) : (State) { + return({ value: prev_state.value + delta }); + } + + entrypoint function bump(int delta) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + (int cov_new_value) = covenant_policy_bump(value, delta); + require(cov_out_count == 1); + + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, 0); + validateOutputState(cov_out_idx, { value: cov_new_value }); + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(7)]); +} + +#[test] +fn lowers_transition_array_return_to_exact_output_count_match() { + let source = r#" + contract Decls(int max_outs, int init_value) { + int value = init_value; + + #[covenant(from = 1, to = max_outs, mode = transition)] + function fanout(State prev_state, State[] next_states) : (State[]) { + return(next_states); + } + } + "#; + + let expected_lowered = r#" + contract Decls(int max_outs, int init_value) { + int value = init_value; + + function covenant_policy_fanout(State prev_state, State[] next_states) : (State[]) { + return(next_states); + } + + entrypoint function fanout(State[] next_states) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + (int[] cov_new_value) = covenant_policy_fanout(value, next_value); + require(cov_out_count <= max_outs); + require(cov_out_count == cov_new_value.length); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { value: cov_new_value[cov_k] }); + } + } + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(4), Expr::int(10)]); +} + +#[test] +fn lowers_singleton_transition_with_termination_allowed_to_array_cardinality_checks() { + let source = r#" + contract Decls(int init_value) { + int value = init_value; + + #[covenant.singleton(mode = transition, termination = allowed)] + function bump_or_terminate(State prev_state, State[] next_states) : (State[]) { + return(next_states); + } + } + "#; + + let expected_lowered = r#" + contract Decls(int init_value) { + int value = init_value; + + function covenant_policy_bump_or_terminate(State prev_state, State[] next_states) : (State[]) { + return(next_states); + } + + entrypoint function bump_or_terminate(State[] next_states) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + (int[] cov_new_value) = covenant_policy_bump_or_terminate(value, next_value); + require(cov_out_count <= 1); + require(cov_out_count == cov_new_value.length); + + for(cov_k, 0, 1) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { value: cov_new_value[cov_k] }); + } + } + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(10)]); +} + +#[test] +fn lowers_auth_verification_groups_multiple_two_field_state_to_expected_wrapper_ast() { + let source = r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = auth, from = 1, to = max_outs, mode = verification, groups = multiple)] + function step(State prev_state, State[] new_states, int nonce) { + require(nonce >= 0); + } + } + "#; + + let expected_lowered = r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + function covenant_policy_step(State prev_state, State[] new_states, int nonce) { + require(nonce >= 0); + } + + entrypoint function step(State[] new_states, int nonce) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + covenant_policy_step(amount, owner, new_amount, new_owner, nonce); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { + amount: new_amount[cov_k], + owner: new_owner[cov_k] + }); + } + } + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(4), Expr::int(10), Expr::bytes(vec![7u8; 32])]); +} + +#[test] +fn lowers_auth_verification_groups_single_two_field_state_to_expected_wrapper_ast() { + let source = r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = auth, from = 1, to = max_outs, mode = verification, groups = single)] + function step(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + } + "#; + + let expected_lowered = r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + function covenant_policy_step(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + + entrypoint function step(State[] new_states) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + int cov_shared_out_count = OpCovOutCount(cov_id); + require(cov_shared_out_count == cov_out_count); + + covenant_policy_step(amount, owner, new_amount, new_owner); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { + amount: new_amount[cov_k], + owner: new_owner[cov_k] + }); + } + } + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(4), Expr::int(10), Expr::bytes(vec![7u8; 32])]); +} + +#[test] +fn lowers_auth_transition_two_field_state_to_expected_wrapper_ast() { + let source = r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = auth, from = 1, to = max_outs, mode = transition)] + function step(State prev_state, int fee) : (State) { + return({ + amount: prev_state.amount - fee, + owner: prev_state.owner + }); + } + } + "#; + + let expected_lowered = r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + function covenant_policy_step(State prev_state, int fee) : (State) { + return({ amount: prev_state.amount - fee, owner: prev_state.owner }); + } + + entrypoint function step(int fee) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + (int cov_new_amount, byte[32] cov_new_owner) = covenant_policy_step(amount, owner, fee); + require(cov_out_count == 1); + + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, 0); + validateOutputState(cov_out_idx, { + amount: cov_new_amount, + owner: cov_new_owner + }); + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(4), Expr::int(10), Expr::bytes(vec![7u8; 32])]); +} + +#[test] +fn lowers_cov_verification_two_field_state_to_expected_wrapper_ast() { + let source = r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = verification)] + function step(State[] prev_states, State[] new_states, int nonce) { + require(nonce >= 0); + } + } + "#; + + let expected_lowered = r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + function covenant_policy_step(State[] prev_states, State[] new_states, int nonce) { + require(nonce >= 0); + } + + entrypoint function leader_step(State[] new_states, int nonce) { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) == this.activeInputIndex); + + int cov_in_count = OpCovInputCount(cov_id); + require(cov_in_count <= max_ins); + + int cov_out_count = OpCovOutCount(cov_id); + int[] prev_amount; + byte[32][] prev_owner; + + for(cov_in_k, 0, max_ins) { + if (cov_in_k < cov_in_count) { + int cov_in_idx = OpCovInputIdx(cov_id, cov_in_k); + { + amount: int cov_prev_amount, + owner: byte[32] cov_prev_owner + } = readInputState(cov_in_idx); + prev_amount.push(cov_prev_amount); + prev_owner.push(cov_prev_owner); + } + } + + covenant_policy_step(prev_amount, prev_owner, new_amount, new_owner, nonce); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpCovOutputIdx(cov_id, cov_k); + validateOutputState(cov_out_idx, { + amount: new_amount[cov_k], + owner: new_owner[cov_k] + }); + } + } + } + + entrypoint function delegate_step() { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) != this.activeInputIndex); + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(vec![7u8; 32])]); +} + +#[test] +fn lowers_cov_transition_two_field_state_to_expected_wrapper_ast() { + let source = r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = transition)] + function step(State[] prev_states, int fee) : (State[]) { + require(fee >= 0); + return(prev_states); + } + } + "#; + + let expected_lowered = r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + function covenant_policy_step(State[] prev_states, int fee) : (State[]) { + require(fee >= 0); + return(prev_states); + } + + entrypoint function leader_step(State[] prev_states, int fee) { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) == this.activeInputIndex); + + int cov_in_count = OpCovInputCount(cov_id); + require(cov_in_count <= max_ins); + + int cov_out_count = OpCovOutCount(cov_id); + + for(cov_in_k, 0, max_ins) { + if (cov_in_k < cov_in_count) { + int cov_in_idx = OpCovInputIdx(cov_id, cov_in_k); + { + amount: int cov_prev_amount, + owner: byte[32] cov_prev_owner + } = readInputState(cov_in_idx); + } + } + + (int[] cov_new_amount, byte[32][] cov_new_owner) = covenant_policy_step(prev_amount, prev_owner, fee); + require(cov_new_owner.length == cov_new_amount.length); + require(cov_out_count <= max_outs); + require(cov_out_count == cov_new_amount.length); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpCovOutputIdx(cov_id, cov_k); + validateOutputState(cov_out_idx, { + amount: cov_new_amount[cov_k], + owner: cov_new_owner[cov_k] + }); + } + } + } + + entrypoint function delegate_step() { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) != this.activeInputIndex); + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(vec![7u8; 32])]); +} + +#[test] +fn lowers_inferred_auth_verification_two_field_state_to_expected_wrapper_ast() { + let source = r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(from = 1, to = max_outs)] + function step(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + } + "#; + + let expected_lowered = r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + function covenant_policy_step(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + + entrypoint function step(State[] new_states) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + covenant_policy_step(amount, owner, new_amount, new_owner); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { + amount: new_amount[cov_k], + owner: new_owner[cov_k] + }); + } + } + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(4), Expr::int(10), Expr::bytes(vec![7u8; 32])]); +} + +#[test] +fn lowers_inferred_cov_verification_two_field_state_to_expected_wrapper_ast() { + let source = r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(from = max_ins, to = max_outs)] + function step(State[] prev_states, State[] new_states) { + require(new_states.length == new_states.length); + } + } + "#; + + let expected_lowered = r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + function covenant_policy_step(State[] prev_states, State[] new_states) { + require(new_states.length == new_states.length); + } + + entrypoint function leader_step(State[] new_states) { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) == this.activeInputIndex); + + int cov_in_count = OpCovInputCount(cov_id); + require(cov_in_count <= max_ins); + + int cov_out_count = OpCovOutCount(cov_id); + int[] prev_amount; + byte[32][] prev_owner; + + for(cov_in_k, 0, max_ins) { + if (cov_in_k < cov_in_count) { + int cov_in_idx = OpCovInputIdx(cov_id, cov_in_k); + { + amount: int cov_prev_amount, + owner: byte[32] cov_prev_owner + } = readInputState(cov_in_idx); + prev_amount.push(cov_prev_amount); + prev_owner.push(cov_prev_owner); + } + } + + covenant_policy_step(prev_amount, prev_owner, new_amount, new_owner); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpCovOutputIdx(cov_id, cov_k); + validateOutputState(cov_out_idx, { + amount: new_amount[cov_k], + owner: new_owner[cov_k] + }); + } + } + } + + entrypoint function delegate_step() { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) != this.activeInputIndex); + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(vec![7u8; 32])]); +} + +#[test] +fn lowers_inferred_singleton_transition_two_field_state_to_expected_wrapper_ast() { + let source = r#" + contract Matrix(int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(from = 1, to = 1)] + function step(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + } + "#; + + let expected_lowered = r#" + contract Matrix(int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + function covenant_policy_step(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + + entrypoint function step(int delta) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + (int cov_new_amount, byte[32] cov_new_owner) = covenant_policy_step(amount, owner, delta); + require(cov_out_count == 1); + + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, 0); + validateOutputState(cov_out_idx, { + amount: cov_new_amount, + owner: cov_new_owner + }); + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(10), Expr::bytes(vec![7u8; 32])]); +} + +#[test] +fn lowers_singleton_sugar_transition_two_field_state_to_expected_wrapper_ast() { + let source = r#" + contract Matrix(int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant.singleton(mode = transition)] + function step(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + } + "#; + + let expected_lowered = r#" + contract Matrix(int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + function covenant_policy_step(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + + entrypoint function step(int delta) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + (int cov_new_amount, byte[32] cov_new_owner) = covenant_policy_step(amount, owner, delta); + require(cov_out_count == 1); + + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, 0); + validateOutputState(cov_out_idx, { + amount: cov_new_amount, + owner: cov_new_owner + }); + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(10), Expr::bytes(vec![7u8; 32])]); +} + +#[test] +fn lowers_singleton_sugar_transition_termination_allowed_two_field_state_to_expected_wrapper_ast() { + let source = r#" + contract Matrix(int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant.singleton(mode = transition, termination = allowed)] + function step( + State prev_state, + State[] next_states + ) : (State[]) { + return(next_states); + } + } + "#; + + let expected_lowered = r#" + contract Matrix(int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + function covenant_policy_step( + State prev_state, + State[] next_states + ) : (State[]) { + return(next_states); + } + + entrypoint function step(State[] next_states) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + (int[] cov_new_amount, byte[32][] cov_new_owner) = covenant_policy_step(amount, owner, next_amount, next_owner); + require(cov_new_owner.length == cov_new_amount.length); + require(cov_out_count <= 1); + require(cov_out_count == cov_new_amount.length); + + for(cov_k, 0, 1) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { + amount: cov_new_amount[cov_k], + owner: cov_new_owner[cov_k] + }); + } + } + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(10), Expr::bytes(vec![7u8; 32])]); +} + +#[test] +fn lowers_fanout_sugar_verification_two_field_state_to_expected_wrapper_ast() { + let source = r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant.fanout(to = max_outs, mode = verification)] + function step(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + } + "#; + + let expected_lowered = r#" + contract Matrix(int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + function covenant_policy_step(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + + entrypoint function step(State[] new_states) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + covenant_policy_step(amount, owner, new_amount, new_owner); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { + amount: new_amount[cov_k], + owner: new_owner[cov_k] + }); + } + } + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(4), Expr::int(10), Expr::bytes(vec![7u8; 32])]); +} + +#[test] +fn lowers_many_covenant_declarations_in_one_contract_to_expected_wrapper_ast() { + let source = r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = auth, from = 1, to = max_outs, mode = verification, groups = multiple)] + function auth_verification_multi( + State prev_state, + State[] new_states, + int nonce + ) { + require(nonce >= 0); + } + + #[covenant(binding = auth, from = 1, to = max_outs, mode = verification, groups = single)] + function auth_verification_single(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + + #[covenant(binding = auth, from = 1, to = max_outs, mode = transition)] + function auth_transition(State prev_state, int fee) : (State) { + return({ amount: prev_state.amount - fee, owner: prev_state.owner }); + } + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = verification)] + function cov_verification( + State[] prev_states, + State[] new_states, + int nonce + ) { + require(nonce >= 0); + } + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = transition)] + function cov_transition(State[] prev_states, int fee) : (State[]) { + require(fee >= 0); + return(prev_states); + } + + #[covenant(from = 1, to = max_outs)] + function inferred_auth(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + + #[covenant(from = max_ins, to = max_outs)] + function inferred_cov(State[] prev_states, State[] new_states) { + require(new_states.length == new_states.length); + } + + #[covenant(from = 1, to = 1)] + function inferred_transition(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + + #[covenant.singleton(mode = transition)] + function singleton_transition(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + + #[covenant.singleton(mode = transition, termination = allowed)] + function singleton_terminate(State prev_state, State[] next_states) : (State[]) { + require(prev_state.amount >= 0); + return(next_states); + } + + #[covenant.fanout(to = max_outs, mode = verification)] + function fanout_verification(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + } + "#; + + let expected_lowered = r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + function covenant_policy_auth_verification_multi( + State prev_state, + State[] new_states, + int nonce + ) { + require(nonce >= 0); + } + + entrypoint function auth_verification_multi(State[] new_states, int nonce) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + covenant_policy_auth_verification_multi(amount, owner, new_amount, new_owner, nonce); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { + amount: new_amount[cov_k], + owner: new_owner[cov_k] + }); + } + } + } + + function covenant_policy_auth_verification_single(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + + entrypoint function auth_verification_single(State[] new_states) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + int cov_shared_out_count = OpCovOutCount(cov_id); + require(cov_shared_out_count == cov_out_count); + + covenant_policy_auth_verification_single(amount, owner, new_amount, new_owner); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { + amount: new_amount[cov_k], + owner: new_owner[cov_k] + }); + } + } + } + + function covenant_policy_auth_transition(State prev_state, int fee) : (State) { + return({ amount: prev_state.amount - fee, owner: prev_state.owner }); + } + + entrypoint function auth_transition(int fee) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + (int cov_new_amount, byte[32] cov_new_owner) = covenant_policy_auth_transition(amount, owner, fee); + require(cov_out_count == 1); + + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, 0); + validateOutputState(cov_out_idx, { + amount: cov_new_amount, + owner: cov_new_owner + }); + } + + function covenant_policy_cov_verification( + State[] prev_states, + State[] new_states, + int nonce + ) { + require(nonce >= 0); + } + + entrypoint function leader_cov_verification(State[] new_states, int nonce) { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) == this.activeInputIndex); + + int cov_in_count = OpCovInputCount(cov_id); + require(cov_in_count <= max_ins); + + int cov_out_count = OpCovOutCount(cov_id); + int[] prev_amount; + byte[32][] prev_owner; + + for(cov_in_k, 0, max_ins) { + if (cov_in_k < cov_in_count) { + int cov_in_idx = OpCovInputIdx(cov_id, cov_in_k); + { + amount: int cov_prev_amount, + owner: byte[32] cov_prev_owner + } = readInputState(cov_in_idx); + prev_amount.push(cov_prev_amount); + prev_owner.push(cov_prev_owner); + } + } + + covenant_policy_cov_verification(prev_amount, prev_owner, new_amount, new_owner, nonce); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpCovOutputIdx(cov_id, cov_k); + validateOutputState(cov_out_idx, { + amount: new_amount[cov_k], + owner: new_owner[cov_k] + }); + } + } + } + + entrypoint function delegate_cov_verification() { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) != this.activeInputIndex); + } + + function covenant_policy_cov_transition(State[] prev_states, int fee) : (State[]) { + require(fee >= 0); + return(prev_states); + } + + entrypoint function leader_cov_transition(State[] prev_states, int fee) { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) == this.activeInputIndex); + + int cov_in_count = OpCovInputCount(cov_id); + require(cov_in_count <= max_ins); + + int cov_out_count = OpCovOutCount(cov_id); + + for(cov_in_k, 0, max_ins) { + if (cov_in_k < cov_in_count) { + int cov_in_idx = OpCovInputIdx(cov_id, cov_in_k); + { + amount: int cov_prev_amount, + owner: byte[32] cov_prev_owner + } = readInputState(cov_in_idx); + } + } + + (int[] cov_new_amount, byte[32][] cov_new_owner) = covenant_policy_cov_transition(prev_amount, prev_owner, fee); + require(cov_new_owner.length == cov_new_amount.length); + require(cov_out_count <= max_outs); + require(cov_out_count == cov_new_amount.length); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpCovOutputIdx(cov_id, cov_k); + validateOutputState(cov_out_idx, { + amount: cov_new_amount[cov_k], + owner: cov_new_owner[cov_k] + }); + } + } + } + + entrypoint function delegate_cov_transition() { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) != this.activeInputIndex); + } + + function covenant_policy_inferred_auth(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + + entrypoint function inferred_auth(State[] new_states) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + covenant_policy_inferred_auth(amount, owner, new_amount, new_owner); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { + amount: new_amount[cov_k], + owner: new_owner[cov_k] + }); + } + } + } + + function covenant_policy_inferred_cov(State[] prev_states, State[] new_states) { + require(new_states.length == new_states.length); + } + + entrypoint function leader_inferred_cov(State[] new_states) { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) == this.activeInputIndex); + + int cov_in_count = OpCovInputCount(cov_id); + require(cov_in_count <= max_ins); + + int cov_out_count = OpCovOutCount(cov_id); + int[] prev_amount; + byte[32][] prev_owner; + + for(cov_in_k, 0, max_ins) { + if (cov_in_k < cov_in_count) { + int cov_in_idx = OpCovInputIdx(cov_id, cov_in_k); + { + amount: int cov_prev_amount, + owner: byte[32] cov_prev_owner + } = readInputState(cov_in_idx); + prev_amount.push(cov_prev_amount); + prev_owner.push(cov_prev_owner); + } + } + + covenant_policy_inferred_cov(prev_amount, prev_owner, new_amount, new_owner); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpCovOutputIdx(cov_id, cov_k); + validateOutputState(cov_out_idx, { + amount: new_amount[cov_k], + owner: new_owner[cov_k] + }); + } + } + } + + entrypoint function delegate_inferred_cov() { + byte[32] cov_id = OpInputCovenantId(this.activeInputIndex); + + require(OpCovInputIdx(cov_id, 0) != this.activeInputIndex); + } + + function covenant_policy_inferred_transition(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + + entrypoint function inferred_transition(int delta) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + (int cov_new_amount, byte[32] cov_new_owner) = covenant_policy_inferred_transition(amount, owner, delta); + require(cov_out_count == 1); + + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, 0); + validateOutputState(cov_out_idx, { + amount: cov_new_amount, + owner: cov_new_owner + }); + } + + function covenant_policy_singleton_transition(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + + entrypoint function singleton_transition(int delta) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + (int cov_new_amount, byte[32] cov_new_owner) = covenant_policy_singleton_transition(amount, owner, delta); + require(cov_out_count == 1); + + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, 0); + validateOutputState(cov_out_idx, { + amount: cov_new_amount, + owner: cov_new_owner + }); + } + + function covenant_policy_singleton_terminate(State prev_state, State[] next_states) : (State[]) { + require(prev_state.amount >= 0); + return(next_states); + } + + entrypoint function singleton_terminate(State[] next_states) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + (int[] cov_new_amount, byte[32][] cov_new_owner) = covenant_policy_singleton_terminate(amount, owner, next_amount, next_owner); + require(cov_new_owner.length == cov_new_amount.length); + require(cov_out_count <= 1); + require(cov_out_count == cov_new_amount.length); + + for(cov_k, 0, 1) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { + amount: cov_new_amount[cov_k], + owner: cov_new_owner[cov_k] + }); + } + } + } + + function covenant_policy_fanout_verification(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + + entrypoint function fanout_verification(State[] new_states) { + int cov_out_count = OpAuthOutputCount(this.activeInputIndex); + + covenant_policy_fanout_verification(amount, owner, new_amount, new_owner); + require(cov_out_count <= max_outs); + + for(cov_k, 0, max_outs) { + if (cov_k < cov_out_count) { + int cov_out_idx = OpAuthOutputIdx(this.activeInputIndex, cov_k); + validateOutputState(cov_out_idx, { + amount: new_amount[cov_k], + owner: new_owner[cov_k] + }); + } + } + } + } + "#; + + assert_lowers_to_expected_ast(source, expected_lowered, &[Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(vec![7u8; 32])]); +} + +#[test] +fn covers_attribute_config_combinations_with_two_field_state() { + let source = r#" + contract Matrix(int max_ins, int max_outs, int init_amount, byte[32] init_owner) { + int amount = init_amount; + byte[32] owner = init_owner; + + #[covenant(binding = auth, from = 1, to = max_outs, mode = verification, groups = multiple)] + function auth_verification_multi( + State prev_state, + State[] new_states, + int nonce + ) { + require(nonce >= 0); + } + + #[covenant(binding = auth, from = 1, to = max_outs, mode = verification, groups = single)] + function auth_verification_single(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + + #[covenant(binding = auth, from = 1, to = max_outs, mode = transition)] + function auth_transition(State prev_state, int fee) : (State) { + return({ amount: prev_state.amount - fee, owner: prev_state.owner }); + } + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = verification)] + function cov_verification( + State[] prev_states, + State[] new_states, + int nonce + ) { + require(nonce >= 0); + } + + #[covenant(binding = cov, from = max_ins, to = max_outs, mode = transition)] + function cov_transition(State[] prev_states, int fee) : (State[]) { + require(fee >= 0); + return(prev_states); + } + + #[covenant(from = 1, to = max_outs)] + function inferred_auth(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + + #[covenant(from = max_ins, to = max_outs)] + function inferred_cov(State[] prev_states, State[] new_states) { + require(new_states.length == new_states.length); + } + + #[covenant(from = 1, to = 1)] + function inferred_transition(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + + #[covenant.singleton(mode = transition)] + function singleton_transition(State prev_state, int delta) : (State) { + return({ amount: prev_state.amount + delta, owner: prev_state.owner }); + } + + #[covenant.singleton(mode = transition, termination = allowed)] + function singleton_terminate(State prev_state, State[] next_states) : (State[]) { + require(prev_state.amount >= 0); + return(next_states); + } + + #[covenant.fanout(to = max_outs, mode = verification)] + function fanout_verification(State prev_state, State[] new_states) { + require(new_states.length == new_states.length); + } + } + "#; + + let contract = compile_and_normalize_contract(source, &[Expr::int(2), Expr::int(4), Expr::int(10), Expr::bytes(vec![7u8; 32])]); + let functions = &contract.functions; + + let expected_entrypoints: HashSet<&str> = vec![ + "auth_verification_multi", + "auth_verification_single", + "auth_transition", + "leader_cov_verification", + "delegate_cov_verification", + "leader_cov_transition", + "delegate_cov_transition", + "inferred_auth", + "leader_inferred_cov", + "delegate_inferred_cov", + "inferred_transition", + "singleton_transition", + "singleton_terminate", + "fanout_verification", + ] + .into_iter() + .collect(); + let actual_entrypoints: HashSet<&str> = + functions.iter().filter(|function| function.entrypoint).map(|function| function.name.as_str()).collect(); + assert_eq!(actual_entrypoints, expected_entrypoints); + + for policy_name in [ + "covenant_policy_auth_verification_multi", + "covenant_policy_auth_verification_single", + "covenant_policy_auth_transition", + "covenant_policy_cov_verification", + "covenant_policy_cov_transition", + "covenant_policy_inferred_auth", + "covenant_policy_inferred_cov", + "covenant_policy_inferred_transition", + "covenant_policy_singleton_transition", + "covenant_policy_singleton_terminate", + "covenant_policy_fanout_verification", + ] { + let policy = function_by_name(functions, policy_name); + assert!(!policy.entrypoint, "policy '{}' must not be an entrypoint", policy_name); + } + + assert_param_names(function_by_name(functions, "auth_verification_multi"), &["new_states", "nonce"]); + assert_param_names(function_by_name(functions, "auth_verification_single"), &["new_states"]); + assert_param_names(function_by_name(functions, "auth_transition"), &["fee"]); + assert_param_names(function_by_name(functions, "leader_cov_verification"), &["new_states", "nonce"]); + assert_param_names(function_by_name(functions, "delegate_cov_verification"), &[]); + assert_param_names(function_by_name(functions, "leader_cov_transition"), &["prev_states", "fee"]); + assert_param_names(function_by_name(functions, "delegate_cov_transition"), &[]); + assert_param_names(function_by_name(functions, "inferred_auth"), &["new_states"]); + assert_param_names(function_by_name(functions, "leader_inferred_cov"), &["new_states"]); + assert_param_names(function_by_name(functions, "delegate_inferred_cov"), &[]); + assert_param_names(function_by_name(functions, "inferred_transition"), &["delta"]); + assert_param_names(function_by_name(functions, "singleton_transition"), &["delta"]); + assert_param_names(function_by_name(functions, "singleton_terminate"), &["next_states"]); + assert_param_names(function_by_name(functions, "fanout_verification"), &["new_states"]); +} diff --git a/silverscript-lang/tests/covenant_declaration_security_tests.rs b/silverscript-lang/tests/covenant_declaration_security_tests.rs new file mode 100644 index 0000000..eaa9d9a --- /dev/null +++ b/silverscript-lang/tests/covenant_declaration_security_tests.rs @@ -0,0 +1,558 @@ +use kaspa_consensus_core::Hash; +use kaspa_consensus_core::hashing::sighash::SigHashReusedValuesUnsync; +use kaspa_consensus_core::tx::{ + CovenantBinding, PopulatedTransaction, ScriptPublicKey, Transaction, TransactionId, TransactionInput, TransactionOutpoint, + TransactionOutput, UtxoEntry, VerifiableTransaction, +}; +use kaspa_txscript::caches::Cache; +use kaspa_txscript::covenants::CovenantsContext; +use kaspa_txscript::opcodes::codes::OpTrue; +use kaspa_txscript::script_builder::ScriptBuilder; +use kaspa_txscript::{EngineCtx, EngineFlags, TxScriptEngine, pay_to_script_hash_script}; +use kaspa_txscript_errors::TxScriptError; +use silverscript_lang::ast::Expr; +use silverscript_lang::compiler::{CompileOptions, CompiledContract, CovenantDeclCallOptions, compile_contract, struct_object}; + +const COV_A: Hash = Hash::from_bytes(*b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"); +const COV_B: Hash = Hash::from_bytes(*b"BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"); + +const AUTH_SINGLETON_SOURCE: &str = r#" + contract Counter(int init_value) { + int value = init_value; + + #[covenant.singleton] + function step(State prev_state, State[] new_states) { + require(prev_state.value >= 0); + require(new_states.length <= 1); + require(OpAuthOutputIdx(this.activeInputIndex, 0) >= 0); + } + } +"#; + +const AUTH_SINGLE_GROUP_SOURCE: &str = r#" + contract Counter(int init_value) { + int value = init_value; + + #[covenant(binding = auth, from = 1, to = 1, groups = single)] + function step(State prev_state, State[] new_states) { + require(prev_state.value >= 0); + require(new_states.length <= 1); + require(OpAuthOutputIdx(this.activeInputIndex, 0) >= 0); + } + } +"#; + +const AUTH_SINGLETON_TRANSITION_SOURCE: &str = r#" + contract Decls(int init_value) { + int value = init_value; + + #[covenant.singleton(mode = transition)] + function bump(State prev_state, int delta) : (State) { + return({ value: prev_state.value + delta }); + } + } +"#; + +const AUTH_SINGLETON_TRANSITION_TERMINATION_ALLOWED_SOURCE: &str = r#" + contract Decls(int init_value) { + int value = init_value; + + #[covenant.singleton(mode = transition, termination = allowed)] + function bump_or_terminate(State prev_state, State[] next_states) : (State[]) { + return(next_states); + } + } +"#; + +const COV_N_TO_M_SOURCE: &str = r#" + contract Pair(int init_value) { + int value = init_value; + + #[covenant(from = 2, to = 2)] + function rebalance(State[] prev_states, State[] new_states) { + require(true); + } + } +"#; + +const AUTH_SINGLETON_ARRAY_RUNTIME_SOURCE: &str = r#" + contract Counter(int init_value) { + int value = init_value; + + #[covenant.singleton] + function step(State prev_state, State[] new_states) { + require(new_states.length == 1); + require(new_states[0].value == prev_state.value + 1); + require(OpAuthOutputIdx(this.activeInputIndex, 0) >= 0); + } + } +"#; + +fn compile_state(source: &'static str, value: i64) -> CompiledContract<'static> { + compile_contract(source, &[Expr::int(value)], CompileOptions::default()).expect("compile succeeds") +} + +fn function_param_type_names(compiled: &CompiledContract<'_>, function_name: &str) -> Vec { + compiled + .ast + .functions + .iter() + .find(|function| function.name == function_name) + .unwrap_or_else(|| panic!("missing function '{function_name}'")) + .params + .iter() + .map(|param| param.type_ref.type_name()) + .collect() +} + +fn push_redeem_script(script: &[u8]) -> Vec { + ScriptBuilder::new().add_data(script).expect("push redeem script").drain() +} + +fn generated_auth_entrypoint_name(function_name: &str) -> String { + format!("__{function_name}") +} + +fn covenant_decl_sigscript(compiled: &CompiledContract<'_>, function_name: &str, args: Vec>, is_leader: bool) -> Vec { + let mut sigscript = compiled + .build_sig_script_for_covenant_decl(function_name, args, CovenantDeclCallOptions { is_leader }) + .expect("build covenant declaration sigscript"); + sigscript.extend_from_slice(&push_redeem_script(&compiled.script)); + sigscript +} + +fn state_array_arg(values: Vec) -> Expr<'static> { + values.into_iter().map(|value| struct_object(vec![("value", Expr::int(value))])).collect::>().into() +} + +fn cov_decl_nm_leader_sigscript(compiled: &CompiledContract<'_>, next_values: Vec) -> Vec { + covenant_decl_sigscript(compiled, "rebalance", vec![state_array_arg(next_values)], true) +} + +fn redeem_only_sigscript(compiled: &CompiledContract<'_>) -> Vec { + push_redeem_script(&compiled.script) +} + +fn tx_input(index: u32, signature_script: Vec) -> TransactionInput { + TransactionInput { + previous_outpoint: TransactionOutpoint { transaction_id: TransactionId::from_bytes([index as u8 + 1; 32]), index }, + signature_script, + sequence: 0, + sig_op_count: 0, + } +} + +fn covenant_output(compiled: &CompiledContract<'_>, authorizing_input: u16, covenant_id: Hash) -> TransactionOutput { + TransactionOutput { + value: 1_000, + script_public_key: pay_to_script_hash_script(&compiled.script), + covenant: Some(CovenantBinding { authorizing_input, covenant_id }), + } +} + +fn plain_covenant_output(authorizing_input: u16, covenant_id: Hash) -> TransactionOutput { + TransactionOutput { + value: 1_000, + script_public_key: ScriptPublicKey::new(0, vec![OpTrue].into()), + covenant: Some(CovenantBinding { authorizing_input, covenant_id }), + } +} + +fn covenant_utxo(compiled: &CompiledContract<'_>, covenant_id: Hash) -> UtxoEntry { + UtxoEntry::new(1_500, pay_to_script_hash_script(&compiled.script), 0, false, Some(covenant_id)) +} + +fn plain_utxo(covenant_id: Hash) -> UtxoEntry { + UtxoEntry::new(1_500, ScriptPublicKey::new(0, vec![OpTrue].into()), 0, false, Some(covenant_id)) +} + +fn execute_input_with_covenants(tx: Transaction, entries: Vec, input_idx: usize) -> Result<(), TxScriptError> { + let reused_values = SigHashReusedValuesUnsync::new(); + let sig_cache = Cache::new(10_000); + let input = tx.inputs[input_idx].clone(); + let populated = PopulatedTransaction::new(&tx, entries); + let cov_ctx = CovenantsContext::from_tx(&populated).map_err(TxScriptError::from)?; + let utxo = populated.utxo(input_idx).expect("selected input utxo"); + + let mut vm = TxScriptEngine::from_transaction_input( + &populated, + &input, + input_idx, + utxo, + EngineCtx::new(&sig_cache).with_reused(&reused_values).with_covenants_ctx(&cov_ctx), + EngineFlags { covenants_enabled: true }, + ); + vm.execute() +} + +fn assert_verify_like_error(err: TxScriptError) { + assert!(matches!(err, TxScriptError::VerifyError | TxScriptError::EvalFalse), "expected verify/eval-false, got {err:?}"); +} + +#[test] +fn singleton_allows_exactly_one_authorized_output() { + let active = compile_state(AUTH_SINGLETON_SOURCE, 10); + let out = compile_state(AUTH_SINGLETON_SOURCE, 10); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "step", vec![state_array_arg(vec![10])], false)); + let outputs = vec![covenant_output(&out, 0, COV_A)]; + let tx = Transaction::new(1, vec![input0], outputs, 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let result = execute_input_with_covenants(tx, entries, 0); + assert!(result.is_ok(), "singleton transition should succeed: {}", result.unwrap_err()); +} + +#[test] +fn singleton_rejects_two_authorized_outputs_from_same_input() { + let active = compile_state(AUTH_SINGLETON_SOURCE, 10); + let out0 = compile_state(AUTH_SINGLETON_SOURCE, 10); + let out1 = compile_state(AUTH_SINGLETON_SOURCE, 10); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "step", vec![state_array_arg(vec![10])], false)); + let outputs = vec![covenant_output(&out0, 0, COV_A), covenant_output(&out1, 0, COV_A)]; + let tx = Transaction::new(1, vec![input0], outputs, 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let err = execute_input_with_covenants(tx, entries, 0).expect_err("singleton must reject two auth outputs from one input"); + assert_verify_like_error(err); +} + +#[test] +fn singleton_transition_allows_correct_state_update() { + let active = compile_state(AUTH_SINGLETON_TRANSITION_SOURCE, 10); + let out = compile_state(AUTH_SINGLETON_TRANSITION_SOURCE, 13); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "bump", vec![Expr::int(3)], false)); + let outputs = vec![covenant_output(&out, 0, COV_A)]; + let tx = Transaction::new(1, vec![input0], outputs, 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let result = execute_input_with_covenants(tx, entries, 0); + assert!(result.is_ok(), "singleton transition should accept the correct new state: {}", result.unwrap_err()); +} + +#[test] +fn singleton_transition_rejects_mismatched_output_state() { + let active = compile_state(AUTH_SINGLETON_TRANSITION_SOURCE, 10); + let wrong_out = compile_state(AUTH_SINGLETON_TRANSITION_SOURCE, 12); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "bump", vec![Expr::int(3)], false)); + let outputs = vec![covenant_output(&wrong_out, 0, COV_A)]; + let tx = Transaction::new(1, vec![input0], outputs, 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let err = execute_input_with_covenants(tx, entries, 0).expect_err("singleton transition must reject mismatched next state"); + assert_verify_like_error(err); +} + +#[test] +fn singleton_transition_rejects_two_authorized_outputs() { + let active = compile_state(AUTH_SINGLETON_TRANSITION_SOURCE, 10); + let out0 = compile_state(AUTH_SINGLETON_TRANSITION_SOURCE, 13); + let out1 = compile_state(AUTH_SINGLETON_TRANSITION_SOURCE, 13); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "bump", vec![Expr::int(3)], false)); + let outputs = vec![covenant_output(&out0, 0, COV_A), covenant_output(&out1, 0, COV_A)]; + let tx = Transaction::new(1, vec![input0], outputs, 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let err = execute_input_with_covenants(tx, entries, 0).expect_err("singleton transition must reject two authorized outputs"); + assert_verify_like_error(err); +} + +#[test] +fn singleton_transition_rejects_missing_authorized_output() { + let active = compile_state(AUTH_SINGLETON_TRANSITION_SOURCE, 10); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "bump", vec![Expr::int(3)], false)); + let tx = Transaction::new(1, vec![input0], vec![], 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let err = execute_input_with_covenants(tx, entries, 0).expect_err("singleton transition must reject missing authorized output"); + assert_verify_like_error(err); +} + +#[test] +fn singleton_transition_termination_allowed_accepts_zero_outputs() { + let active = compile_state(AUTH_SINGLETON_TRANSITION_TERMINATION_ALLOWED_SOURCE, 10); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "bump_or_terminate", vec![state_array_arg(vec![])], false)); + let tx = Transaction::new(1, vec![input0], vec![], 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let result = execute_input_with_covenants(tx, entries, 0); + assert!( + result.is_ok(), + "singleton transition with termination=allowed should accept empty successor set: {}", + result.unwrap_err() + ); +} + +#[test] +fn singleton_transition_termination_allowed_accepts_one_output() { + let active = compile_state(AUTH_SINGLETON_TRANSITION_TERMINATION_ALLOWED_SOURCE, 10); + let out = compile_state(AUTH_SINGLETON_TRANSITION_TERMINATION_ALLOWED_SOURCE, 13); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "bump_or_terminate", vec![state_array_arg(vec![13])], false)); + let outputs = vec![covenant_output(&out, 0, COV_A)]; + let tx = Transaction::new(1, vec![input0], outputs, 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let result = execute_input_with_covenants(tx, entries, 0); + assert!(result.is_ok(), "singleton transition with one successor should succeed: {}", result.unwrap_err()); +} + +#[test] +fn singleton_transition_termination_allowed_rejects_two_outputs() { + let active = compile_state(AUTH_SINGLETON_TRANSITION_TERMINATION_ALLOWED_SOURCE, 10); + let out0 = compile_state(AUTH_SINGLETON_TRANSITION_TERMINATION_ALLOWED_SOURCE, 13); + let out1 = compile_state(AUTH_SINGLETON_TRANSITION_TERMINATION_ALLOWED_SOURCE, 14); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "bump_or_terminate", vec![state_array_arg(vec![13, 14])], false)); + let outputs = vec![covenant_output(&out0, 0, COV_A), covenant_output(&out1, 0, COV_A)]; + let tx = Transaction::new(1, vec![input0], outputs, 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let err = execute_input_with_covenants(tx, entries, 0) + .expect_err("singleton transition with termination=allowed must still reject >1 authorized outputs"); + assert_verify_like_error(err); +} + +#[test] +fn singleton_missing_authorized_output_returns_invalid_auth_index_error() { + let active = compile_state(AUTH_SINGLETON_SOURCE, 10); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "step", vec![state_array_arg(vec![])], false)); + let tx = Transaction::new(1, vec![input0], vec![], 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let err = execute_input_with_covenants(tx, entries, 0).expect_err("policy must fail when auth output slot 0 does not exist"); + assert!( + matches!(err, TxScriptError::CovenantsError(kaspa_txscript_errors::CovenantsError::InvalidAuthCovOutIndex(0, 0, 0))), + "unexpected error: {err:?}" + ); +} + +#[test] +fn auth_groups_single_rejects_parallel_group_with_same_covenant_id() { + let active = compile_state(AUTH_SINGLE_GROUP_SOURCE, 10); + let out = compile_state(AUTH_SINGLE_GROUP_SOURCE, 10); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "step", vec![state_array_arg(vec![10])], false)); + let input1 = tx_input(1, vec![]); + let outputs = vec![covenant_output(&out, 0, COV_A), plain_covenant_output(1, COV_A)]; + let tx = Transaction::new(1, vec![input0, input1], outputs, 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A), plain_utxo(COV_A)]; + + let err = + execute_input_with_covenants(tx, entries, 0).expect_err("groups=single must reject a second auth group for same covenant id"); + assert_verify_like_error(err); +} + +#[test] +fn auth_groups_single_allows_other_covenant_id() { + let active = compile_state(AUTH_SINGLE_GROUP_SOURCE, 10); + let out = compile_state(AUTH_SINGLE_GROUP_SOURCE, 10); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "step", vec![state_array_arg(vec![10])], false)); + let input1 = tx_input(1, vec![]); + let outputs = vec![covenant_output(&out, 0, COV_A), plain_covenant_output(1, COV_B)]; + let tx = Transaction::new(1, vec![input0, input1], outputs, 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A), plain_utxo(COV_B)]; + + let result = execute_input_with_covenants(tx, entries, 0); + assert!(result.is_ok(), "groups=single should not reject unrelated covenant ids: {}", result.unwrap_err()); +} + +fn build_nm_tx_for_source( + source: &'static str, + input0_sigscript: Vec, + input1_sigscript: Vec, + outputs: Vec, +) -> (Transaction, Vec) { + let in0 = compile_state(source, 10); + let in1 = compile_state(source, 7); + let tx = Transaction::new( + 1, + vec![tx_input(0, input0_sigscript), tx_input(1, input1_sigscript)], + outputs, + 0, + Default::default(), + 0, + vec![], + ); + let entries = vec![covenant_utxo(&in0, COV_A), covenant_utxo(&in1, COV_A)]; + (tx, entries) +} + +fn build_nm_tx( + input0_sigscript: Vec, + input1_sigscript: Vec, + outputs: Vec, +) -> (Transaction, Vec) { + build_nm_tx_for_source(COV_N_TO_M_SOURCE, input0_sigscript, input1_sigscript, outputs) +} + +#[test] +fn many_to_many_rejects_wrong_entrypoint_role() { + let in0 = compile_state(COV_N_TO_M_SOURCE, 10); + let in1 = compile_state(COV_N_TO_M_SOURCE, 7); + let out0 = compile_state(COV_N_TO_M_SOURCE, 10); + let out1 = compile_state(COV_N_TO_M_SOURCE, 10); + let outputs = vec![covenant_output(&out0, 0, COV_A), covenant_output(&out1, 0, COV_A)]; + + let delegate_on_leader = { + let input0_sigscript = covenant_decl_sigscript(&in0, "rebalance", vec![], false); + let input1_sigscript = covenant_decl_sigscript(&in1, "rebalance", vec![], false); + let (tx, entries) = build_nm_tx(input0_sigscript, input1_sigscript, outputs.clone()); + execute_input_with_covenants(tx, entries, 0).expect_err("leader input must reject delegate entrypoint") + }; + assert_verify_like_error(delegate_on_leader); + + let leader_on_delegate = { + let input0_sigscript = cov_decl_nm_leader_sigscript(&in0, vec![10, 10]); + let input1_sigscript = cov_decl_nm_leader_sigscript(&in1, vec![10, 10]); + let (tx, entries) = build_nm_tx(input0_sigscript, input1_sigscript, outputs); + execute_input_with_covenants(tx, entries, 1).expect_err("delegate input must reject leader entrypoint") + }; + assert_verify_like_error(leader_on_delegate); +} + +#[test] +fn many_to_many_happy_path_currently_fails_with_validate_output_state() { + let in0 = compile_state(COV_N_TO_M_SOURCE, 10); + let in1 = compile_state(COV_N_TO_M_SOURCE, 7); + let out0 = compile_state(COV_N_TO_M_SOURCE, 10); + let out1 = compile_state(COV_N_TO_M_SOURCE, 10); + assert_eq!(in0.script, out0.script, "leader input and output[0] script should match"); + assert_eq!(in0.script, out1.script, "leader input and output[1] script should match"); + + // Intended valid shape: two covenant inputs in the same id, two covenant outputs in the same id, + // leader path on input 0 and delegate path on input 1. + let input0_sigscript = cov_decl_nm_leader_sigscript(&in0, vec![10, 10]); + let input1_sigscript = covenant_decl_sigscript(&in1, "rebalance", vec![], false); + let outputs = vec![covenant_output(&out0, 0, COV_A), covenant_output(&out1, 1, COV_A)]; + let (tx, entries) = build_nm_tx(input0_sigscript, input1_sigscript, outputs); + + let leader_err = execute_input_with_covenants(tx.clone(), entries.clone(), 0) + .expect_err("leader path is expected to fail until validateOutputState fully supports selector-dispatched scripts"); + assert_verify_like_error(leader_err); + + let delegate_result = execute_input_with_covenants(tx, entries, 1); + assert!(delegate_result.is_ok(), "delegate path unexpectedly failed: {}", delegate_result.unwrap_err()); +} + +#[test] +fn many_to_many_rejects_input_count_above_from_bound() { + let in0 = compile_state(COV_N_TO_M_SOURCE, 10); + let in1 = compile_state(COV_N_TO_M_SOURCE, 7); + let in2 = compile_state(COV_N_TO_M_SOURCE, 6); + let out0 = compile_state(COV_N_TO_M_SOURCE, 10); + let out1 = compile_state(COV_N_TO_M_SOURCE, 10); + + let input0_sigscript = cov_decl_nm_leader_sigscript(&in0, vec![10, 10]); + let input1_sigscript = redeem_only_sigscript(&in1); + let input2_sigscript = redeem_only_sigscript(&in2); + let tx = Transaction::new( + 1, + vec![tx_input(0, input0_sigscript), tx_input(1, input1_sigscript), tx_input(2, input2_sigscript)], + vec![covenant_output(&out0, 0, COV_A), covenant_output(&out1, 1, COV_A)], + 0, + Default::default(), + 0, + vec![], + ); + let entries = vec![covenant_utxo(&in0, COV_A), covenant_utxo(&in1, COV_A), covenant_utxo(&in2, COV_A)]; + + let err = execute_input_with_covenants(tx, entries, 0).expect_err("wrapper must reject cov input count above from bound"); + assert_verify_like_error(err); +} + +#[test] +fn many_to_many_rejects_output_count_above_to_bound() { + let in0 = compile_state(COV_N_TO_M_SOURCE, 10); + let in1 = compile_state(COV_N_TO_M_SOURCE, 7); + let out0 = compile_state(COV_N_TO_M_SOURCE, 10); + let out1 = compile_state(COV_N_TO_M_SOURCE, 10); + + let input0_sigscript = cov_decl_nm_leader_sigscript(&in0, vec![10, 11]); + let input1_sigscript = covenant_decl_sigscript(&in1, "rebalance", vec![], false); + let outputs = vec![covenant_output(&out0, 0, COV_A), covenant_output(&out1, 1, COV_A), plain_covenant_output(0, COV_A)]; + let (tx, entries) = build_nm_tx(input0_sigscript, input1_sigscript, outputs); + + let err = execute_input_with_covenants(tx, entries, 0).expect_err("wrapper must reject cov output count above to bound"); + assert_verify_like_error(err); +} + +#[test] +fn singleton_rejects_authorized_output_with_different_script() { + let active = compile_state(AUTH_SINGLETON_SOURCE, 10); + let different = compile_state(AUTH_SINGLETON_SOURCE, 11); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "step", vec![state_array_arg(vec![10])], false)); + let tx = Transaction::new(1, vec![input0], vec![covenant_output(&different, 0, COV_A)], 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let err = execute_input_with_covenants(tx, entries, 0).expect_err("wrapper should reject authorized output with different script"); + assert_verify_like_error(err); +} + +#[test] +fn many_to_many_leader_rejects_cov_output_with_different_script() { + let in0 = compile_state(COV_N_TO_M_SOURCE, 10); + let in1 = compile_state(COV_N_TO_M_SOURCE, 7); + let out0 = compile_state(COV_N_TO_M_SOURCE, 10); + let out1_different = compile_state(COV_N_TO_M_SOURCE, 11); + + let input0_sigscript = cov_decl_nm_leader_sigscript(&in0, vec![10, 11]); + let input1_sigscript = covenant_decl_sigscript(&in1, "rebalance", vec![], false); + let outputs = vec![covenant_output(&out0, 0, COV_A), covenant_output(&out1_different, 1, COV_A)]; + let (tx, entries) = build_nm_tx(input0_sigscript, input1_sigscript, outputs); + + let err = execute_input_with_covenants(tx, entries, 0).expect_err("leader wrapper should reject cov output with different script"); + assert_verify_like_error(err); +} + +#[test] +fn runtime_accepts_state_array_entrypoint_argument_for_generated_wrapper() { + let active = compile_state(AUTH_SINGLETON_ARRAY_RUNTIME_SOURCE, 10); + let out = compile_state(AUTH_SINGLETON_ARRAY_RUNTIME_SOURCE, 11); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "step", vec![state_array_arg(vec![11])], false)); + let outputs = vec![covenant_output(&out, 0, COV_A)]; + let tx = Transaction::new(1, vec![input0], outputs, 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let result = execute_input_with_covenants(tx, entries, 0); + assert!(result.is_ok(), "generated wrapper should accept State[] entrypoint args at runtime: {}", result.unwrap_err()); +} + +#[test] +fn runtime_passes_state_array_into_generated_policy_function() { + let active = compile_state(AUTH_SINGLETON_ARRAY_RUNTIME_SOURCE, 10); + let out = compile_state(AUTH_SINGLETON_ARRAY_RUNTIME_SOURCE, 11); + + let wrapper_name = generated_auth_entrypoint_name("step"); + let wrapper_param_types = function_param_type_names(&active, &wrapper_name); + assert_eq!(wrapper_param_types, vec!["State[]".to_string()]); + + let policy = active + .ast + .functions + .iter() + .find(|function| !function.entrypoint && function.name == "__covenant_policy_step") + .expect("generated covenant policy exists"); + assert!(!policy.entrypoint, "generated covenant policy must remain non-entrypoint"); + let policy_param_types: Vec = policy.params.iter().map(|param| param.type_ref.type_name()).collect(); + assert_eq!(policy_param_types, vec!["State".to_string(), "State[]".to_string()]); + + let input0 = tx_input(0, covenant_decl_sigscript(&active, "step", vec![state_array_arg(vec![12])], false)); + let outputs = vec![covenant_output(&out, 0, COV_A)]; + let tx = Transaction::new(1, vec![input0], outputs, 0, Default::default(), 0, vec![]); + let entries = vec![covenant_utxo(&active, COV_A)]; + + let err = execute_input_with_covenants(tx, entries, 0) + .expect_err("generated policy should reject when the State[] argument content is wrong"); + assert_verify_like_error(err); +} diff --git a/silverscript-lang/tests/examples/covenant_id.sil b/silverscript-lang/tests/examples/covenant_id.sil index c0ea1b0..080f862 100644 --- a/silverscript-lang/tests/examples/covenant_id.sil +++ b/silverscript-lang/tests/examples/covenant_id.sil @@ -2,6 +2,7 @@ pragma silverscript ^0.1.0; contract CovenantId(int max_ins, int max_outs, int init_amount) { int amount = init_amount; + entrypoint function main(int[] output_amounts) { require(output_amounts.length <= max_outs); byte[32] covid = OpInputCovenantId(this.activeInputIndex); @@ -9,22 +10,25 @@ contract CovenantId(int max_ins, int max_outs, int init_amount) { int in_count = OpCovInputCount(covid); require(in_count <= max_ins); + int out_count = OpCovOutCount(covid); + require(out_count <= max_outs); + int in_sum = 0; - for(i,0,max_ins){ - if( i < in_count ){ + for(i, 0, max_ins) { + if( i < in_count ) { int in_idx = OpCovInputIdx(covid, i); - {amount: int in_amount} = readInputState(in_idx); + { amount: int in_amount } = readInputState(in_idx); in_sum = in_sum + in_amount; } } int out_sum = 0; - for(i,0,max_outs){ - if( i < output_amounts.length ){ + for(i, 0, max_outs) { + if( i < output_amounts.length ) { int out_idx = OpCovOutputIdx(covid, i); int out_amount = output_amounts[i]; out_sum = out_sum + out_amount; - validateOutputState(out_idx, {amount: out_amount}); + validateOutputState(out_idx, { amount: out_amount }); } } diff --git a/silverscript-lang/tests/parser_tests.rs b/silverscript-lang/tests/parser_tests.rs index 55e344c..7810092 100644 --- a/silverscript-lang/tests/parser_tests.rs +++ b/silverscript-lang/tests/parser_tests.rs @@ -118,3 +118,79 @@ fn parses_struct_destructuring() { assert!(parse_source_file(input).is_ok()); } + +#[test] +fn rejects_bounded_for_syntax() { + let input = r#" + contract Decls(int max_outs) { + #[covenant(binding = auth, from = 1, to = max_outs, mode = verification)] + function split() { + int dyn = tx.outputs.length; + for(i, 0, dyn, max_outs) { + require(i >= 0); + } + } + } + "#; + + let result = parse_source_file(input); + assert!(result.is_err()); +} + +#[test] +fn rejects_malformed_function_attributes() { + let bad_path_start = r#" + contract Decls() { + #[.covenant(binding = auth, from = 1, to = 1, mode = transition)] + function main() { + require(true); + } + } + "#; + assert!(parse_source_file(bad_path_start).is_err()); + + let bad_path_double_dot = r#" + contract Decls() { + #[covenant..transition(binding = auth, from = 1, to = 1, mode = transition)] + function main() { + require(true); + } + } + "#; + assert!(parse_source_file(bad_path_double_dot).is_err()); + + let bad_arg_missing_equals = r#" + contract Decls(int max_outs) { + #[covenant(binding, from = 1, to = max_outs, mode = verification)] + function main() { + require(max_outs >= 0); + } + } + "#; + assert!(parse_source_file(bad_arg_missing_equals).is_err()); +} + +#[test] +fn rejects_invalid_for_arities() { + let trailing_comma = r#" + contract Loops() { + function main() { + for(i, 0, 1,) { + require(i >= 0); + } + } + } + "#; + assert!(parse_source_file(trailing_comma).is_err()); + + let too_few_args = r#" + contract Loops() { + function main() { + for(i, 0) { + require(i >= 0); + } + } + } + "#; + assert!(parse_source_file(too_few_args).is_err()); +} diff --git a/silverscript-lang/tests/tutorial_examples_tests.rs b/silverscript-lang/tests/tutorial_examples_tests.rs index fb91585..6bb6985 100644 --- a/silverscript-lang/tests/tutorial_examples_tests.rs +++ b/silverscript-lang/tests/tutorial_examples_tests.rs @@ -2,9 +2,9 @@ use silverscript_lang::ast::parse_contract_ast; #[test] fn tutorial_contract_examples_parse() { - let markdown = include_str!("../../TUTORIAL.md"); + let markdown = include_str!("../../docs/TUTORIAL.md"); let blocks = extract_code_blocks(markdown, "javascript"); - assert!(!blocks.is_empty(), "no contract examples found in TUTORIAL.md"); + assert!(!blocks.is_empty(), "no contract examples found in docs/TUTORIAL.md"); for (index, snippet) in blocks { let source = wrap_snippet(&snippet);