diff --git a/apps/typegpu-docs/src/content/docs/fundamentals/utils.mdx b/apps/typegpu-docs/src/content/docs/fundamentals/utils.mdx index 5c68c0ecb8..6a4c392697 100644 --- a/apps/typegpu-docs/src/content/docs/fundamentals/utils.mdx +++ b/apps/typegpu-docs/src/content/docs/fundamentals/utils.mdx @@ -163,6 +163,63 @@ Note how the function passed into `comptime` doesn't have to be marked with `'use gpu'` and can use `Math`. That's because the function doesn't execute on the GPU, it gets executed before the shader code gets sent to the GPU. +## Branch pruning + +If a condition is known at resolution time (*comptime*), then typegpu prunes the unvisited block. +Comptime-known conditions include: +- referenced js values and operations on such, like `std.pow(userSelection, 2) < THRESHOLD` (note that these values will be concretized during shader resolution, if `userSelection` may change over time, use buffers to provide it), +- values provided via slots, +- values returned by `comptime` functions. + +```ts twoslash +import tgpu from 'typegpu'; +import * as d from 'typegpu/data'; +const root = await tgpu.init(); +// ---cut--- +const counterEnabledSlot = tgpu.slot(false); +const counter = root.createMutable(d.u32); + +const myFunction = tgpu.fn([])(() => { + if (counterEnabledSlot.$) counter.$++; +}); +// fn myFunction() { +// +// } + +const myFunctionWithCounter = myFunction.with(counterEnabledSlot, true); +// "@group(0) @binding(0) var counter: u32; +// +// fn myFunction() { +// { +// counter++; +// } +// }" +``` + +:::note +Conditions involving `tgpu.const` won't be automatically pruned, +since we treat `tgpu.const` as a way to opt-out of inlining. +::: + +Branch pruning also works for ternary operators. + +```ts twoslash +import tgpu from 'typegpu'; +import * as d from 'typegpu/data'; +const root = await tgpu.init(); +const counterEnabledSlot = tgpu.slot(false); +const counter = root.createMutable(d.u32); +// ---cut--- +const myFunction = tgpu.fn([])(() => { + counterEnabledSlot.$ ? counter.$++ : undefined; +}); +``` + +:::caution +Ternary operator's condition must be a comptime-known value. +This restriction is due to WGSL having no ternary operator equivalent. +::: + ## *tgpu.rawCodeSnippet* When working on top of some existing shader code, sometimes you may know for certain that some variable will be already defined and should be accessible in the code. diff --git a/packages/tinyest-for-wgsl/src/parsers.ts b/packages/tinyest-for-wgsl/src/parsers.ts index 3f0713365d..ffb6170321 100644 --- a/packages/tinyest-for-wgsl/src/parsers.ts +++ b/packages/tinyest-for-wgsl/src/parsers.ts @@ -150,6 +150,14 @@ const Transpilers: Partial< return [NODE.postUpdate, operator, argument]; }, + ConditionalExpression(ctx, node) { + const test = transpile(ctx, node.test) as tinyest.Expression; + const consequent = transpile(ctx, node.consequent) as tinyest.Expression; + const alternative = transpile(ctx, node.alternate) as tinyest.Expression; + + return [NODE.conditionalExpr, test, consequent, alternative]; + }, + Literal(ctx, node) { if (typeof node.value === 'boolean') { return node.value; diff --git a/packages/tinyest/src/nodes.ts b/packages/tinyest/src/nodes.ts index 40bb78efab..4381ae3ea0 100644 --- a/packages/tinyest/src/nodes.ts +++ b/packages/tinyest/src/nodes.ts @@ -30,6 +30,7 @@ export const NodeTypeCatalog = { postUpdate: 102, stringLiteral: 103, objectExpr: 104, + conditionalExpr: 105, } as const; export type NodeTypeCatalog = typeof NodeTypeCatalog; @@ -205,6 +206,13 @@ export type ArrayExpression = readonly [ values: Expression[], ]; +export type ConditionalExpression = readonly [ + type: NodeTypeCatalog['conditionalExpr'], + test: Expression, + consequent: Expression, + alternative: Expression, +]; + export type MemberAccess = readonly [ type: NodeTypeCatalog['memberAccess'], object: Expression, @@ -254,6 +262,7 @@ export type Expression = | MemberAccess | IndexAccess | ArrayExpression + | ConditionalExpression | PreUpdate | PostUpdate | Call diff --git a/packages/typegpu/src/tgsl/wgslGenerator.ts b/packages/typegpu/src/tgsl/wgslGenerator.ts index 2c95b96d9e..9bc14da108 100644 --- a/packages/typegpu/src/tgsl/wgslGenerator.ts +++ b/packages/typegpu/src/tgsl/wgslGenerator.ts @@ -160,6 +160,7 @@ function operatorToType< const unaryOpCodeToCodegen = { '-': neg[$internal].gpuImpl, + 'void': () => snip('', wgsl.Void, 'constant'), } satisfies Partial< Record unknown> >; @@ -770,6 +771,21 @@ ${this.ctx.pre}}`; ); } + if (expression[0] === NODE.conditionalExpr) { + // ternary operator + const [_, test, consequent, alternative] = expression; + const testExpression = this.expression(test); + if (isKnownAtComptime(testExpression)) { + return testExpression.value + ? this.expression(consequent) + : this.expression(alternative); + } else { + throw new Error( + `Ternary operator is only supported for comptime-known checks (used with '${testExpression.value}'). For runtime checks, please use 'std.select' or if/else statements.`, + ); + } + } + if (expression[0] === NODE.stringLiteral) { return snip(expression[1], UnknownData, /* origin */ 'constant'); } @@ -791,9 +807,8 @@ ${this.ctx.pre}}`; statement: tinyest.Statement, ): string { if (typeof statement === 'string') { - return `${this.ctx.pre}${ - this.ctx.resolve(this.identifier(statement).value).value - };`; + const resolved = this.ctx.resolve(this.expression(statement).value).value; + return resolved.length === 0 ? '' : `${this.ctx.pre}${resolved};`; } if (typeof statement === 'boolean') { @@ -1066,9 +1081,8 @@ ${this.ctx.pre}else ${alternate}`; return `${this.ctx.pre}break;`; } - return `${this.ctx.pre}${ - this.ctx.resolve(this.expression(statement).value).value - };`; + const resolved = this.ctx.resolve(this.expression(statement).value).value; + return resolved.length === 0 ? '' : `${this.ctx.pre}${resolved};`; } } diff --git a/packages/typegpu/tests/declare.test.ts b/packages/typegpu/tests/declare.test.ts index 3985550671..20a792ea82 100644 --- a/packages/typegpu/tests/declare.test.ts +++ b/packages/typegpu/tests/declare.test.ts @@ -112,7 +112,7 @@ struct Output { "@group(0) @binding(0) var val: f32; fn main() -> f32 { - ; + return 2f; }" `); diff --git a/packages/typegpu/tests/tgsl/ternaryOperator.test.ts b/packages/typegpu/tests/tgsl/ternaryOperator.test.ts new file mode 100644 index 0000000000..62d85a7555 --- /dev/null +++ b/packages/typegpu/tests/tgsl/ternaryOperator.test.ts @@ -0,0 +1,151 @@ +import { describe, expect } from 'vitest'; +import { it } from '../utils/extendedIt.ts'; +import * as d from '../../src/data/index.ts'; +import * as std from '../../src/std/index.ts'; +import tgpu from '../../src/index.ts'; + +describe('ternary operator', () => { + it('should resolve to one of the branches', () => { + const mySlot = tgpu.slot(); + const myFn = tgpu.fn([], d.u32)(() => { + return mySlot.$ ? 10 : 20; + }); + + expect( + tgpu.resolve([ + myFn.with(mySlot, true).$name('trueFn'), + myFn.with(mySlot, false).$name('falseFn'), + ]), + ) + .toMatchInlineSnapshot(` + "fn falseFn() -> u32 { + return 10u; + } + + fn falseFn_1() -> u32 { + return 20u; + }" + `); + }); + + it('should work for different comptime known expressions', () => { + const condition = true; + const comptime = tgpu['~unstable'].comptime(() => true); + const slot = tgpu.slot(true); + const derived = tgpu['~unstable'].derived(() => slot.$); + + const myFn = tgpu.fn([])(() => { + // biome-ignore lint/correctness/noConstantCondition: it's a test + const a = true ? 1 : 0; + const b = std.allEq(d.vec2f(1, 2), d.vec2f(1, 2)) ? 1 : 0; + const c = condition ? 1 : 0; + const dd = comptime() ? 1 : 0; + const e = slot.$ ? 1 : 0; + const f = derived.$ ? 1 : 0; + }); + + expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(` + "fn myFn() { + const a = 1; + const b = 1; + const c = 1; + const dd = 1; + const e = 1; + const f = 1; + }" + `); + }); + + it('should resolve nested operators', () => { + const mySlot = tgpu.slot(0); + const myFn = tgpu.fn([], d.u32)(() => { + return mySlot.$ === 1 + ? 10 + : mySlot.$ === 2 + ? 20 + : mySlot.$ === 3 + ? 30 + : -1; + }); + + expect( + tgpu.resolve([ + myFn, + myFn.with(mySlot, 1).$name('oneFn'), + myFn.with(mySlot, 2).$name('twoFn'), + myFn.with(mySlot, 3).$name('threeFn'), + ]), + ) + .toMatchInlineSnapshot(` + "fn threeFn() -> u32 { + return -1u; + } + + fn threeFn_1() -> u32 { + return 10u; + } + + fn threeFn_2() -> u32 { + return 20u; + } + + fn threeFn_3() -> u32 { + return 30u; + }" + `); + }); + + it('should not include unused dependencies', ({ root }) => { + const mySlot = tgpu.slot(); + const myUniform = root.createUniform(d.u32); + const myReadonly = root.createReadonly(d.u32); + + const myFn = tgpu.fn([], d.u32)(() => { + return mySlot.$ ? myUniform.$ : myReadonly.$; + }); + + expect(tgpu.resolve([myFn.with(mySlot, true).$name('trueFn')])) + .toMatchInlineSnapshot(` + "@group(0) @binding(0) var myUniform: u32; + + fn trueFn() -> u32 { + return myUniform; + }" + `); + + expect(tgpu.resolve([myFn.with(mySlot, false).$name('falseFn')])) + .toMatchInlineSnapshot(` + "@group(0) @binding(0) var myReadonly: u32; + + fn falseFn() -> u32 { + return myReadonly; + }" + `); + }); + + it('should handle undefined', ({ root }) => { + const counter = root.createMutable(d.u32); + + const myFunction = tgpu.fn([])(() => { + // biome-ignore lint/correctness/noConstantCondition: it's a test + false ? counter.$++ : undefined; + }); + expect(tgpu.resolve([myFunction])).toMatchInlineSnapshot(` + "fn myFunction() { + + }" + `); + }); + + it('should throw when test is not comptime known', () => { + const myFn = tgpu.fn([d.u32], d.u32)((n) => { + return n > 0 ? n : -n; + }); + + expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(` + [Error: Resolution of the following tree failed: + - + - fn:myFn: Ternary operator is only supported for comptime-known checks (used with '(n > 0u)'). For runtime checks, please use 'std.select' or if/else statements.] + `); + }); +});