Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 57 additions & 70 deletions crates/csharp/src/AsyncSupport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,13 @@ public enum CallbackCode : uint
//#define TEST_CALLBACK_CODE_WAIT(set) (2 | (set << 4))
}

public class WaitableSet(int handle) : IDisposable
// The context that we will create in unmanaged memory and pass to context_set.
// TODO: C has world specific types for these pointers, perhaps C# would benefit from those also.
[StructLayout(LayoutKind.Sequential)]
public struct ContextTask
{
public int Handle { get; } = handle;

void Dispose(bool _disposing)
{
AsyncSupport.WaitableSetDrop(handle);
}

public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}

~WaitableSet()
{
Dispose(false);
}
public int WaitableSetHandle;
public int FutureHandle;
}

public static class AsyncSupport
Expand All @@ -51,9 +39,6 @@ internal static class PollWasmInterop
internal static extern void wasmImportPoll(nint p0, int p1, nint p2);
}

// TODO: How do we allow multiple waitable sets?
internal static WaitableSet WaitableSet;

private static class Interop
{
[global::System.Runtime.InteropServices.DllImport("$root", EntryPoint = "[waitable-set-new]"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute]
Expand All @@ -78,11 +63,11 @@ private static class Interop
internal static unsafe extern ContextTask* ContextGet();
}

public static WaitableSet WaitableSetNew()
public static int WaitableSetNew()
{
var waitableSet = Interop.WaitableSetNew();
Console.WriteLine($"WaitableSet created with number {waitableSet}");
return new WaitableSet(waitableSet);
return waitableSet;
}

public static unsafe void WaitableSetPoll(int waitableHandle)
Expand All @@ -94,16 +79,16 @@ public static unsafe void WaitableSetPoll(int waitableHandle)
}
}

internal static void Join(SubtaskStatus subtask, WaitableSet set, WaitableInfoState waitableInfoState)
internal static void Join(SubtaskStatus subtask, int waitableSetHandle, WaitableInfoState waitableInfoState)
{
AddTaskToWaitables(set.Handle, subtask.Handle, waitableInfoState);
Interop.WaitableJoin(subtask.Handle, set.Handle);
AddTaskToWaitables(waitableSetHandle, subtask.Handle, waitableInfoState);
Interop.WaitableJoin(subtask.Handle, waitableSetHandle);
}

internal static void Join(int readerWriterHandle, WaitableSet set, WaitableInfoState waitableInfoState)
internal static void Join(int readerWriterHandle, int waitableHandle, WaitableInfoState waitableInfoState)
{
AddTaskToWaitables(set.Handle, readerWriterHandle, waitableInfoState);
Interop.WaitableJoin(readerWriterHandle, set.Handle);
AddTaskToWaitables(waitableHandle, readerWriterHandle, waitableInfoState);
Interop.WaitableJoin(readerWriterHandle, waitableHandle);
}

// TODO: Revisit this to see if we can remove it.
Expand All @@ -120,10 +105,10 @@ private static void AddTaskToWaitables(int waitableSetHandle, int waitableHandle
waitableSetOfTasks[waitableHandle] = waitableInfoState;
}

public unsafe static EventWaitable WaitableSetWait(WaitableSet set)
public unsafe static EventWaitable WaitableSetWait(int waitableSetHandle)
{
uint* buffer = stackalloc uint[2];
var eventCode = (EventCode)Interop.WaitableSetWait(set.Handle, buffer);
var eventCode = (EventCode)Interop.WaitableSetWait(waitableSetHandle, buffer);
return new EventWaitable(eventCode, buffer[0], buffer[1]);
}

Expand All @@ -132,34 +117,22 @@ public static void WaitableSetDrop(int handle)
Interop.WaitableSetDrop(handle);
}

// The context that we will create in unmanaged memory and pass to context_set.
// TODO: C has world specific types for these pointers, perhaps C# would benefit from those also.
[StructLayout(LayoutKind.Sequential)]
public struct ContextTask
{
public int Set;
public int FutureHandle;
}

public static unsafe void ContextSet(ContextTask* contextTask)
{
Interop.ContextSet(contextTask);
}

public static unsafe ContextTask* ContextGet()
{
ContextTask* contextTaskPtr = Interop.ContextGet();
if(contextTaskPtr == null)
{
throw new Exception("null context returned.");
}
return contextTaskPtr;
return Interop.ContextGet();
}

