Skip to content

Commit

Permalink
Ability to mock protected methods with and without return types (void)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jason31569 committed Nov 20, 2024
1 parent ab73157 commit ff8cffc
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/NSubstitute/Core/IThreadLocalContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ public interface IThreadLocalContext
void EnqueueArgumentSpecification(IArgumentSpecification spec);
IList<IArgumentSpecification> DequeueAllArgumentSpecifications();

/// <summary>
/// Peeks into the argument specifications
/// </summary>
/// <returns>Enqueued argument specifications</returns>
IList<IArgumentSpecification> PeekAllArgumentSpecifications();

void SetPendingRaisingEventArgumentsFactory(Func<ICall, object?[]> getArguments);
/// <summary>
/// Returns the previously set arguments factory and resets the stored value.
Expand Down
18 changes: 18 additions & 0 deletions src/NSubstitute/Core/ThreadLocalContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ public IList<IArgumentSpecification> DequeueAllArgumentSpecifications()
return queue;
}

/// <inheritdoc/>
public IList<IArgumentSpecification> PeekAllArgumentSpecifications()
{
var queue = _argumentSpecifications.Value;
if (queue == null) { throw new SubstituteInternalException("Argument specification queue is null."); }

if (queue.Count > 0)
{
var items = new IArgumentSpecification[queue.Count];

queue.CopyTo(items, 0);

return items;
}

return EmptySpecifications;
}

public void SetPendingRaisingEventArgumentsFactory(Func<ICall, object?[]> getArguments)
{
_getArgumentsForRaisingEvent.Value = getArguments;
Expand Down
59 changes: 59 additions & 0 deletions src/NSubstitute/Extensions/ProtectedExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using System.Reflection;
using NSubstitute.Core;
using NSubstitute.Core.Arguments;

// Disable nullability for client API, so it does not affect clients.
#nullable disable annotations

namespace NSubstitute.Extensions;

public static class ProtectedExtensions
{
/// <summary>
/// Configure behavior for a protected method with return value
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="obj">The object.</param>
/// <param name="methodName">Name of the method.</param>
/// <param name="args">The arguments.</param>
/// <returns>Result object from the mehtod invocation.</returns>
/// <exception cref="System.ArgumentNullException">obj - Cannot mock null object</exception>
/// <exception cref="System.ArgumentException">Must provide valid protected method name to mock - methodName</exception>
public static object Protected<T>(this T obj, string methodName, params object[] args) where T : class
{
if (obj == null) { throw new ArgumentNullException(nameof(obj), "Cannot mock null object"); }
if (string.IsNullOrWhiteSpace(methodName)) { throw new ArgumentException("Must provide valid protected method name to mock", nameof(methodName)); }

IList<IArgumentSpecification> argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications();
MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(x => x.ForType).ToArray(), null);

if (mthdInfo == null) { throw new Exception($"Method {methodName} not found"); }
if (!mthdInfo.IsVirtual) { throw new Exception($"Method {methodName} is not virtual"); }

return mthdInfo.Invoke(obj, args);
}

/// <summary>
/// Configure behavior for a protected method with no return vlaue
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="obj">The object.</param>
/// <param name="methodName">Name of the method.</param>
/// <param name="args">The arguments.</param>
/// <returns>WhenCalled&lt;T&gt;.</returns>
/// <exception cref="System.ArgumentNullException">obj - Cannot mock null object</exception>
/// <exception cref="System.ArgumentException">Must provide valid protected method name to mock - methodName</exception>
public static WhenCalled<T> When<T>(this T obj, string methodName, params object[] args) where T : class
{
if (obj == null) { throw new ArgumentNullException(nameof(obj), "Cannot mock null object"); }
if (string.IsNullOrWhiteSpace(methodName)) { throw new ArgumentException("Must provide valid protected method name to mock", nameof(methodName)); }

IList<IArgumentSpecification> argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications();
MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(y => y.ForType).ToArray(), null);

if (mthdInfo == null) { throw new Exception($"Method {methodName} not found"); }
if (!mthdInfo.IsVirtual) { throw new Exception($"Method {methodName} is not virtual"); }

return new WhenCalled<T>(SubstitutionContext.Current, obj, x => mthdInfo.Invoke(x, args), MatchArgs.AsSpecifiedInCall);
}
}
46 changes: 46 additions & 0 deletions tests/NSubstitute.Acceptance.Specs/Infrastructure/AnotherClass.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
namespace NSubstitute.Acceptance.Specs.Infrastructure;

