Skip to content

Commit 982bff1

Browse files
committed
Added IScheduler.Marshal()
This is to allow the scheduler to ensure you're within an expected execution context, without needing to force a yield if you're already in the expected or desired context This includes major refactoring to collapse the implementation into Scheduler for simplicity. Added unit tests to check behavior of Marshal() and Yield() with respect to threads and the sync context
1 parent be0f458 commit 982bff1

File tree

6 files changed

+271
-121
lines changed

6 files changed

+271
-121
lines changed

src/CoroutineScheduler.Core/IScheduler.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ public interface IScheduler
2323
/// </summary>
2424
/// <returns>An awaitable whose awaiter completes at the mercy of this scheduler.</returns>
2525
YieldTask Yield();
26+
27+
/// <summary>
28+
/// Marshal to this scheduler.<br/>
29+
/// </summary>
30+
/// <returns>An awaitable whose awaiter completes at the mercy of this scheduler for the purposes of marshalling in some way, such as between threads.</returns>
31+
YieldTask Marshal();
2632
}
2733

2834
/// <summary>
@@ -36,7 +42,6 @@ public record struct YieldTask
3642
/// <summary>
3743
/// A struct container to provide the compiler duck-typeing so an <see cref="IYieldAwaiter"/> can be <c>await</c>ed.
3844
/// </summary>
39-
/// <param name="awaiter"></param>
4045
[MethodImpl(MethodImplOptions.AggressiveInlining)]
4146
[DebuggerNonUserCode]
4247
public YieldTask(IYieldAwaiter awaiter)

src/CoroutineScheduler.Core/PublicAPI.Unshipped.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#nullable enable
22
CoroutineScheduler.IScheduler
3+
CoroutineScheduler.IScheduler.Marshal() -> CoroutineScheduler.YieldTask
34
CoroutineScheduler.IScheduler.SpawnTask(System.Func<System.Threading.Tasks.Task!>! func) -> System.Threading.Tasks.Task!
45
CoroutineScheduler.IScheduler.Yield() -> CoroutineScheduler.YieldTask
56
CoroutineScheduler.IYieldAwaiter

src/CoroutineScheduler/AsyncManualResetEvent.cs

Lines changed: 0 additions & 113 deletions
This file was deleted.