public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr, Action taskReturn)
{
Console.WriteLine($"Callback Event code {e.EventCode} Code {e.Code} Waitable {e.Waitable} Waitable Status {e.WaitableStatus.State}, Count {e.WaitableCount}");
var waitables = pendingTasks[WaitableSet.Handle];
ContextTask* contextTaskPtr = ContextGet();

var waitables = pendingTasks[contextTaskPtr->WaitableSetHandle];
var waitableInfoState = waitables[e.Waitable];

if (e.IsDropped)
Expand Down Expand Up @@ -195,32 +168,38 @@ public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr, Act

if (waitables.Count == 0)
{
Console.WriteLine($"No more waitables for waitable {e.Waitable} in set {WaitableSet.Handle}");
Console.WriteLine($"No more waitables for waitable {e.Waitable} in set {contextTaskPtr->WaitableSetHandle}");
taskReturn();
ContextSet(null);
Marshal.FreeHGlobal((IntPtr)contextTaskPtr);
return (uint)CallbackCode.Exit;
}

Console.WriteLine("More waitables in the set.");
return (uint)CallbackCode.Wait | (uint)(WaitableSet.Handle << 4);
return (uint)CallbackCode.Wait | (uint)(contextTaskPtr->WaitableSetHandle << 4);
}

throw new NotImplementedException($"WaitableStatus not implemented {e.WaitableStatus.State} in set {WaitableSet.Handle}");
throw new NotImplementedException($"WaitableStatus not implemented {e.WaitableStatus.State} in set {contextTaskPtr->WaitableSetHandle}");
}

public static Task TaskFromStatus(uint status)
public static unsafe Task TaskFromStatus(uint status)
{
var subtaskStatus = new SubtaskStatus(status);
status = status & 0xF;

if (subtaskStatus.IsSubtaskStarting || subtaskStatus.IsSubtaskStarted)
{
if (WaitableSet == null) {
WaitableSet = WaitableSetNew();
Console.WriteLine($"TaskFromStatus creating WaitableSet {WaitableSet.Handle}");
ContextTask* contextTaskPtr = ContextGet();
if (contextTaskPtr == null) {
contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf<ContextTask>());

contextTaskPtr->WaitableSetHandle = WaitableSetNew();
ContextSet(contextTaskPtr);
Console.WriteLine($"TaskFromStatus creating WaitableSet {contextTaskPtr->WaitableSetHandle}");
}

TaskCompletionSource tcs = new TaskCompletionSource();
AsyncSupport.Join(subtaskStatus, WaitableSet, new WaitableInfoState(tcs));
Join(subtaskStatus, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs));
return tcs.Task;
}
else if (subtaskStatus.IsSubtaskReturned)
Expand All @@ -233,7 +212,7 @@ public static Task TaskFromStatus(uint status)
}
}

