diff --git a/crates/csharp/src/AsyncSupport.cs b/crates/csharp/src/AsyncSupport.cs index 1bb5145e9..1fd93b804 100644 --- a/crates/csharp/src/AsyncSupport.cs +++ b/crates/csharp/src/AsyncSupport.cs @@ -21,25 +21,13 @@ public enum CallbackCode : uint //#define TEST_CALLBACK_CODE_WAIT(set) (2 | (set << 4)) } -public class WaitableSet(int handle) : IDisposable +// The context that we will create in unmanaged memory and pass to context_set. +// TODO: C has world specific types for these pointers, perhaps C# would benefit from those also. +[StructLayout(LayoutKind.Sequential)] +public struct ContextTask { - public int Handle { get; } = handle; - - void Dispose(bool _disposing) - { - AsyncSupport.WaitableSetDrop(handle); - } - - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } - - ~WaitableSet() - { - Dispose(false); - } + public int WaitableSetHandle; + public int FutureHandle; } public static class AsyncSupport @@ -51,9 +39,6 @@ internal static class PollWasmInterop internal static extern void wasmImportPoll(nint p0, int p1, nint p2); } - // TODO: How do we allow multiple waitable sets? - internal static WaitableSet WaitableSet; - private static class Interop { [global::System.Runtime.InteropServices.DllImport("$root", EntryPoint = "[waitable-set-new]"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] @@ -78,11 +63,11 @@ private static class Interop internal static unsafe extern ContextTask* ContextGet(); } - public static WaitableSet WaitableSetNew() + public static int WaitableSetNew() { var waitableSet = Interop.WaitableSetNew(); Console.WriteLine($"WaitableSet created with number {waitableSet}"); - return new WaitableSet(waitableSet); + return waitableSet; } public static unsafe void WaitableSetPoll(int waitableHandle) @@ -94,16 +79,16 @@ public static unsafe void WaitableSetPoll(int waitableHandle) } } - internal static void Join(SubtaskStatus subtask, WaitableSet set, WaitableInfoState waitableInfoState) + internal static void Join(SubtaskStatus subtask, int waitableSetHandle, WaitableInfoState waitableInfoState) { - AddTaskToWaitables(set.Handle, subtask.Handle, waitableInfoState); - Interop.WaitableJoin(subtask.Handle, set.Handle); + AddTaskToWaitables(waitableSetHandle, subtask.Handle, waitableInfoState); + Interop.WaitableJoin(subtask.Handle, waitableSetHandle); } - internal static void Join(int readerWriterHandle, WaitableSet set, WaitableInfoState waitableInfoState) + internal static void Join(int readerWriterHandle, int waitableHandle, WaitableInfoState waitableInfoState) { - AddTaskToWaitables(set.Handle, readerWriterHandle, waitableInfoState); - Interop.WaitableJoin(readerWriterHandle, set.Handle); + AddTaskToWaitables(waitableHandle, readerWriterHandle, waitableInfoState); + Interop.WaitableJoin(readerWriterHandle, waitableHandle); } // TODO: Revisit this to see if we can remove it. @@ -120,10 +105,10 @@ private static void AddTaskToWaitables(int waitableSetHandle, int waitableHandle waitableSetOfTasks[waitableHandle] = waitableInfoState; } - public unsafe static EventWaitable WaitableSetWait(WaitableSet set) + public unsafe static EventWaitable WaitableSetWait(int waitableSetHandle) { uint* buffer = stackalloc uint[2]; - var eventCode = (EventCode)Interop.WaitableSetWait(set.Handle, buffer); + var eventCode = (EventCode)Interop.WaitableSetWait(waitableSetHandle, buffer); return new EventWaitable(eventCode, buffer[0], buffer[1]); } @@ -132,15 +117,6 @@ public static void WaitableSetDrop(int handle) Interop.WaitableSetDrop(handle); } - // The context that we will create in unmanaged memory and pass to context_set. - // TODO: C has world specific types for these pointers, perhaps C# would benefit from those also. - [StructLayout(LayoutKind.Sequential)] - public struct ContextTask - { - public int Set; - public int FutureHandle; - } - public static unsafe void ContextSet(ContextTask* contextTask) { Interop.ContextSet(contextTask); @@ -148,18 +124,15 @@ public static unsafe void ContextSet(ContextTask* contextTask) public static unsafe ContextTask* ContextGet() { - ContextTask* contextTaskPtr = Interop.ContextGet(); - if(contextTaskPtr == null) - { - throw new Exception("null context returned."); - } - return contextTaskPtr; + return Interop.ContextGet(); } public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr, Action taskReturn) { Console.WriteLine($"Callback Event code {e.EventCode} Code {e.Code} Waitable {e.Waitable} Waitable Status {e.WaitableStatus.State}, Count {e.WaitableCount}"); - var waitables = pendingTasks[WaitableSet.Handle]; + ContextTask* contextTaskPtr = ContextGet(); + + var waitables = pendingTasks[contextTaskPtr->WaitableSetHandle]; var waitableInfoState = waitables[e.Waitable]; if (e.IsDropped) @@ -195,32 +168,38 @@ public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr, Act if (waitables.Count == 0) { - Console.WriteLine($"No more waitables for waitable {e.Waitable} in set {WaitableSet.Handle}"); + Console.WriteLine($"No more waitables for waitable {e.Waitable} in set {contextTaskPtr->WaitableSetHandle}"); taskReturn(); + ContextSet(null); + Marshal.FreeHGlobal((IntPtr)contextTaskPtr); return (uint)CallbackCode.Exit; } Console.WriteLine("More waitables in the set."); - return (uint)CallbackCode.Wait | (uint)(WaitableSet.Handle << 4); + return (uint)CallbackCode.Wait | (uint)(contextTaskPtr->WaitableSetHandle << 4); } - throw new NotImplementedException($"WaitableStatus not implemented {e.WaitableStatus.State} in set {WaitableSet.Handle}"); + throw new NotImplementedException($"WaitableStatus not implemented {e.WaitableStatus.State} in set {contextTaskPtr->WaitableSetHandle}"); } - public static Task TaskFromStatus(uint status) + public static unsafe Task TaskFromStatus(uint status) { var subtaskStatus = new SubtaskStatus(status); status = status & 0xF; if (subtaskStatus.IsSubtaskStarting || subtaskStatus.IsSubtaskStarted) { - if (WaitableSet == null) { - WaitableSet = WaitableSetNew(); - Console.WriteLine($"TaskFromStatus creating WaitableSet {WaitableSet.Handle}"); + ContextTask* contextTaskPtr = ContextGet(); + if (contextTaskPtr == null) { + contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf()); + + contextTaskPtr->WaitableSetHandle = WaitableSetNew(); + ContextSet(contextTaskPtr); + Console.WriteLine($"TaskFromStatus creating WaitableSet {contextTaskPtr->WaitableSetHandle}"); } TaskCompletionSource tcs = new TaskCompletionSource(); - AsyncSupport.Join(subtaskStatus, WaitableSet, new WaitableInfoState(tcs)); + Join(subtaskStatus, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs)); return tcs.Task; } else if (subtaskStatus.IsSubtaskReturned) @@ -233,7 +212,7 @@ public static Task TaskFromStatus(uint status) } } - public static Task TaskFromStatus(uint status, Func liftFunc) + public static unsafe Task TaskFromStatus(uint status, Func liftFunc) { var subtaskStatus = new SubtaskStatus(status); status = status & 0xF; @@ -242,9 +221,12 @@ public static Task TaskFromStatus(uint status, Func liftFunc) var tcs = new TaskCompletionSource(); if (subtaskStatus.IsSubtaskStarting || subtaskStatus.IsSubtaskStarted) { - if (WaitableSet == null) { + ContextTask* contextTaskPtr = ContextGet(); + if (contextTaskPtr == null) { + contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf()); Console.WriteLine("TaskFromStatus creating WaitableSet"); - WaitableSet = AsyncSupport.WaitableSetNew(); + contextTaskPtr->WaitableSetHandle = WaitableSetNew(); + ContextSet(contextTaskPtr); } return tcs.Task; @@ -389,14 +371,17 @@ internal unsafe Task ReadInternal(Func liftBuffer, int length) { Console.WriteLine("Read Blocked"); var tcs = new TaskCompletionSource(); - if(AsyncSupport.WaitableSet == null) + ContextTask* contextTaskPtr = AsyncSupport.ContextGet(); + if(contextTaskPtr == null) { Console.WriteLine("FutureReader Read Blocked creating WaitableSet"); - AsyncSupport.WaitableSet = AsyncSupport.WaitableSetNew(); + contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf()); + contextTaskPtr->WaitableSetHandle = AsyncSupport.WaitableSetNew(); + AsyncSupport.ContextSet(contextTaskPtr); } Console.WriteLine("blocked read before join"); - AsyncSupport.Join(Handle, AsyncSupport.WaitableSet, new WaitableInfoState(tcs, this)); + AsyncSupport.Join(Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs, this)); Console.WriteLine("blocked read after join"); return tcs.Task; } @@ -470,7 +455,7 @@ public class FutureReader(int handle, FutureVTable vTable) : ReaderBase(handl { public FutureVTable VTable { get; private set; } = vTable; - private GCHandle LiftBuffer(T buffer) + private GCHandle LiftBuffer(T buffer) { if(typeof(T) == typeof(byte)) { @@ -483,7 +468,7 @@ private GCHandle LiftBuffer(T buffer) } } - public unsafe Task Read(T buffer) + public unsafe Task Read(T buffer) { return ReadInternal(() => LiftBuffer(buffer), 1); } @@ -528,7 +513,7 @@ public class StreamReader(int handle, StreamVTable vTable) : ReaderBase(hand { public StreamVTable VTable { get; private set; } = vTable; - private GCHandle LiftBuffer(T[] buffer) + private GCHandle LiftBuffer(T[] buffer) { if(typeof(T) == typeof(byte)) { @@ -541,7 +526,7 @@ private GCHandle LiftBuffer(T[] buffer) } } - public unsafe Task Read(T[] buffer) + public unsafe Task Read(T[] buffer) { return ReadInternal(() => LiftBuffer(buffer), buffer.Length); } @@ -600,12 +585,15 @@ internal unsafe Task WriteInternal(Func lowerPayload, int length { Console.WriteLine("blocked write"); var tcs = new TaskCompletionSource(); - if(AsyncSupport.WaitableSet == null) + ContextTask* contextTaskPtr = AsyncSupport.ContextGet(); + if(contextTaskPtr == null) { - AsyncSupport.WaitableSet = AsyncSupport.WaitableSetNew(); + contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf()); + contextTaskPtr->WaitableSetHandle = AsyncSupport.WaitableSetNew(); + AsyncSupport.ContextSet(contextTaskPtr); } Console.WriteLine("blocked write before join"); - AsyncSupport.Join(Handle, AsyncSupport.WaitableSet, new WaitableInfoState(tcs, this)); + AsyncSupport.Join(Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs, this)); Console.WriteLine("blocked write after join"); return tcs.Task; } @@ -679,7 +667,6 @@ public class FutureWriter(int handle, FutureVTable vTable) : WriterBase(handl // TODO: Generate per type for this instrinsic. public Task Write() { - // TODO: Lower T return WriteInternal(() => null, 1); } @@ -719,7 +706,7 @@ public class StreamWriter(int handle, StreamVTable vTable) : WriterBase(handl private GCHandle bufferHandle; public StreamVTable VTable { get; private set; } = vTable; - private GCHandle LowerPayload(T[] payload) + private GCHandle LowerPayload(T[] payload) { if (VTable.Lower == null) { diff --git a/crates/csharp/src/function.rs b/crates/csharp/src/function.rs index ad1285585..a98a4fc4d 100644 --- a/crates/csharp/src/function.rs +++ b/crates/csharp/src/function.rs @@ -1163,7 +1163,8 @@ impl Bindgen for FunctionBindgen<'_, '_> { }}); // TODO: Defer dropping borrowed resources until a result is returned. - return (uint)CallbackCode.Wait | (uint)(AsyncSupport.WaitableSet.Handle << 4); + ContextTask* contextTaskPtr = AsyncSupport.ContextGet(); + return (uint)CallbackCode.Wait | (uint)(contextTaskPtr->WaitableSetHandle << 4); "#); } diff --git a/crates/csharp/src/interface.rs b/crates/csharp/src/interface.rs index 9c9e982bc..4da5eaefa 100644 --- a/crates/csharp/src/interface.rs +++ b/crates/csharp/src/interface.rs @@ -880,7 +880,7 @@ var {async_status_var} = {raw_name}({wasm_params}); uwriteln!( self.csharp_interop_src, r#" - return (uint)AsyncSupport.Callback(e, (AsyncSupport.ContextTask *)IntPtr.Zero, () => {camel_name}TaskReturn()); + return (uint)AsyncSupport.Callback(e, (ContextTask *)IntPtr.Zero, () => {camel_name}TaskReturn()); }} "# );