Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions apps/typegpu-docs/src/content/docs/fundamentals/utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,63 @@ Note how the function passed into `comptime` doesn't have to be marked with
`'use gpu'` and can use `Math`. That's because the function doesn't execute on the GPU, it gets
executed before the shader code gets sent to the GPU.

## Branch pruning

If a condition is known at resolution time (*comptime*), then typegpu prunes the unvisited block.
Comptime-known conditions include:
- referenced js values and operations on such, like `std.pow(userSelection, 2) < THRESHOLD` (note that these values will be concretized during shader resolution, if `userSelection` may change over time, use buffers to provide it),
- values provided via slots,
- values returned by `comptime` functions.

```ts twoslash
import tgpu from 'typegpu';
import * as d from 'typegpu/data';
const root = await tgpu.init();
// ---cut---
const counterEnabledSlot = tgpu.slot<boolean>(false);
const counter = root.createMutable(d.u32);

const myFunction = tgpu.fn([])(() => {
if (counterEnabledSlot.$) counter.$++;
});
// fn myFunction() {
//
// }

const myFunctionWithCounter = myFunction.with(counterEnabledSlot, true);
// "@group(0) @binding(0) var<storage, read_write> counter: u32;
//
// fn myFunction() {
// {
// counter++;
// }
// }"
```

:::note
Conditions involving `tgpu.const` won't be automatically pruned,
since we treat `tgpu.const` as a way to opt-out of inlining.
:::

Branch pruning also works for ternary operators.

```ts twoslash
import tgpu from 'typegpu';
import * as d from 'typegpu/data';
const root = await tgpu.init();
const counterEnabledSlot = tgpu.slot<boolean>(false);
const counter = root.createMutable(d.u32);
// ---cut---
const myFunction = tgpu.fn([])(() => {
counterEnabledSlot.$ ? counter.$++ : undefined;
});
```

:::caution
Ternary operator's condition must be a comptime-known value.
This restriction is due to WGSL having no ternary operator equivalent.
:::

## *tgpu.rawCodeSnippet*

When working on top of some existing shader code, sometimes you may know for certain that some variable will be already defined and should be accessible in the code.
Expand Down
8 changes: 8 additions & 0 deletions packages/tinyest-for-wgsl/src/parsers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ const Transpilers: Partial<
return [NODE.postUpdate, operator, argument];
},

ConditionalExpression(ctx, node) {
const test = transpile(ctx, node.test) as tinyest.Expression;
const consequent = transpile(ctx, node.consequent) as tinyest.Expression;
const alternative = transpile(ctx, node.alternate) as tinyest.Expression;

return [NODE.conditionalExpr, test, consequent, alternative];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should call it "conditional" if we want to follow the pattern

Suggested change
return [NODE.conditionalExpr, test, consequent, alternative];
return [NODE.conditional, test, consequent, alternative];

},

