Skip to content

Commit e04950f

Browse files
committed
Added IScheduler.Marshal()
This is to allow the scheduler to ensure you're within an expected execution context, without needing to force a yield if you're already in the expected or desired context This includes major refactoring to collapse the implementation into Scheduler for simplicity. Added unit tests to check behavior of Marshal() and Yield() with respect to threads and the sync context
1 parent be0f458 commit e04950f

File tree

6 files changed

+272
-121
lines changed

6 files changed

+272
-121
lines changed

src/CoroutineScheduler.Core/IScheduler.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ public interface IScheduler
2323
/// </summary>
2424
/// <returns>An awaitable whose awaiter completes at the mercy of this scheduler.</returns>
2525
YieldTask Yield();
26+
27+
/// <summary>
28+
/// Marshal to this scheduler.<br/>
29+
/// </summary>
30+
/// <returns>An awaitable whose awaiter completes at the mercy of this scheduler for the purposes of marshalling in some way, such as between threads.</returns>
31+
YieldTask Marshal();
2632
}
2733

2834
/// <summary>
@@ -36,7 +42,6 @@ public record struct YieldTask
3642
/// <summary>
3743
/// A struct container to provide the compiler duck-typeing so an <see cref="IYieldAwaiter"/> can be <c>await</c>ed.
3844
/// </summary>
39-
/// <param name="awaiter"></param>
4045
[MethodImpl(MethodImplOptions.AggressiveInlining)]
4146
[DebuggerNonUserCode]
4247
public YieldTask(IYieldAwaiter awaiter)

src/CoroutineScheduler.Core/PublicAPI.Unshipped.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#nullable enable
22
CoroutineScheduler.IScheduler
3+
CoroutineScheduler.IScheduler.Marshal() -> CoroutineScheduler.YieldTask
34
CoroutineScheduler.IScheduler.SpawnTask(System.Func<System.Threading.Tasks.Task!>! func) -> System.Threading.Tasks.Task!
45
CoroutineScheduler.IScheduler.Yield() -> CoroutineScheduler.YieldTask
56
CoroutineScheduler.IYieldAwaiter

src/CoroutineScheduler/AsyncManualResetEvent.cs

Lines changed: 0 additions & 113 deletions
This file was deleted.

