Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor cancellation mechanism and make it internal, accessible via experimental nuget. #2797

Merged
merged 16 commits into from
Mar 19, 2019
15 changes: 15 additions & 0 deletions Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Ensemble", "Mi
pkg\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.symbols.nupkgproj = pkg\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.symbols.nupkgproj
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Experimental", "src\Microsoft.ML.Experimental\Microsoft.ML.Experimental.csproj", "{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -948,6 +950,18 @@ Global
{5E920CAC-5A28-42FB-936E-49C472130953}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
{5E920CAC-5A28-42FB-936E-49C472130953}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
{5E920CAC-5A28-42FB-936E-49C472130953}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug|Any CPU.Build.0 = Debug|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release|Any CPU.ActiveCfg = Release|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release|Any CPU.Build.0 = Release|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -1033,6 +1047,7 @@ Global
{31D38B21-102B-41C0-9E0A-2FE0BF68D123} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{5E920CAC-5A28-42FB-936E-49C472130953} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{AD7058C9-5608-49A8-BE23-58C33A74EE91} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
12 changes: 12 additions & 0 deletions pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.nupkgproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk" DefaultTargets="Pack">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<PackageDescription>Microsoft.ML.Experimental contains experimental work such extension methods to access internal methods.</PackageDescription>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<Project DefaultTargets="Pack">

<Import Project="Microsoft.ML.Experimental.nupkgproj" />

</Project>
22 changes: 13 additions & 9 deletions src/Microsoft.ML.Core/Data/IHostEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,23 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
IHost Register(string name, int? seed = null, bool? verbose = null);

/// <summary>
/// Flag which indicate should we stop any code execution in this host.
/// The catalog of loadable components (<see cref="LoadableClassAttribute"/>) that are available in this host.
/// </summary>
bool IsCancelled { get; }
codemzs marked this conversation as resolved.
Show resolved Hide resolved
ComponentCatalog ComponentCatalog { get; }
}

[BestFriend]
internal interface ICancelable
{
/// <summary>
/// The catalog of loadable components (<see cref="LoadableClassAttribute"/>) that are available in this host.
/// Signal to stop exection in all the hosts.
/// </summary>
ComponentCatalog ComponentCatalog { get; }
void CancelExecution();

/// <summary>
/// Flag which indicates host execution has been stopped.
/// </summary>
bool IsCanceled { get; }
}

/// <summary>
Expand All @@ -85,11 +94,6 @@ public interface IHost : IHostEnvironment
/// generators are NOT thread safe.
/// </summary>
Random Rand { get; }

