Skip to content

Commit 91b738f

Browse files
authored
feat: Dollarless de-ref (#2119)
1 parent f8b23cf commit 91b738f

File tree

5 files changed

+128
-37
lines changed

5 files changed

+128
-37
lines changed

packages/typegpu/src/data/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ export { unstruct } from './unstruct.ts';
191191
export { mat2x2f, mat3x3f, mat4x4f, matToArray } from './matrix.ts';
192192
export * from './vertexFormatData.ts';
193193
export { atomic } from './atomic.ts';
194-
export { ref } from './ref.ts';
194+
export { _ref as ref } from './ref.ts';
195195
export {
196196
align,
197197
type AnyAttribute,

packages/typegpu/src/data/ref.ts

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,8 @@ import {
2222
// Public API
2323
// ----------
2424

25-
/**
26-
* A reference to a value `T`. Can be passed to other functions to give them
27-
* mutable access to the underlying value.
28-
*
29-
* Conceptually, it represents a WGSL pointer.
30-
*/
31-
export interface ref<T> {
32-
readonly [$internal]: unknown;
33-
readonly type: 'ref';
25+
interface ref<T> {
26+
readonly [$internal]: { type: 'ref' };
3427

3528
/**
3629
* Derefences the reference, and gives access to the underlying value.
@@ -47,10 +40,18 @@ export interface ref<T> {
4740
$: T;
4841
}
4942

50-
type RefFn = DualFn<(<T>(value: T) => ref<T>)> & { [$internal]: true };
43+
/**
44+
* A reference to a value `T`. Can be passed to other functions to give them
45+
* mutable access to the underlying value.
46+
*
47+
* Conceptually, it represents a WGSL pointer.
48+
*/
49+
export type _ref<T> = T extends object ? T & ref<T> : ref<T>;
50+
51+
type RefFn = DualFn<(<T>(value: T) => _ref<T>)> & { [$internal]: true };
5152

52-
export const ref = (() => {
53-
const impl = (<T>(value: T) => new refImpl(value)) as unknown as RefFn;
53+
export const _ref = (() => {
54+
const impl = (<T>(value: T) => INTERNAL_createRef(value)) as unknown as RefFn;
5455

5556
setName(impl, 'ref');
5657
impl.toString = () => 'ref';
@@ -94,39 +95,60 @@ export const ref = (() => {
9495
})();
9596

9697
export function isRef<T>(value: unknown | ref<T>): value is ref<T> {
97-
return value instanceof refImpl;
98+
return (value as ref<T>)?.[$internal]?.type === 'ref';
9899
}
99100

100101
// --------------
101102
// Implementation
102103
// --------------
103104

104-
class refImpl<T> implements ref<T> {
105-
readonly [$internal]: true;
106-
readonly type: 'ref';
107-
#value: T;
105+
export function INTERNAL_createRef<T>(value: T): ref<T> {
106+
const target = {
107+
[$internal]: { type: 'ref' },
108108

109-
constructor(value: T) {
110-
this[$internal] = true;
111-
this.type = 'ref';
112-
this.#value = value;
113-
}
109+
get $(): T {
110+
return value;
111+
},
112+
113+
set $(newValue: T) {
114+
if (newValue && typeof newValue === 'object') {
115+
// Setting an object means updating the properties of the original object.
116+
// e.g.: foo.$ = Boid();
117+
for (const key of Object.keys(newValue) as (keyof T)[]) {
118+
value[key] = newValue[key];
119+
}
120+
} else {
121+
value = newValue;
122+
}
123+
},
124+
};
114125

115-
get $(): T {
116-
return this.#value as T;
126+
if (value === undefined || value === null) {
127+
throw new Error('Cannot create a ref from undefined or null');
117128
}
118129

119-
set $(value: T) {
120-
if (value && typeof value === 'object') {
121-
// Setting an object means updating the properties of the original object.
122-
// e.g.: foo.$ = Boid();
123-
for (const key of Object.keys(value) as (keyof T)[]) {
124-
this.#value[key] = value[key];
125-
}
126-
} else {
127-
this.#value = value;
128-
}
130+
if (typeof value === 'object') {
131+
return new Proxy(target, {
132+
get(target, prop) {
133+
if (prop in target) {
134+
return target[prop as keyof typeof target];
135+
}
136+
return value[prop as keyof T];
137+
},
138+
set(_target, prop, propValue) {
139+
if (prop === $internal) {
140+
return false;
141+
}
142+
if (prop === '$') {
143+
console.log('Setting ref value:', propValue);
144+
return Reflect.set(target, prop, propValue);
145+
}
146+
return Reflect.set(value as object, prop, propValue);
147+
},
148+
}) as ref<T>;
129149
}
150+
151+
return target as ref<T>;
130152
}
131153

132154
/**

packages/typegpu/src/data/wgslTypes.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import type {
3131
WgslTexture,
3232
} from './texture.ts';
3333
import type { WgslComparisonSampler, WgslSampler } from './sampler.ts';
34-
import type { ref } from './ref.ts';
34+
import type { _ref as ref } from './ref.ts';
3535
import type { DualFn } from '../types.ts';
3636

3737
type DecoratedLocation<T extends BaseData> = Decorated<T, Location[]>;

packages/typegpu/src/std/array.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { dualImpl } from '../core/function/dualImpl.ts';
22
import { stitch } from '../core/resolve/stitch.ts';
33
import { abstractInt, u32 } from '../data/numeric.ts';
44
import { ptrFn } from '../data/ptr.ts';
5-
import { isRef, type ref } from '../data/ref.ts';
5+
import { type _ref as ref, isRef } from '../data/ref.ts';
66
import { isPtr, isWgslArray, type StorableData } from '../data/wgslTypes.ts';
77

88
const sizeOfPointedToArray = (dataType: unknown) =>

packages/typegpu/tests/ref.test.ts

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,75 @@ describe('d.ref', () => {
108108
`);
109109
});
110110

111+
it('allows updating a struct property from another function', () => {
112+
type Entity = d.Infer<typeof Entity>;
113+
const Entity = d.struct({ pos: d.vec3f });
114+
115+
const clearPosition = (entity: d.ref<Entity>) => {
116+
'use gpu';
117+
entity.pos = d.vec3f();
118+
};
119+
120+
const main = () => {
121+
'use gpu';
122+
const entity = Entity({ pos: d.vec3f(1, 2, 3) });
123+
clearPosition(d.ref(entity));
124+
// entity.pos should be vec3f(0, 0, 0)
125+
return entity;
126+
};
127+
128+
// Works in JS
129+
expect(main().pos).toStrictEqual(d.vec3f(0, 0, 0));
130+
131+
// And on the GPU
132+
expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
133+
"struct Entity {
134+
pos: vec3f,
135+
}
136+
137+
fn clearPosition(entity: ptr<function, Entity>) {
138+
(*entity).pos = vec3f();
139+
}
140+
141+
fn main() -> Entity {
142+
var entity = Entity(vec3f(1, 2, 3));
143+
clearPosition((&entity));
144+
return entity;
145+
}"
146+
`);
147+
});
148+
149+
it('allows updating a vector component from another function', () => {
150+
const clearX = (pos: d.ref<d.v3f>) => {
151+
'use gpu';
152+
pos.x = 0;
153+
};
154+
155+
const main = () => {
156+
'use gpu';
157+
const pos = d.vec3f(1, 0, 0);
158+
clearX(d.ref(pos));
159+
// pos should be vec3f(0, 0, 0)
160+
return pos;
161+
};
162+
163+
// Works in JS
164+
expect(main()).toStrictEqual(d.vec3f(0, 0, 0));
165+
166+
// And on the GPU
167+
expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
168+
"fn clearX(pos: ptr<function, vec3f>) {
169+
(*pos).x = 0f;
170+
}
171+
172+
fn main() -> vec3f {
173+
var pos = vec3f(1, 0, 0);
174+
clearX((&pos));
175+
return pos;
176+
}"
177+
`);
178+
});
179+
111180
it('allows updating a number from another function', () => {
112181
const increment = (value: d.ref<number>) => {
113182
'use gpu';

0 commit comments

Comments
 (0)