Literal(ctx, node) {
if (typeof node.value === 'boolean') {
return node.value;
Expand Down
9 changes: 9 additions & 0 deletions packages/tinyest/src/nodes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export const NodeTypeCatalog = {
postUpdate: 102,
stringLiteral: 103,
objectExpr: 104,
conditionalExpr: 105,
} as const;

export type NodeTypeCatalog = typeof NodeTypeCatalog;
Expand Down Expand Up @@ -205,6 +206,13 @@ export type ArrayExpression = readonly [
values: Expression[],
];

export type ConditionalExpression = readonly [
type: NodeTypeCatalog['conditionalExpr'],
test: Expression,
consequent: Expression,
alternative: Expression,
];

export type MemberAccess = readonly [
type: NodeTypeCatalog['memberAccess'],
object: Expression,
Expand Down Expand Up @@ -254,6 +262,7 @@ export type Expression =
| MemberAccess
| IndexAccess
| ArrayExpression
| ConditionalExpression
| PreUpdate
| PostUpdate
| Call
Expand Down
26 changes: 20 additions & 6 deletions packages/typegpu/src/tgsl/wgslGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ function operatorToType<

const unaryOpCodeToCodegen = {
'-': neg[$internal].gpuImpl,
'void': () => snip('', wgsl.Void, 'constant'),
} satisfies Partial<
Record<tinyest.UnaryOperator, (...args: never[]) => unknown>
>;
Expand Down Expand Up @@ -770,6 +771,21 @@ ${this.ctx.pre}}`;
);
}

if (expression[0] === NODE.conditionalExpr) {
// ternary operator
const [_, test, consequent, alternative] = expression;
const testExpression = this.expression(test);
if (isKnownAtComptime(testExpression)) {
return testExpression.value
? this.expression(consequent)
: this.expression(alternative);
} else {
throw new Error(
`Ternary operator is only supported for comptime-known checks (used with '${testExpression.value}'). For runtime checks, please use 'std.select' or if/else statements.`,
);
}
}

if (expression[0] === NODE.stringLiteral) {
return snip(expression[1], UnknownData, /* origin */ 'constant');
}
Expand All @@ -791,9 +807,8 @@ ${this.ctx.pre}}`;
statement: tinyest.Statement,
): string {
if (typeof statement === 'string') {
return `${this.ctx.pre}${
this.ctx.resolve(this.identifier(statement).value).value
};`;
const resolved = this.ctx.resolve(this.expression(statement).value).value;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does changing it from "identifier" to "expression" do? Was it intentional?

return resolved.length === 0 ? '' : `${this.ctx.pre}${resolved};`;
}

if (typeof statement === 'boolean') {
Expand Down Expand Up @@ -1066,9 +1081,8 @@ ${this.ctx.pre}else ${alternate}`;
return `${this.ctx.pre}break;`;
}

return `${this.ctx.pre}${
this.ctx.resolve(this.expression(statement).value).value
};`;
const resolved = this.ctx.resolve(this.expression(statement).value).value;
return resolved.length === 0 ? '' : `${this.ctx.pre}${resolved};`;
}
}

Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/tests/declare.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ struct Output {
"@group(0) @binding(0) var<uniform> val: f32;

fn main() -> f32 {
;

return 2f;
}"
`);
Expand Down
151 changes: 151 additions & 0 deletions packages/typegpu/tests/tgsl/ternaryOperator.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import { describe, expect } from 'vitest';
import { it } from '../utils/extendedIt.ts';
import * as d from '../../src/data/index.ts';
import * as std from '../../src/std/index.ts';
import tgpu from '../../src/index.ts';

describe('ternary operator', () => {
it('should resolve to one of the branches', () => {
const mySlot = tgpu.slot<boolean>();
const myFn = tgpu.fn([], d.u32)(() => {
return mySlot.$ ? 10 : 20;
});

expect(
tgpu.resolve([
myFn.with(mySlot, true).$name('trueFn'),
myFn.with(mySlot, false).$name('falseFn'),
]),
)
.toMatchInlineSnapshot(`
"fn falseFn() -> u32 {
return 10u;
}

fn falseFn_1() -> u32 {
return 20u;
}"
`);
});

it('should work for different comptime known expressions', () => {
const condition = true;
const comptime = tgpu['~unstable'].comptime(() => true);
const slot = tgpu.slot(true);
const derived = tgpu['~unstable'].derived(() => slot.$);

const myFn = tgpu.fn([])(() => {
// biome-ignore lint/correctness/noConstantCondition: it's a test
const a = true ? 1 : 0;
const b = std.allEq(d.vec2f(1, 2), d.vec2f(1, 2)) ? 1 : 0;
const c = condition ? 1 : 0;
const dd = comptime() ? 1 : 0;
const e = slot.$ ? 1 : 0;
const f = derived.$ ? 1 : 0;
});

expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
"fn myFn() {
const a = 1;
const b = 1;
const c = 1;
const dd = 1;
const e = 1;
const f = 1;
}"
`);
});

it('should resolve nested operators', () => {
const mySlot = tgpu.slot<number>(0);
const myFn = tgpu.fn([], d.u32)(() => {
return mySlot.$ === 1
? 10
: mySlot.$ === 2
? 20
: mySlot.$ === 3
? 30
: -1;
});

expect(
tgpu.resolve([
myFn,
myFn.with(mySlot, 1).$name('oneFn'),
myFn.with(mySlot, 2).$name('twoFn'),
myFn.with(mySlot, 3).$name('threeFn'),
]),
)
.toMatchInlineSnapshot(`
"fn threeFn() -> u32 {
return -1u;
}

fn threeFn_1() -> u32 {
return 10u;
}

fn threeFn_2() -> u32 {
return 20u;
}

fn threeFn_3() -> u32 {
return 30u;
}"
`);
});

it('should not include unused dependencies', ({ root }) => {
const mySlot = tgpu.slot<boolean>();
const myUniform = root.createUniform(d.u32);
const myReadonly = root.createReadonly(d.u32);

const myFn = tgpu.fn([], d.u32)(() => {
return mySlot.$ ? myUniform.$ : myReadonly.$;
});

expect(tgpu.resolve([myFn.with(mySlot, true).$name('trueFn')]))
.toMatchInlineSnapshot(`
"@group(0) @binding(0) var<uniform> myUniform: u32;

fn trueFn() -> u32 {
return myUniform;
}"
`);

expect(tgpu.resolve([myFn.with(mySlot, false).$name('falseFn')]))
.toMatchInlineSnapshot(`
"@group(0) @binding(0) var<storage, read> myReadonly: u32;

fn falseFn() -> u32 {
return myReadonly;
}"
`);
});

it('should handle undefined', ({ root }) => {
const counter = root.createMutable(d.u32);

const myFunction = tgpu.fn([])(() => {
// biome-ignore lint/correctness/noConstantCondition: it's a test
false ? counter.$++ : undefined;
});
expect(tgpu.resolve([myFunction])).toMatchInlineSnapshot(`
"fn myFunction() {

}"
`);
});

it('should throw when test is not comptime known', () => {
const myFn = tgpu.fn([d.u32], d.u32)((n) => {
return n > 0 ? n : -n;
});

expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(`
[Error: Resolution of the following tree failed:
- <root>
- fn:myFn: Ternary operator is only supported for comptime-known checks (used with '(n > 0u)'). For runtime checks, please use 'std.select' or if/else statements.]
`);
});
});