Skip to content

Commit

Permalink
Make cancellation mechanism internal and accessible via experiemtnal …
Browse files Browse the repository at this point in the history
…nuget.
  • Loading branch information
codemzs committed Mar 12, 2019
1 parent 790da12 commit 225cad9
Show file tree
Hide file tree
Showing 13 changed files with 157 additions and 15 deletions.
15 changes: 15 additions & 0 deletions Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Ensemble", "Mi
pkg\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.symbols.nupkgproj = pkg\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.symbols.nupkgproj
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Experimental", "src\Microsoft.ML.Experimental\Microsoft.ML.Experimental.csproj", "{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -948,6 +950,18 @@ Global
{5E920CAC-5A28-42FB-936E-49C472130953}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
{5E920CAC-5A28-42FB-936E-49C472130953}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
{5E920CAC-5A28-42FB-936E-49C472130953}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug|Any CPU.Build.0 = Debug|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release|Any CPU.ActiveCfg = Release|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release|Any CPU.Build.0 = Release|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -1033,6 +1047,7 @@ Global
{31D38B21-102B-41C0-9E0A-2FE0BF68D123} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{5E920CAC-5A28-42FB-936E-49C472130953} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{AD7058C9-5608-49A8-BE23-58C33A74EE91} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
12 changes: 12 additions & 0 deletions pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.nupkgproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk" DefaultTargets="Pack">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<PackageDescription>Microsoft.ML.Experimental contains extension method for MLContext to access internal function to cancel ML pipeline.</PackageDescription>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="../Microsoft.ML.Data/Microsoft.ML.Data.nupkgproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<Project DefaultTargets="Pack">

<Import Project="Microsoft.ML.Experimental.nupkgproj" />

</Project>
1 change: 1 addition & 0 deletions src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ private sealed class Host : HostBase
public Host(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
: base(source, shortName, parentFullName, rand, verbose)
{
IsCancelled = source.IsCancelled;
}

protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ internal interface IMessageSource
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, IHostEnvironment, IChannelProvider
where TEnv : HostEnvironmentBase<TEnv>
{
public bool IsCancelled { get; set; }

/// <summary>
/// Base class for hosts. Classes derived from <see cref="HostEnvironmentBase{THostEnvironmentBase}"/> may choose
/// to provide their own host class that derives from this class.
Expand All @@ -110,8 +112,6 @@ public abstract class HostBase : HostEnvironmentBase<TEnv>, IHost, ICancellableE
// We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference.
private readonly List<WeakReference<IHost>> _children;

public bool IsCancelled { get; set; }

public HostBase(HostEnvironmentBase<TEnv> source, string shortName, string parentFullName, Random rand, bool verbose)
: base(source, rand, verbose, shortName, parentFullName)
{
Expand Down
22 changes: 11 additions & 11 deletions src/Microsoft.ML.Core/Utilities/Contracts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -737,22 +737,22 @@ public static void CheckIO(this IExceptionContext ctx, bool f, string msg)
if (!f)
throw ExceptIO(ctx, msg);
}
public static void CheckIO(this IExceptionContext ctx, bool f, string msg, params object[] args)
{
if (!f)
throw ExceptIO(ctx, msg, args);
}
}
public static void CheckIO(this IExceptionContext ctx, bool f, string msg, params object[] args)
{
if (!f)
throw ExceptIO(ctx, msg, args);
}
#if !CPUMATH_INFRASTRUCTURE
/// <summary>
/// Check state of the host and throw exception if host marked to stop all exection.
/// </summary>
public static void CheckAlive(this IHostEnvironment env)
/// <summary>
/// Check state of the host and throw exception if host marked to stop all exection.
/// </summary>
public static void CheckAlive(this IHostEnvironment env)
{
if (env.IsCancelled)
if ((env is ICancellableEnvironment) && (env as ICancellableEnvironment).IsCancelled)
throw Process(new OperationCanceledException("Operation was cancelled."), env);
}
#endif

/// <summary>
/// This documents that the parameter can legally be null.
/// </summary>
Expand Down
20 changes: 19 additions & 1 deletion src/Microsoft.ML.Data/MLContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;

Expand Down Expand Up @@ -72,6 +73,8 @@ public sealed class MLContext : IHostEnvironment
/// </summary>
public ComponentCatalog ComponentCatalog => _env.ComponentCatalog;

private List<IHost> _hosts;

/// <summary>
/// Create the ML context.
/// </summary>
Expand All @@ -90,6 +93,7 @@ public MLContext(int? seed = null)
Transforms = new TransformsCatalog(_env);
Model = new ModelOperationsCatalog(_env);
Data = new DataOperationsCatalog(_env);
_hosts = new List<IHost>();
}

private void ProcessMessage(IMessageSource source, ChannelMessage message)
Expand All @@ -106,9 +110,23 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)