public abstract class AnotherClass
{
protected abstract string ProtectedMethod();

protected abstract string ProtectedMethod(int i);

protected abstract string ProtectedMethod(string msg, int i, char j);

protected abstract void ProtectedMethodWithNoReturn();

protected abstract void ProtectedMethodWithNoReturn(int i);

protected abstract void ProtectedMethodWithNoReturn(string msg, int i, char j);

public string DoWork()
{
return ProtectedMethod();
}

public string DoWork(int i)
{
return ProtectedMethod(i);
}

public string DoWork(string msg, int i, char j)
{
return ProtectedMethod(msg, i, j);
}

public void DoVoidWork()
{
ProtectedMethodWithNoReturn();
}

public void DoVoidWork(int i)
{
ProtectedMethodWithNoReturn(i);
}

public void DoVoidWork(string msg, int i, char j)
{
ProtectedMethodWithNoReturn(msg, i, j);
}
}
123 changes: 123 additions & 0 deletions tests/NSubstitute.Acceptance.Specs/ProtectedExtensionsTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
using NSubstitute.Acceptance.Specs.Infrastructure;
using NSubstitute.Extensions;
using NUnit.Framework;

namespace NSubstitute.Acceptance.Specs;

public class ProtectedExtensionsTests
{
[Test]
public void Should_mock_and_verify_protected_method_with_no_args()
{
var expectedMsg = "unit test message";
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.Protected("ProtectedMethod").Returns(expectedMsg);

Assert.That(worker.DoWork(sub), Is.EqualTo(expectedMsg));
sub.Received(1).Protected("ProtectedMethod");
}

[Test]
public void Should_mock_and_verify_protected_method_with_arg()
{
var expectedMsg = "unit test message";
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.Protected("ProtectedMethod", Arg.Any<int>()).Returns(expectedMsg);

Assert.That(worker.DoMoreWork(sub, 5), Is.EqualTo(expectedMsg));
var a = sub.Received(1);
a.Protected("ProtectedMethod", Arg.Any<int>());
}

[Test]
public void Should_mock_and_verify_protected_method_with_multiple_args()
{
var expectedMsg = "unit test message";
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.Protected("ProtectedMethod", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>()).Returns(expectedMsg);

Assert.That(worker.DoEvenMoreWork(sub, 3, 'x'), Is.EqualTo(expectedMsg));
sub.Received(1).Protected("ProtectedMethod", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>());
}

[Test]
public void Should_mock_and_verify_method_with_no_return_and_no_args()
{
var count = 0;
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.When("ProtectedMethodWithNoReturn").Do(x => count++);

worker.DoVoidWork(sub);
Assert.That(count, Is.EqualTo(1));
sub.Received(1).Protected("ProtectedMethodWithNoReturn");
}

[Test]
public void Should_mock_and_verify_method_with_no_return_with_arg()
{
var count = 0;
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.When("ProtectedMethodWithNoReturn", Arg.Any<int>()).Do(x => count++);

worker.DoVoidWork(sub, 5);
Assert.That(count, Is.EqualTo(1));
sub.Received(1).Protected("ProtectedMethodWithNoReturn", Arg.Any<int>());
}

[Test]
public void Should_mock_and_verify_method_with_no_return_with_multiple_args()
{
var count = 0;
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.When("ProtectedMethodWithNoReturn", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>()).Do(x => count++);

worker.DoVoidWork(sub, 5, 'x');
Assert.That(count, Is.EqualTo(1));
sub.Received(1).Protected("ProtectedMethodWithNoReturn", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>());
}

private class Worker
{
internal string DoWork(AnotherClass worker)
{
return worker.DoWork();
}

internal string DoMoreWork(AnotherClass worker, int i)
{
return worker.DoWork(i);
}

internal string DoEvenMoreWork(AnotherClass worker, int i, char j)
{
return worker.DoWork("worker", i, j);
}

internal void DoVoidWork(AnotherClass worker)
{
worker.DoVoidWork();
}

internal void DoVoidWork(AnotherClass worker, int i)
{
worker.DoVoidWork(i);
}

internal void DoVoidWork(AnotherClass worker, int i, char j)
{
worker.DoVoidWork("void worker", i, j);
}
}
}

0 comments on commit ff8cffc

Please sign in to comment.