Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices.Execution;

internal struct SynchronizationContextPreservingAsyncTaskMethodBuilder
{
private AsyncTaskMethodBuilder _inner;

public static SynchronizationContextPreservingAsyncTaskMethodBuilder Create()
=> new() { _inner = AsyncTaskMethodBuilder.Create() };

public SynchronizationContextPreservingTask Task
=> new SynchronizationContextPreservingTask(_inner.Task);

public void SetResult()
=> _inner.SetResult();

public void SetException(Exception ex)
=> _inner.SetException(ex);

public void SetStateMachine(IAsyncStateMachine stateMachine)
=> _inner.SetStateMachine(stateMachine);

#pragma warning disable
public void Start<TStateMachine>(ref TStateMachine stateMachine)
#pragma warning restore
where TStateMachine : IAsyncStateMachine
// Start is the whole reason why we have this custom builder.
// BCL implementation restores back SynchronizationContext.
// See https://github.com/dotnet/runtime/blob/c591f971241e7074f8a31ccde744aec9794e2500/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncMethodBuilderCore.cs#L45-L46
// We want to avoid restoring the SynchronizationContext for the case when the task completes synchronously on the same thread.
// This allows TestInitialize to set SynchronizationContext, and lets us be still fully async in our implementation.
// But then TestMethod can see the correct SynchronizationContext in the case of TestInitialize completing synchronously.
=> stateMachine.MoveNext();

public void AwaitOnCompleted<TAwaiter, TStateMachine>(
ref TAwaiter awaiter, ref TStateMachine stateMachine)
where TAwaiter : INotifyCompletion
where TStateMachine : IAsyncStateMachine
=> _inner.AwaitOnCompleted(ref awaiter, ref stateMachine);

public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
ref TAwaiter awaiter, ref TStateMachine stateMachine)
where TAwaiter : ICriticalNotifyCompletion
where TStateMachine : IAsyncStateMachine
=> _inner.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine);
}

internal struct SynchronizationContextPreservingAsyncTaskMethodBuilder<TResult>
{
private AsyncTaskMethodBuilder<TResult> _inner;

public static SynchronizationContextPreservingAsyncTaskMethodBuilder<TResult> Create()
=> new() { _inner = AsyncTaskMethodBuilder<TResult>.Create() };

public SynchronizationContextPreservingTask<TResult> Task
=> new SynchronizationContextPreservingTask<TResult>(_inner.Task);

public void SetResult(TResult result)
=> _inner.SetResult(result);

public void SetException(Exception ex)
=> _inner.SetException(ex);

public void SetStateMachine(IAsyncStateMachine stateMachine)
=> _inner.SetStateMachine(stateMachine);

#pragma warning disable
public void Start<TStateMachine>(ref TStateMachine stateMachine)
#pragma warning restore
where TStateMachine : IAsyncStateMachine
// Start is the whole reason why we have this custom builder.
// BCL implementation restores back SynchronizationContext.
// See https://github.com/dotnet/runtime/blob/c591f971241e7074f8a31ccde744aec9794e2500/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncMethodBuilderCore.cs#L45-L46
// We want to avoid restoring the SynchronizationContext for the case when the task completes synchronously on the same thread.
// This allows TestInitialize to set SynchronizationContext, and lets us be still fully async in our implementation.
// But then TestMethod can see the correct SynchronizationContext in the case of TestInitialize completing synchronously.
=> stateMachine.MoveNext();

public void AwaitOnCompleted<TAwaiter, TStateMachine>(
ref TAwaiter awaiter, ref TStateMachine stateMachine)
where TAwaiter : INotifyCompletion
where TStateMachine : IAsyncStateMachine
=> _inner.AwaitOnCompleted(ref awaiter, ref stateMachine);

public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
ref TAwaiter awaiter, ref TStateMachine stateMachine)
where TAwaiter : ICriticalNotifyCompletion
where TStateMachine : IAsyncStateMachine
=> _inner.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices.Execution;