src/CoroutineScheduler/PublicAPI.Unshipped.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ CoroutineScheduler.InternalReleaseException.InternalReleaseException(string! mes
1414
CoroutineScheduler.InternalReleaseException.InternalReleaseException(string! message, System.Exception! innerException) -> void
1515
CoroutineScheduler.InternalReleaseException.InternalReleaseException(System.Runtime.Serialization.SerializationInfo! info, System.Runtime.Serialization.StreamingContext context) -> void
1616
CoroutineScheduler.Scheduler
17+
CoroutineScheduler.Scheduler.Marshal() -> CoroutineScheduler.YieldTask
1718
CoroutineScheduler.Scheduler.Resume() -> void
1819
CoroutineScheduler.Scheduler.Scheduler() -> void
1920
CoroutineScheduler.Scheduler.SpawnTask(System.Func<System.Threading.Tasks.Task!>! func) -> System.Threading.Tasks.Task!

src/CoroutineScheduler/Scheduler.cs

Lines changed: 122 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#if DEBUG
22
#define DEBUG_ASYNC
33
#endif
4-
54
#if !DEBUG_ASYNC
65
using System.Diagnostics;
76
#endif
87

8+
using System.Collections.Concurrent;
99
using System.Diagnostics;
1010
using System.Runtime.CompilerServices;
1111

@@ -46,17 +46,25 @@ internal UnhandledExceptionEventArgs(Exception ex)
4646
/// </summary>
4747
public sealed class Scheduler : IScheduler
4848
{
49-
private AsyncManualResetEvent Awaiter { get; } = new AsyncManualResetEvent();
5049
private SchedulerSynchronizationContext SyncContext { get; }
50+
private YieldAwaiter NormalAwaiter { get; }
51+
private YieldAwaiter MarshalAwaiter { get; }
5152

5253
/// <summary>
5354
/// Create a new scheduler
5455
/// </summary>
5556
public Scheduler()
5657
{
57-
SyncContext = new(SyncronizationContextPost);
58+
SyncContext = new(PostContinuation);
59+
NormalAwaiter = new(this, true);
60+
MarshalAwaiter = new(this, false);
5861
}
5962

63+
private ContextCallback? Runner { get; set; }
64+
65+
private int? ResumingOnThread { get; set; }
66+
private bool RequiresMarshalling => !ResumingOnThread.HasValue || ResumingOnThread.Value != Thread.CurrentThread.ManagedThreadId;
67+
6068
/// <summary>
6169
/// Resume any tasks that are at the point of invocation are currently awaiting `Yield()`, and <br />
6270
/// accept any posted continuations from the syncronization context for the duration of the call.
@@ -69,19 +77,59 @@ public void Resume()
6977
{
7078
var original = SynchronizationContext.Current;
7179
SynchronizationContext.SetSynchronizationContext(SyncContext);
80+
ResumingOnThread = Thread.CurrentThread.ManagedThreadId;
7281
{
73-
Awaiter.Release();
82+
// cache this allocation
83+
[DebuggerNonUserCode]
84+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
85+
static void executionContextRunner(object obj) => (obj as Action)!();
86+
Runner ??= executionContextRunner!;
87+
88+
ReleaseImplicit();
89+
90+
int alreadyQueuedCount = explicitQueue.Count;
91+
while (alreadyQueuedCount > 0)
92+
{
93+
if (!explicitQueue.TryDequeue(out var workItem))
94+
throw new InternalReleaseException("Failed to dequeue the next continuation.");
95+
96+
alreadyQueuedCount--;
97+
98+
if (workItem.Context is null)
99+
workItem.Continuation();
100+
else
101+
ExecutionContext.Run(workItem.Context, Runner, workItem.Continuation);
102+
103+
ReleaseImplicit();
104+
}
74105
}
106+
ResumingOnThread = null;
75107
SynchronizationContext.SetSynchronizationContext(original);
76108
}
77109

110+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
111+
private void ReleaseImplicit()
112+
{
113+
while (implicitQueue.TryDequeue(out var workItem))
114+
workItem.Continuation(workItem.Context);
115+
}
116+
78117
/// <summary>
79118
/// Wait until <see cref="Resume()"/> is next invoked.
80119
/// </summary>
81120
[MethodImpl(MethodImplOptions.AggressiveInlining)]
82121
public YieldTask Yield()
83122
{
84-
return new(Awaiter);
123+
return new(NormalAwaiter);
124+
}
125+
126+
/// <summary>
127+
/// Wait until <see cref="Resume()"/> is next invoked.
128+
/// </summary>
129+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
130+
public YieldTask Marshal()
131+
{
132+
return new(MarshalAwaiter);
85133
}
86134

87135
/// <inheritdoc/>
@@ -129,8 +177,75 @@ async void SpawnTaskInternal()
129177
#if !DEBUG_ASYNC
130178
[DebuggerNonUserCode]
131179
#endif
132-
private void SyncronizationContextPost(SendOrPostCallback cb, object? state)
180+
private void PostContinuation(SendOrPostCallback cb, object? state)
181+
{
182+
if (RequiresMarshalling)
183+
implicitQueue.Enqueue(new() { Continuation = cb, Context = state });
184+
else
185+
cb(state);
186+
}
187+
188+
#if !DEBUG_ASYNC
189+
[DebuggerNonUserCode]
190+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
191+
#endif
192+
private void Yield(Action action, ExecutionContext? ctx)
133193
{
134-
Awaiter.Post(cb, state);
194+
explicitQueue.Enqueue(new() { Continuation = action, Context = ctx });
195+
}
196+
197+
private struct PostedWorkItem
198+
{
199+
public SendOrPostCallback Continuation { get; set; }
200+
public object? Context { get; set; }
201+
}
202+
203+
private struct YieldedWorkItem
204+
{
205+
public Action Continuation { get; set; }
206+
public ExecutionContext? Context { get; set; }
207+
}
208+
209+
private readonly ConcurrentQueue<YieldedWorkItem> explicitQueue = new();
210+
private readonly ConcurrentQueue<PostedWorkItem> implicitQueue = new();
211+
212+
private class YieldAwaiter : IYieldAwaiter
213+
{
214+
public Scheduler Self { get; }
215+
public bool ForceYield { get; }
216+
217+
public YieldAwaiter(Scheduler self, bool forceYield)
218+
{
219+
Self = self;
220+
ForceYield = forceYield;
221+
}
222+
223+
bool IYieldAwaiter.IsCompleted => ForceYield switch
224+
{
225+
true => false,
226+
false => !Self.RequiresMarshalling,
227+
};
228+
#if !DEBUG_ASYNC
229+
[DebuggerNonUserCode]
230+
#endif
231+
void INotifyCompletion.OnCompleted(Action continuation)
232+
{
233+
Self.Yield(continuation, ExecutionContext.Capture());
234+
}
235+
#if !DEBUG_ASYNC
236+
[DebuggerNonUserCode]
237+
#endif
238+
void ICriticalNotifyCompletion.UnsafeOnCompleted(Action continuation)
239+
{
240+
Self.Yield(continuation, null);
241+
}
242+
243+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
244+
#if !DEBUG_ASYNC
245+
[DebuggerNonUserCode]
246+
#endif
247+
void IYieldAwaiter.GetResult()
248+
{
249+
}
135250
}
136251
}

0 commit comments

Comments
 (0)