src/CoroutineScheduler/PublicAPI.Unshipped.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ CoroutineScheduler.InternalReleaseException.InternalReleaseException(string! mes
1414
CoroutineScheduler.InternalReleaseException.InternalReleaseException(string! message, System.Exception! innerException) -> void
1515
CoroutineScheduler.InternalReleaseException.InternalReleaseException(System.Runtime.Serialization.SerializationInfo! info, System.Runtime.Serialization.StreamingContext context) -> void
1616
CoroutineScheduler.Scheduler
17+
CoroutineScheduler.Scheduler.Marshal() -> CoroutineScheduler.YieldTask
1718
CoroutineScheduler.Scheduler.Resume() -> void
1819
CoroutineScheduler.Scheduler.Scheduler() -> void
1920
CoroutineScheduler.Scheduler.SpawnTask(System.Func<System.Threading.Tasks.Task!>! func) -> System.Threading.Tasks.Task!

src/CoroutineScheduler/Scheduler.cs

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#if DEBUG
22
#define DEBUG_ASYNC
33
#endif
4-
54
#if !DEBUG_ASYNC
65
using System.Diagnostics;
76
#endif
87

8+
using System.Collections.Concurrent;
99
using System.Diagnostics;
1010
using System.Runtime.CompilerServices;
1111

@@ -46,17 +46,29 @@ internal UnhandledExceptionEventArgs(Exception ex)
4646
/// </summary>
4747
public sealed class Scheduler : IScheduler
4848
{
49-
private AsyncManualResetEvent Awaiter { get; } = new AsyncManualResetEvent();
5049
private SchedulerSynchronizationContext SyncContext { get; }
50+
private YieldAwaiter NormalAwaiter { get; }
51+
private YieldAwaiter MarshalAwaiter { get; }
5152

5253
/// <summary>
5354
/// Create a new scheduler
5455
/// </summary>
5556
public Scheduler()
5657
{
57-
SyncContext = new(SyncronizationContextPost);
58+
SyncContext = new(PostContinuation);
59+
NormalAwaiter = new(this, true);
60+
MarshalAwaiter = new(this, false);
5861
}
5962

63+
private ContextCallback? Runner { get; set; }
64+
#if !DEBUG_ASYNC
65+
[DebuggerNonUserCode]
66+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
67+
#endif
68+
69+
private int? ResumingOnThread { get; set; }
70+
private bool RequiresMarshalling => !ResumingOnThread.HasValue || ResumingOnThread.Value != Thread.CurrentThread.ManagedThreadId;
71+
6072
/// <summary>
6173
/// Resume any tasks that are at the point of invocation are currently awaiting `Yield()`, and <br />
6274
/// accept any posted continuations from the syncronization context for the duration of the call.
@@ -69,19 +81,60 @@ public void Resume()
6981
{
7082
var original = SynchronizationContext.Current;
7183
SynchronizationContext.SetSynchronizationContext(SyncContext);
84+
ResumingOnThread = Thread.CurrentThread.ManagedThreadId;
7285
{
73-
Awaiter.Release();
86+
// cache this allocation
87+
[DebuggerNonUserCode]
88+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
89+
static void executionContextRunner(object obj) => (obj as Action)!();
90+
if (Runner is null)
91+
Runner = executionContextRunner!;
92+
93+
ReleaseImplicit();
94+
95+
int alreadyQueuedCount = explicitQueue.Count;
96+
while (alreadyQueuedCount > 0)
97+
{
98+
if (!explicitQueue.TryDequeue(out var workItem))
99+
throw new InternalReleaseException("Failed to dequeue the next continuation.");
100+
101+
alreadyQueuedCount--;
102+
103+
if (workItem.Context is null)
104+
workItem.Continuation();
105+
else
106+
ExecutionContext.Run(workItem.Context, Runner, workItem.Continuation);
107+
108+
ReleaseImplicit();
109+
}
74110
}
111+
ResumingOnThread = null;
75112
SynchronizationContext.SetSynchronizationContext(original);
76113
}
77114

115+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
116+
private void ReleaseImplicit()
117+
{
118+
while (implicitQueue.TryDequeue(out var workItem))
119+
workItem.Continuation(workItem.Context);
120+
}
121+
78122
/// <summary>
79123
/// Wait until <see cref="Resume()"/> is next invoked.
80124
/// </summary>
81125
[MethodImpl(MethodImplOptions.AggressiveInlining)]
82126
public YieldTask Yield()
83127
{
84-
return new(Awaiter);
128+
return new(NormalAwaiter);
129+
}
130+
131+
/// <summary>
132+
/// Wait until <see cref="Resume()"/> is next invoked.
133+
/// </summary>
134+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
135+
public YieldTask Marshal()
136+
{
137+
return new(MarshalAwaiter);
85138
}
86139

87140
/// <inheritdoc/>
@@ -129,8 +182,71 @@ async void SpawnTaskInternal()
129182
#if !DEBUG_ASYNC
130183
[DebuggerNonUserCode]
131184
#endif
132-
private void SyncronizationContextPost(SendOrPostCallback cb, object? state)
185+
private void PostContinuation(SendOrPostCallback cb, object? state)
186+
{
187+
if (RequiresMarshalling)
188+
implicitQueue.Enqueue(new() { Continuation = cb, Context = state });
189+
else
190+
cb(state);
191+
}
192+
193+
private void Yield(Action action, ExecutionContext? ctx)
194+
{
195+
explicitQueue.Enqueue(new() { Continuation = action, Context = ctx });
196+
}
197+
198+
private struct PostedWorkItem
199+
{
200+
public SendOrPostCallback Continuation { get; set; }
201+
public object? Context { get; set; }
202+
}
203+
204+
private struct YieldedWorkItem
133205
{
134-
Awaiter.Post(cb, state);
206+
public Action Continuation { get; set; }
207+
public ExecutionContext? Context { get; set; }
208+
}
209+
210+
private readonly ConcurrentQueue<YieldedWorkItem> explicitQueue = new();
211+
private readonly ConcurrentQueue<PostedWorkItem> implicitQueue = new();
212+
213+
private class YieldAwaiter : IYieldAwaiter
214+
{
215+
public Scheduler Self { get; }
216+
public bool ForceYield { get; }
217+
218+
public YieldAwaiter(Scheduler self, bool forceYield)
219+
{
220+
Self = self;
221+
ForceYield = forceYield;
222+
}
223+
224+
bool IYieldAwaiter.IsCompleted => ForceYield switch
225+
{
226+
true => false,
227+
false => !Self.RequiresMarshalling,
228+
};
229+
#if !DEBUG_ASYNC
230+
[DebuggerNonUserCode]
231+
#endif
232+
void INotifyCompletion.OnCompleted(Action continuation)
233+
{
234+
Self.Yield(continuation, ExecutionContext.Capture());
235+
}
236+
#if !DEBUG_ASYNC
237+
[DebuggerNonUserCode]
238+
#endif
239+
void ICriticalNotifyCompletion.UnsafeOnCompleted(Action continuation)
240+
{
241+
Self.Yield(continuation, null);
242+
}
243+
244+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
245+
#if !DEBUG_ASYNC
246+
[DebuggerNonUserCode]
247+
#endif
248+
void IYieldAwaiter.GetResult()
249+
{
250+
}
135251
}
136252
}

0 commit comments

Comments
 (0)