/// <summary>
/// Signal to stop exection in this host and all its children.
/// </summary>
void StopExecution();
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ private sealed class Host : HostBase
public Host(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
: base(source, shortName, parentFullName, rand, verbose)
{
IsCancelled = source.IsCancelled;
IsCanceled = source.IsCanceled;
}

protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
Expand Down
58 changes: 33 additions & 25 deletions src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,24 @@ internal interface IMessageSource
/// query progress.
/// </summary>
[BestFriend]
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, IHostEnvironment, IChannelProvider
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, IHostEnvironment, IChannelProvider, ICancelable
where TEnv : HostEnvironmentBase<TEnv>
{
[BestFriend]
codemzs marked this conversation as resolved.
Show resolved Hide resolved
void ICancelable.CancelExecution()
{
lock (_cancelLock)
{
foreach (var child in _children)
if (child.TryGetTarget(out IHost host))
if (host is ICancelable cancelableHost)
cancelableHost.CancelExecution();

_children.Clear();
IsCanceled = true;
}
}

/// <summary>
/// Base class for hosts. Classes derived from <see cref="HostEnvironmentBase{THostEnvironmentBase}"/> may choose
/// to provide their own host class that derives from this class.
Expand All @@ -107,28 +122,10 @@ public abstract class HostBase : HostEnvironmentBase<TEnv>, IHost

public Random Rand => _rand;

// We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference.
private readonly List<WeakReference<IHost>> _children;

public HostBase(HostEnvironmentBase<TEnv> source, string shortName, string parentFullName, Random rand, bool verbose)
: base(source, rand, verbose, shortName, parentFullName)
{
Depth = source.Depth + 1;
_children = new List<WeakReference<IHost>>();
}

public void StopExecution()
{
lock (_cancelLock)
{
IsCancelled = true;
foreach (var child in _children)
{
if (child.TryGetTarget(out IHost host))
host.StopExecution();
}
_children.Clear();
}
}

public new IHost Register(string name, int? seed = null, bool? verbose = null)
Expand All @@ -139,7 +136,7 @@ public void StopExecution()
{
Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose);
if (!IsCancelled)
if (host is ICancelable cancelableHost && !cancelableHost.IsCanceled)
codemzs marked this conversation as resolved.
Show resolved Hide resolved
_children.Add(new WeakReference<IHost>(host));
}
return host;
Expand Down Expand Up @@ -175,7 +172,7 @@ protected PipeBase(ChannelProviderBase parent, string shortName,

public void Dispose()
{
if(!_disposed)
if (!_disposed)
{
Dispose(true);
_disposed = true;
Expand Down Expand Up @@ -339,12 +336,15 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)

protected readonly ProgressReporting.ProgressTracker ProgressTracker;

public bool IsCancelled { get; protected set; }

public ComponentCatalog ComponentCatalog { get; }

public override int Depth => 0;

public bool IsCanceled { get; protected set; }

// We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference.
private readonly List<WeakReference<IHost>> _children;

/// <summary>
/// The main constructor.
/// </summary>
Expand All @@ -359,6 +359,7 @@ protected HostEnvironmentBase(Random rand, bool verbose,
_cancelLock = new object();
Root = this as TEnv;
ComponentCatalog = new ComponentCatalog();
_children = new List<WeakReference<IHost>>();
}

/// <summary>
Expand All @@ -379,13 +380,20 @@ protected HostEnvironmentBase(HostEnvironmentBase<TEnv> source, Random rand, boo
ListenerDict = source.ListenerDict;
ProgressTracker = source.ProgressTracker;
ComponentCatalog = source.ComponentCatalog;
_children = new List<WeakReference<IHost>>();
}

public IHost Register(string name, int? seed = null, bool? verbose = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
return RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose);
IHost host;
lock (_cancelLock)
{
Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose);
_children.Add(new WeakReference<IHost>(host));
}
return host;
}

protected abstract IHost RegisterCore(HostEnvironmentBase<TEnv> source, string shortName,
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Core/Utilities/Contracts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ public static void CheckIO(this IExceptionContext ctx, bool f, string msg)
if (!f)
throw ExceptIO(ctx, msg);
}

public static void CheckIO(this IExceptionContext ctx, bool f, string msg, params object[] args)
{
if (!f)
Expand All @@ -748,11 +749,10 @@ public static void CheckIO(this IExceptionContext ctx, bool f, string msg, param
/// </summary>
public static void CheckAlive(this IHostEnvironment env)
{
if (env.IsCancelled)
if (env is ICancelable cancelableEnv && cancelableEnv.IsCanceled)
throw Process(new OperationCanceledException("Operation was cancelled."), env);
}
#endif

/// <summary>
/// This documents that the parameter can legally be null.
/// </summary>
Expand Down
5 changes: 4 additions & 1 deletion src/Microsoft.ML.Data/MLContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;

Expand Down Expand Up @@ -104,12 +105,14 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
log(this, new LoggingEventArgs(msg));
}

bool IHostEnvironment.IsCancelled => _env.IsCancelled;
string IExceptionContext.ContextDescription => _env.ContextDescription;
TException IExceptionContext.Process<TException>(TException ex) => _env.Process(ex);
IHost IHostEnvironment.Register(string name, int? seed, bool? verbose) => _env.Register(name, seed, verbose);
IChannel IChannelProvider.Start(string name) => _env.Start(name);
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);

[BestFriend]
internal void CancelExecution() => (_env as ICancelable).CancelExecution();
codemzs marked this conversation as resolved.
Show resolved Hide resolved
}
}
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet50" + PublicKey.Value)]

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipe" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Experimental" + PublicKey.Value)]

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.MetaLinearLearner" + InternalPublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "TMSNlearnPrediction" + InternalPublicKey.Value)]
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ private sealed class Host : HostBase
public Host(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
: base(source, shortName, parentFullName, rand, verbose)
{
IsCancelled = source.IsCancelled;
IsCanceled = source.IsCanceled;
}

protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
Expand Down
15 changes: 15 additions & 0 deletions src/Microsoft.ML.Experimental/MLContextExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.ML.Experimental
{
public static class MLContextExtensions
{
/// <summary>
/// Stop the execution of pipeline in <see cref="MLContext"/>
/// </summary>
/// <param name="ctx"><see cref="MLContext"/> reference.</param>
public static void CancelExecution(this MLContext ctx) => ctx.CancelExecution();
}
}
12 changes: 12 additions & 0 deletions src/Microsoft.ML.Experimental/Microsoft.ML.Experimental.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<IncludeInPackage>Microsoft.ML.Experimental</IncludeInPackage>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
</ItemGroup>

</Project>
10 changes: 8 additions & 2 deletions test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ public async Task ContractsCheck()
VerifyCS.Diagnostic(ContractsCheckAnalyzer.SimpleMessageDiagnostic.Rule).WithLocation(basis + 32, 35).WithArguments("Check", "\"Less fine: \" + env.GetType().Name"),
VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(basis + 34, 17).WithArguments("CheckUserArg", "name", "\"p\""),
VerifyCS.Diagnostic(ContractsCheckAnalyzer.DecodeMessageWithLoadContextDiagnostic.Rule).WithLocation(basis + 39, 41).WithArguments("CheckDecode", "\"This message is suspicious\""),
new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 24).WithMessage("'ICancelable' is inaccessible due to its protection level"),
new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 67).WithMessage("'ICancelable.IsCanceled' is inaccessible due to its protection level"),
};