[AsyncMethodBuilder(typeof(SynchronizationContextPreservingAsyncTaskMethodBuilder<>))]
internal sealed class SynchronizationContextPreservingTask<TResult>
{
private readonly Task<TResult> _innerTask;

public SynchronizationContextPreservingTask(Task<TResult> innerTask)
=> _innerTask = innerTask;

public TaskAwaiter<TResult> GetAwaiter()
=> _innerTask.GetAwaiter();

public ConfiguredTaskAwaitable<TResult> ConfigureAwait(bool continueOnCapturedContext)
=> _innerTask.ConfigureAwait(continueOnCapturedContext);
}

[AsyncMethodBuilder(typeof(SynchronizationContextPreservingAsyncTaskMethodBuilder))]
internal sealed class SynchronizationContextPreservingTask
{
private readonly Task _innerTask;

public SynchronizationContextPreservingTask(Task innerTask)
=> _innerTask = innerTask;

public TaskAwaiter GetAwaiter()
=> _innerTask.GetAwaiter();

public ConfiguredTaskAwaitable ConfigureAwait(bool continueOnCapturedContext)
=> _innerTask.ConfigureAwait(continueOnCapturedContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.Helpers;
using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.ObjectModel;
using Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices;
using Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices.Execution;
using Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices.Extensions;
using Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices.Interface;
using Microsoft.VisualStudio.TestTools.UnitTesting;
Expand Down Expand Up @@ -601,7 +602,7 @@ private static TestFailedException HandleMethodException(Exception ex, Exception
/// <param name="result">Instance of TestResult.</param>
/// <param name="timeoutTokenSource">The timeout token source.</param>
[SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Requirement is to handle all kinds of user exceptions and message appropriately.")]
private async Task RunTestCleanupMethodAsync(TestResult result, CancellationTokenSource? timeoutTokenSource)
private async SynchronizationContextPreservingTask RunTestCleanupMethodAsync(TestResult result, CancellationTokenSource? timeoutTokenSource)
{
DebugEx.Assert(result != null, "result != null");

Expand Down Expand Up @@ -708,7 +709,7 @@ _classInstance is IAsyncDisposable ||
/// <param name="timeoutTokenSource">The timeout token source.</param>
/// <returns>True if the TestInitialize method(s) did not throw an exception.</returns>
[SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Requirement is to handle all kinds of user exceptions and message appropriately.")]
private async Task<bool> RunTestInitializeMethodAsync(object classInstance, TestResult result, CancellationTokenSource? timeoutTokenSource)
private async SynchronizationContextPreservingTask<bool> RunTestInitializeMethodAsync(object classInstance, TestResult result, CancellationTokenSource? timeoutTokenSource)
{
DebugEx.Assert(classInstance != null, "classInstance != null");
DebugEx.Assert(result != null, "result != null");
Expand Down Expand Up @@ -786,7 +787,7 @@ private async Task<bool> RunTestInitializeMethodAsync(object classInstance, Test
return false;
}

private async Task<TestFailedException?> InvokeInitializeMethodAsync(MethodInfo methodInfo, object classInstance, CancellationTokenSource? timeoutTokenSource)
private async SynchronizationContextPreservingTask<TestFailedException?> InvokeInitializeMethodAsync(MethodInfo methodInfo, object classInstance, CancellationTokenSource? timeoutTokenSource)
{
TimeoutInfo? timeout = null;
if (Parent.TestInitializeMethodTimeoutMilliseconds.TryGetValue(methodInfo, out TimeoutInfo localTimeout))
Expand All @@ -807,7 +808,11 @@ private async Task<bool> RunTestInitializeMethodAsync(object classInstance, Test
await task.ConfigureAwait(false);
}

_executionContext = ExecutionContext.Capture() ?? _executionContext;
if (timeout?.CooperativeCancellation == false || _executionContext is not null)
{
_executionContext = ExecutionContext.Capture() ?? _executionContext;
}

#if NETFRAMEWORK
_hostContext = CallContext.HostContext;
#endif
Expand All @@ -825,7 +830,7 @@ timeoutTokenSource is null
return result;
}

private async Task<TestFailedException?> InvokeGlobalInitializeMethodAsync(MethodInfo methodInfo, TimeoutInfo? timeoutInfo, CancellationTokenSource? timeoutTokenSource)
private async SynchronizationContextPreservingTask<TestFailedException?> InvokeGlobalInitializeMethodAsync(MethodInfo methodInfo, TimeoutInfo? timeoutInfo, CancellationTokenSource? timeoutTokenSource)
{
TestFailedException? result = await FixtureMethodRunner.RunWithTimeoutAndCancellationAsync(
async () =>
Expand All @@ -840,7 +845,11 @@ timeoutTokenSource is null
await task.ConfigureAwait(false);
}

_executionContext = ExecutionContext.Capture() ?? _executionContext;
if (timeoutInfo?.CooperativeCancellation == false || _executionContext is not null)
{
_executionContext = ExecutionContext.Capture() ?? _executionContext;
}

#if NETFRAMEWORK
_hostContext = CallContext.HostContext;
#endif
Expand All @@ -858,7 +867,7 @@ timeoutTokenSource is null
return result;
}

private async Task<TestFailedException?> InvokeCleanupMethodAsync(MethodInfo methodInfo, object classInstance, CancellationTokenSource? timeoutTokenSource)
private async SynchronizationContextPreservingTask<TestFailedException?> InvokeCleanupMethodAsync(MethodInfo methodInfo, object classInstance, CancellationTokenSource? timeoutTokenSource)
{
TimeoutInfo? timeout = null;
if (Parent.TestCleanupMethodTimeoutMilliseconds.TryGetValue(methodInfo, out TimeoutInfo localTimeout))
Expand All @@ -879,7 +888,11 @@ timeoutTokenSource is null
await task.ConfigureAwait(false);
}

_executionContext = ExecutionContext.Capture() ?? _executionContext;
if (timeout?.CooperativeCancellation == false || _executionContext is not null)
{
_executionContext = ExecutionContext.Capture() ?? _executionContext;
}

#if NETFRAMEWORK
_hostContext = CallContext.HostContext;
#endif
Expand All @@ -897,7 +910,7 @@ timeoutTokenSource is null
return result;
}

private async Task<TestFailedException?> InvokeGlobalCleanupMethodAsync(MethodInfo methodInfo, TimeoutInfo? timeoutInfo, CancellationTokenSource? timeoutTokenSource)
private async SynchronizationContextPreservingTask<TestFailedException?> InvokeGlobalCleanupMethodAsync(MethodInfo methodInfo, TimeoutInfo? timeoutInfo, CancellationTokenSource? timeoutTokenSource)
{
TestFailedException? result = await FixtureMethodRunner.RunWithTimeoutAndCancellationAsync(
async () =>
Expand All @@ -912,7 +925,11 @@ timeoutTokenSource is null
await task.ConfigureAwait(false);
}

_executionContext = ExecutionContext.Capture() ?? _executionContext;
if (timeoutInfo?.CooperativeCancellation == false || _executionContext is not null)
{
_executionContext = ExecutionContext.Capture() ?? _executionContext;
}

#if NETFRAMEWORK
_hostContext = CallContext.HostContext;
#endif
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices.Execution;

namespace Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.Helpers;

internal static class ExecutionContextHelpers
Expand All @@ -23,7 +25,7 @@ internal static void RunOnContext(ExecutionContext? executionContext, Action act
}
}

internal static async Task RunOnContextAsync(ExecutionContext? executionContext, Func<Task> action)
internal static async SynchronizationContextPreservingTask RunOnContextAsync(ExecutionContext? executionContext, Func<SynchronizationContextPreservingTask> action)
{
if (executionContext is null)
{
Expand All @@ -37,8 +39,8 @@ internal static async Task RunOnContextAsync(ExecutionContext? executionContext,
// Otherwise, it will throw InvalidOperationException with message:
// Cannot apply a context that has been marshaled across AppDomains, that was not acquired through a Capture operation or that has already been the argument to a Set call.
executionContext = executionContext.CreateCopy();
Task? t = null;
ExecutionContext.Run(executionContext, action => t = ((Func<Task>)action!).Invoke(), action);
SynchronizationContextPreservingTask? t = null;
ExecutionContext.Run(executionContext, action => t = ((Func<SynchronizationContextPreservingTask>)action!).Invoke(), action);
if (t is not null)
{
await t.ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.Execution;
using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.ObjectModel;
using Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices.Execution;
using Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices.Extensions;

using UnitTestOutcome = Microsoft.VisualStudio.TestTools.UnitTesting.UnitTestOutcome;
Expand All @@ -11,8 +12,8 @@ namespace Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.Helpers;

internal static class FixtureMethodRunner
{
internal static async Task<TestFailedException?> RunWithTimeoutAndCancellationAsync(
Func<Task> action, CancellationTokenSource cancellationTokenSource, TimeoutInfo? timeoutInfo, MethodInfo methodInfo,
internal static async SynchronizationContextPreservingTask<TestFailedException?> RunWithTimeoutAndCancellationAsync(
Func<SynchronizationContextPreservingTask> action, CancellationTokenSource cancellationTokenSource, TimeoutInfo? timeoutInfo, MethodInfo methodInfo,
ExecutionContext? executionContext, string methodCanceledMessageFormat, string methodTimedOutMessageFormat,
// When a test method is marked with [Timeout], this timeout is applied from ctor to destructor, so we need to take
// that into account when processing the OCE of the action.
Expand Down Expand Up @@ -66,7 +67,7 @@ internal static class FixtureMethodRunner
: RunWithTimeoutAndCancellationWithThreadPool(action, executionContext, cancellationTokenSource, timeoutInfo.Value.Timeout, methodInfo, methodCanceledMessageFormat, methodTimedOutMessageFormat);
}

private static async Task<TestFailedException?> RunWithCooperativeCancellationAsync(Func<Task> action, ExecutionContext? executionContext, CancellationTokenSource cancellationTokenSource, int timeout, MethodInfo methodInfo, string methodCanceledMessageFormat, string methodTimedOutMessageFormat)
private static async SynchronizationContextPreservingTask<TestFailedException?> RunWithCooperativeCancellationAsync(Func<SynchronizationContextPreservingTask> action, ExecutionContext? executionContext, CancellationTokenSource cancellationTokenSource, int timeout, MethodInfo methodInfo, string methodCanceledMessageFormat, string methodTimedOutMessageFormat)
{
CancellationTokenSource? timeoutTokenSource = null;
try
Expand Down Expand Up @@ -117,7 +118,7 @@ internal static class FixtureMethodRunner
}

private static TestFailedException? RunWithTimeoutAndCancellationWithThreadPool(
Func<Task> action, ExecutionContext? executionContext, CancellationTokenSource cancellationTokenSource, int timeout, MethodInfo methodInfo,
Func<SynchronizationContextPreservingTask> action, ExecutionContext? executionContext, CancellationTokenSource cancellationTokenSource, int timeout, MethodInfo methodInfo,
string methodCanceledMessageFormat, string methodTimedOutMessageFormat)
{
Exception? realException = null;
Expand Down Expand Up @@ -178,7 +179,7 @@ internal static class FixtureMethodRunner

[SupportedOSPlatform("windows")]
private static TestFailedException? RunWithTimeoutAndCancellationWithSTAThread(
Func<Task> action, ExecutionContext? executionContext, CancellationTokenSource cancellationTokenSource, int timeout, MethodInfo methodInfo,
Func<SynchronizationContextPreservingTask> action, ExecutionContext? executionContext, CancellationTokenSource cancellationTokenSource, int timeout, MethodInfo methodInfo,
string methodCanceledMessageFormat, string methodTimedOutMessageFormat)
{
TaskCompletionSource<int> tcs = new();
Expand Down
Loading
Loading