public static Task<T> TaskFromStatus<T>(uint status, Func<T> liftFunc)
public static unsafe Task<T> TaskFromStatus<T>(uint status, Func<T> liftFunc)
{
var subtaskStatus = new SubtaskStatus(status);
status = status & 0xF;
Expand All @@ -242,9 +221,12 @@ public static Task<T> TaskFromStatus<T>(uint status, Func<T> liftFunc)
var tcs = new TaskCompletionSource<T>();
if (subtaskStatus.IsSubtaskStarting || subtaskStatus.IsSubtaskStarted)
{
if (WaitableSet == null) {
ContextTask* contextTaskPtr = ContextGet();
if (contextTaskPtr == null) {
contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf<ContextTask>());
Console.WriteLine("TaskFromStatus<T> creating WaitableSet");
WaitableSet = AsyncSupport.WaitableSetNew();
contextTaskPtr->WaitableSetHandle = WaitableSetNew();
ContextSet(contextTaskPtr);
}

return tcs.Task;
Expand Down Expand Up @@ -389,14 +371,17 @@ internal unsafe Task<int> ReadInternal(Func<GCHandle?> liftBuffer, int length)
{
Console.WriteLine("Read Blocked");
var tcs = new TaskCompletionSource<int>();
if(AsyncSupport.WaitableSet == null)
ContextTask* contextTaskPtr = AsyncSupport.ContextGet();
if(contextTaskPtr == null)
{
Console.WriteLine("FutureReader Read Blocked creating WaitableSet");
AsyncSupport.WaitableSet = AsyncSupport.WaitableSetNew();
contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf<ContextTask>());
contextTaskPtr->WaitableSetHandle = AsyncSupport.WaitableSetNew();
AsyncSupport.ContextSet(contextTaskPtr);
}
Console.WriteLine("blocked read before join");

AsyncSupport.Join(Handle, AsyncSupport.WaitableSet, new WaitableInfoState(tcs, this));
AsyncSupport.Join(Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs, this));
Console.WriteLine("blocked read after join");
return tcs.Task;
}
Expand Down Expand Up @@ -470,7 +455,7 @@ public class FutureReader<T>(int handle, FutureVTable vTable) : ReaderBase(handl
{
public FutureVTable VTable { get; private set; } = vTable;

private GCHandle LiftBuffer<T>(T buffer)
private GCHandle LiftBuffer(T buffer)
{
if(typeof(T) == typeof(byte))
{
Expand All @@ -483,7 +468,7 @@ private GCHandle LiftBuffer<T>(T buffer)
}
}

public unsafe Task Read<T>(T buffer)
public unsafe Task Read(T buffer)
{
return ReadInternal(() => LiftBuffer(buffer), 1);
}
Expand Down Expand Up @@ -528,7 +513,7 @@ public class StreamReader<T>(int handle, StreamVTable vTable) : ReaderBase(hand
{
public StreamVTable VTable { get; private set; } = vTable;

private GCHandle LiftBuffer<T>(T[] buffer)
private GCHandle LiftBuffer(T[] buffer)
{
if(typeof(T) == typeof(byte))
{
Expand All @@ -541,7 +526,7 @@ private GCHandle LiftBuffer<T>(T[] buffer)
}
}

public unsafe Task<int> Read<T>(T[] buffer)
public unsafe Task<int> Read(T[] buffer)
{
return ReadInternal(() => LiftBuffer(buffer), buffer.Length);
}
Expand Down Expand Up @@ -600,12 +585,15 @@ internal unsafe Task<int> WriteInternal(Func<GCHandle?> lowerPayload, int length
{
Console.WriteLine("blocked write");
var tcs = new TaskCompletionSource<int>();
if(AsyncSupport.WaitableSet == null)
ContextTask* contextTaskPtr = AsyncSupport.ContextGet();
if(contextTaskPtr == null)
{
AsyncSupport.WaitableSet = AsyncSupport.WaitableSetNew();
contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf<ContextTask>());
contextTaskPtr->WaitableSetHandle = AsyncSupport.WaitableSetNew();
AsyncSupport.ContextSet(contextTaskPtr);
}
Console.WriteLine("blocked write before join");
AsyncSupport.Join(Handle, AsyncSupport.WaitableSet, new WaitableInfoState(tcs, this));
AsyncSupport.Join(Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs, this));
Console.WriteLine("blocked write after join");
return tcs.Task;
}
Expand Down Expand Up @@ -679,7 +667,6 @@ public class FutureWriter<T>(int handle, FutureVTable vTable) : WriterBase(handl
// TODO: Generate per type for this instrinsic.
public Task Write()
{
// TODO: Lower T
return WriteInternal(() => null, 1);
}

Expand Down Expand Up @@ -719,7 +706,7 @@ public class StreamWriter<T>(int handle, StreamVTable vTable) : WriterBase(handl
private GCHandle bufferHandle;
public StreamVTable VTable { get; private set; } = vTable;

private GCHandle LowerPayload<T>(T[] payload)
private GCHandle LowerPayload(T[] payload)
{
if (VTable.Lower == null)
{
Expand Down
3 changes: 2 additions & 1 deletion crates/csharp/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,8 @@ impl Bindgen for FunctionBindgen<'_, '_> {
}});

// TODO: Defer dropping borrowed resources until a result is returned.
return (uint)CallbackCode.Wait | (uint)(AsyncSupport.WaitableSet.Handle << 4);
ContextTask* contextTaskPtr = AsyncSupport.ContextGet();
return (uint)CallbackCode.Wait | (uint)(contextTaskPtr->WaitableSetHandle << 4);
"#);
}

Expand Down
2 changes: 1 addition & 1 deletion crates/csharp/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ var {async_status_var} = {raw_name}({wasm_params});
uwriteln!(
self.csharp_interop_src,
r#"
return (uint)AsyncSupport.Callback(e, (AsyncSupport.ContextTask *)IntPtr.Zero, () => {camel_name}TaskReturn());
return (uint)AsyncSupport.Callback(e, (ContextTask *)IntPtr.Zero, () => {camel_name}TaskReturn());
}}
"#
);
Expand Down
Loading