string IExceptionContext.ContextDescription => _env.ContextDescription;
TException IExceptionContext.Process<TException>(TException ex) => _env.Process(ex);
IHost IHostEnvironment.Register(string name, int? seed, bool? verbose) => _env.Register(name, seed, verbose);
IHost IHostEnvironment.Register(string name, int? seed, bool? verbose)
{
var host = _env.Register(name, seed, verbose);
_hosts.Add(host);
return host;
}

IChannel IChannelProvider.Start(string name) => _env.Start(name);
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);

[BestFriend]
internal void StopExecution()
{
foreach(var host in _hosts)
if (host is ICancellableHost)
((ICancellableHost)host).StopExecution();
}
}
}
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet50" + PublicKey.Value)]

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipe" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Experimental" + PublicKey.Value)]

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.MetaLinearLearner" + InternalPublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "TMSNlearnPrediction" + InternalPublicKey.Value)]
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ private sealed class Host : HostBase
public Host(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
: base(source, shortName, parentFullName, rand, verbose)
{
IsCancelled = source.IsCancelled;
}

protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
Expand Down
15 changes: 15 additions & 0 deletions src/Microsoft.ML.Experimental/Experimental.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.ML.Experimental
{
public static class Experimental
{
/// <summary>
/// Stop the exeuction of pipeline in <see cref="MLContext"/>
/// </summary>
/// <param name="ctx"></param>
public static void StopExecution(this MLContext ctx) => ctx.StopExecution();
}
}
17 changes: 17 additions & 0 deletions src/Microsoft.ML.Experimental/Microsoft.ML.Experimental.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<IncludeInPackage>Microsoft.ML.Experimental</IncludeInPackage>
<RootNamespace>Microsoft.ML.Experimental</RootNamespace>
</PropertyGroup>

<ItemGroup>
<Folder Include="Properties\" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
</ItemGroup>

</Project>
57 changes: 57 additions & 0 deletions test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,63 @@ namespace Microsoft.ML.RunTests
{
public class TestHosts
{
[Fact]
public void TestCancellation()
{
IHostEnvironment env = new MLContext(seed: 42);
for (int z = 0; z < 1000; z++)
{
var mainHost = env.Register("Main");
var children = new ConcurrentDictionary<IHost, List<IHost>>();
var hosts = new BlockingCollection<Tuple<IHost, int>>();
hosts.Add(new Tuple<IHost, int>(mainHost.Register("1"), 1));
hosts.Add(new Tuple<IHost, int>(mainHost.Register("2"), 1));
hosts.Add(new Tuple<IHost, int>(mainHost.Register("3"), 1));
hosts.Add(new Tuple<IHost, int>(mainHost.Register("4"), 1));
hosts.Add(new Tuple<IHost, int>(mainHost.Register("5"), 1));

int iterations = 100;
Random rand = new Random();
var addThread = new Thread(
() =>
{
for (int i = 0; i < iterations; i++)
{
var randHostTuple = hosts.ElementAt(rand.Next(hosts.Count - 1));
var newHost = randHostTuple.Item1.Register((randHostTuple.Item2 + 1).ToString());
hosts.Add(new Tuple<IHost, int>(newHost, randHostTuple.Item2 + 1));
if (!children.ContainsKey(randHostTuple.Item1))
children[randHostTuple.Item1] = new List<IHost>();
else
children[randHostTuple.Item1].Add(newHost);
}
});
addThread.Start();
Queue<IHost> queue = new Queue<IHost>();
for (int i = 0; i < 5; i++)
{
IHost rootHost = null;
var index = 0;
do
{
index = rand.Next(hosts.Count);
} while ((hosts.ElementAt(index).Item1 as ICancellableEnvironment).IsCancelled || hosts.ElementAt(index).Item2 < 3);
(hosts.ElementAt(index).Item1 as ICancellableHost).StopExecution();
rootHost = hosts.ElementAt(index).Item1;
queue.Enqueue(rootHost);
}
addThread.Join();
while (queue.Count > 0)
{
var currentHost = queue.Dequeue();
Assert.True((currentHost as ICancellableEnvironment).IsCancelled);

if (children.ContainsKey(currentHost))
children[currentHost].ForEach(x => queue.Enqueue(x));
}
}
}

/// <summary>
/// Tests that MLContext's Log event intercepts messages properly.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ internal static class DecodeMessageWithLoadContextDiagnostic
private static HashSet<string> _targetSet = new HashSet<string>(new[]
{
"Check", "CheckUserArg", "CheckParam", "CheckParamValue", "CheckRef", "CheckValue",
"CheckNonEmpty", "CheckNonWhiteSpace", "CheckDecode", "CheckIO", "CheckValueOrNull",
"CheckNonEmpty", "CheckNonWhiteSpace", "CheckDecode", "CheckIO", "CheckAlive", "CheckValueOrNull",
"Except", "ExceptUserArg", "ExceptParam", "ExceptParamValue", "ExceptValue", "ExceptEmpty",
"ExceptWhiteSpace", "ExceptDecode", "ExceptIO", "ExceptNotImpl", "ExceptNotSupp", "ExceptSchemaMismatch"
});
Expand Down

0 comments on commit 225cad9

Please sign in to comment.