var test = new VerifyCS.Test
Expand Down Expand Up @@ -125,7 +127,9 @@ public async Task ContractsCheckFix()
VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(23, 39).WithArguments("CheckValue", "paramName", "\"noMatch\""),
VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(24, 53).WithArguments("CheckUserArg", "name", "\"chumble\""),
VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(25, 53).WithArguments("CheckUserArg", "name", "\"sp\""),
new DiagnosticResult("CS1503", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 91).WithMessage("Argument 2: cannot convert from 'Microsoft.ML.Runtime.IHostEnvironment' to 'Microsoft.ML.Runtime.IExceptionContext'"),
new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 24).WithMessage("'ICancelable' is inaccessible due to its protection level"),
new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 67).WithMessage("'ICancelable.IsCanceled' is inaccessible due to its protection level"),
new DiagnosticResult("CS1503", DiagnosticSeverity.Error).WithLocation("Test1.cs", 753, 91).WithMessage("Argument 2: cannot convert from 'Microsoft.ML.Runtime.IHostEnvironment' to 'Microsoft.ML.Runtime.IExceptionContext'"),
},
AdditionalReferences = { AdditionalMetadataReferences.RefFromType<Memory<int>>() },
},
Expand All @@ -144,7 +148,9 @@ public async Task ContractsCheckFix()
{
VerifyCS.Diagnostic(ContractsCheckAnalyzer.ExceptionDiagnostic.Rule).WithLocation(9, 43).WithArguments("ExceptParam"),
VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(23, 39).WithArguments("CheckValue", "paramName", "\"noMatch\""),
new DiagnosticResult("CS1503", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 91).WithMessage("Argument 2: cannot convert from 'Microsoft.ML.Runtime.IHostEnvironment' to 'Microsoft.ML.Runtime.IExceptionContext'"),
new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 24).WithMessage("'ICancelable' is inaccessible due to its protection level"),
new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 67).WithMessage("'ICancelable.IsCanceled' is inaccessible due to its protection level"),
new DiagnosticResult("CS1503", DiagnosticSeverity.Error).WithLocation("Test1.cs", 753, 91).WithMessage("Argument 2: cannot convert from 'Microsoft.ML.Runtime.IHostEnvironment' to 'Microsoft.ML.Runtime.IExceptionContext'"),
},
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ internal enum MessageSensitivity
}
internal interface IHostEnvironment : IExceptionContext
{
bool IsCancelled { get; }
}
}

Expand Down
Loading