mgravell/PooledAwait

ValueTaskCompletionSource: ValidateOptimized called for every T

RichardD2 opened this issue · 0 comments

With the current implementation, the ValidateOptimized method is called once for every type used as a type parameter. I suspect it only needs to be called once.

Moving most of the static members out of the ValueTaskCompletionSource struct would allow this:

internal static class TaskHelper
{
#if NETSTANDARD1_3
    public static readonly bool UseOptimizedPath = false;
#else
    private static class TaskInternals<T>
    {
        public static readonly Func<Task<T>, T, bool> TrySetResult = TryCreate<T>(nameof(TrySetResult));
        public static readonly Func<Task<T>, Exception, bool> TrySetException = TryCreate<Exception>(nameof(TrySetException));
        public static readonly Func<Task<T>, CancellationToken, bool> TrySetCanceled = TryCreate<CancellationToken>(nameof(TrySetCanceled));

        [MethodImpl(MethodImplOptions.NoInlining)]
        private static Func<Task<T>, TArg, bool> TryCreate<TArg>(string methodName)
        {
            try
            {
                var method = typeof(Task<T>).GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance, null, new[] { typeof(TArg) }, null);
                return method is null ? null : (Func<Task<T>, TArg, bool>)Delegate.CreateDelegate(typeof(Func<Task<T>, TArg, bool>), method);
            }
            catch
            {
                return null;
            }
        }
    }

    public static readonly bool UseOptimizedPath = ValidateOptimized();

    [MethodImpl(MethodImplOptions.NoInlining)]
    private static bool ValidateOptimized()
    {
        try
        {
            if (TaskInternals<object>.TrySetResult is null) return false;
            if (TaskInternals<object>.TrySetException is null) return false;
            if (TaskInternals<object>.TrySetCanceled is null) return false;

            var task = CreateUninitializedTask<object>();
            if (task is null) return false;
            if (task.IsCompleted) return false;

            if (!TaskInternals<object>.TrySetResult(task, default)) return false;
            if (task.Status != TaskStatus.RanToCompletion) return false;

            task = CreateUninitializedTask<object>();
            if (!TaskInternals<object>.TrySetException(task, new InvalidOperationException())) return false;
            if (!task.IsCompleted) return false;
            if (!task.IsFaulted) return false;

            try
            {
                _ = task.Result;
                return false;
            }
            catch (AggregateException ex) when (ex.InnerException is InvalidOperationException)
            {
            }

            return task.Exception?.InnerException is InvalidOperationException;
        }
        catch
        {
            return false;
        }
    }

    [MethodImpl(MethodImplOptions.NoInlining)]
    private static void SpinUntilCompleted([NotNull] Task task)
    {
        // Spin wait until the completion is finalized by another thread.
        var sw = new SpinWait();
        while (!task.IsCompleted)
        {
            sw.SpinOnce();
        }
    }

    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static Task<T> CreateUninitializedTask<T>() => (Task<T>)System.Runtime.Serialization.FormatterServices.GetUninitializedObject(typeof(Task<T>));

    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static bool TrySetResult<T>(this Task<T> task, T value)
    {
        bool result = TaskInternals<T>.TrySetResult(task, value);
        if (!result && !task.IsCompleted) SpinUntilCompleted(task);
        return result;
    }

    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static bool TrySetException<T>(this Task<T> task, Exception error)
    {
        bool result = TaskInternals<T>.TrySetException(task, error);
        if (!result && !task.IsCompleted) SpinUntilCompleted(task);
        return result;
    }

    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static bool TrySetCanceled<T>(this Task<T> task, CancellationToken cancellationToken)
    {
        bool result = TaskInternals<T>.TrySetCanceled(task, cancellationToken);
        if (!result && !task.IsCompleted) SpinUntilCompleted(task);
        return result;
    }
#endif
}