Skip to content

Commit

Permalink
Stop using System.ComponentModel.Composition (dotnet#2569)
Browse files Browse the repository at this point in the history
* Stop using System.ComponentModel.Composition

Replace our MEF usage, which is only used by custom mapping transforms, with the ComponentCatalog class.

Fix dotnet#1595
Fix dotnet#2422

* Rename new class to CustomMappingFactory.
  • Loading branch information
eerhardt authored Feb 21, 2019
1 parent 512493a commit 412e1f9
Show file tree
Hide file tree
Showing 15 changed files with 196 additions and 91 deletions.
1 change: 0 additions & 1 deletion build/Dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
<SystemMemoryVersion>4.5.1</SystemMemoryVersion>
<SystemReflectionEmitLightweightPackageVersion>4.3.0</SystemReflectionEmitLightweightPackageVersion>
<SystemThreadingTasksDataflowPackageVersion>4.8.0</SystemThreadingTasksDataflowPackageVersion>
<SystemComponentModelCompositionVersion>4.5.0</SystemComponentModelCompositionVersion>
</PropertyGroup>

<!-- Other/Non-Core Product Dependencies -->
Expand Down
29 changes: 15 additions & 14 deletions docs/code/MlNetCookBook.md
Original file line number Diff line number Diff line change
Expand Up @@ -970,27 +970,27 @@ Please note that you need to make your `mapping` operation into a 'pure function
- It should not have side effects (we may call it arbitrarily at any time, or omit the call)

One important caveat is: if you want your custom transformation to be part of your saved model, you will need to provide a `contractName` for it.
At loading time, you will need to reconstruct the custom transformer and inject it into MLContext.
At loading time, you will need to register the custom transformer with the MLContext.

Here is a complete example that saves and loads a model with a custom mapping.
```csharp
/// <summary>
/// One class that contains all custom mappings that we need for our model.
/// One class that contains the custom mapping functionality that we need for our model.
///
/// It has a <see cref="CustomMappingFactoryAttributeAttribute"/> on it and
/// derives from <see cref="CustomMappingFactory{TSrc, TDst}"/>.
/// </summary>
public class CustomMappings
[CustomMappingFactoryAttribute(nameof(CustomMappings.IncomeMapping))]
public class CustomMappings : CustomMappingFactory<InputRow, OutputRow>
{
// This is the custom mapping. We now separate it into a method, so that we can use it both in training and in loading.
public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000;

// MLContext is needed to create a new transformer. We are using 'Import' to have ML.NET populate
// this property.
[Import]
public MLContext MLContext { get; set; }

// We are exporting the custom transformer by the name 'IncomeMapping'.
[Export(nameof(IncomeMapping))]
public ITransformer MyCustomTransformer
=> MLContext.Transforms.CustomMappingTransformer<InputRow, OutputRow>(IncomeMapping, nameof(IncomeMapping));
// This factory method will be called when loading the model to get the mapping operation.
public override Action<InputRow, OutputRow> GetMapping()
{
return IncomeMapping;
}
}
```

Expand All @@ -1013,8 +1013,9 @@ using (var fs = File.Create(modelPath))

// Now pretend we are in a different process.
// Create a custom composition container for all our custom mapping actions.
newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(CustomMappings)));
// Register the assembly that contains 'CustomMappings' with the ComponentCatalog
// so it can be found when loading the model.
newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);

// Now we can load the model.
ITransformer loadedModel;
Expand Down
1 change: 0 additions & 1 deletion pkg/Microsoft.ML/Microsoft.ML.nupkgproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
<PackageReference Include="System.CodeDom" Version="$(SystemCodeDomPackageVersion)" />
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
<PackageReference Include="System.ComponentModel.Composition" Version="$(SystemComponentModelCompositionVersion)" />
</ItemGroup>

<ItemGroup>
Expand Down
76 changes: 76 additions & 0 deletions src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ internal ComponentCatalog()
_entryPointMap = new Dictionary<string, EntryPointInfo>();
_componentMap = new Dictionary<string, ComponentInfo>();
_components = new List<ComponentInfo>();

_extensionsMap = new Dictionary<(Type AttributeType, string ContractName), Type>();
}

/// <summary>
Expand Down Expand Up @@ -404,6 +406,8 @@ internal ComponentInfo(Type interfaceType, string kind, Type argumentType, TlcMo
private readonly List<ComponentInfo> _components;
private readonly Dictionary<string, ComponentInfo> _componentMap;

private readonly Dictionary<(Type AttributeType, string ContractName), Type> _extensionsMap;

private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTypes,
out MethodInfo getter, out ConstructorInfo ctor, out MethodInfo create, out bool requireEnvironment)
{
Expand Down Expand Up @@ -618,6 +622,8 @@ public void RegisterAssembly(Assembly assembly, bool throwOnError = true)

AddClass(info, attr.LoadNames, throwOnError);
}

LoadExtensions(assembly, throwOnError);
}
}
}
Expand Down Expand Up @@ -980,5 +986,75 @@ private static void ParseArguments(IHostEnvironment env, object args, string set
if (errorMsg != null)
throw Contracts.Except(errorMsg);
}

private void LoadExtensions(Assembly assembly, bool throwOnError)
{
// don't waste time looking through all the types of an assembly
// that can't contain extensions
if (CanContainExtensions(assembly))
{
foreach (Type type in assembly.GetTypes())
{
if (type.IsClass)
{
foreach (ExtensionBaseAttribute attribute in type.GetCustomAttributes(typeof(ExtensionBaseAttribute)))
{
var key = (AttributeType: attribute.GetType(), attribute.ContractName);
if (_extensionsMap.TryGetValue(key, out var existingType))
{
if (throwOnError)
{
throw Contracts.Except($"An extension for '{key.AttributeType.Name}' with contract '{key.ContractName}' has already been registered in the ComponentCatalog.");
}
}
else
{
_extensionsMap.Add(key, type);
}
}
}
}
}
}

/// <summary>
/// Gets a value indicating whether <paramref name="assembly"/> can contain extensions.
/// </summary>
/// <remarks>
/// All ML.NET product assemblies won't contain extensions.
/// </remarks>
private static bool CanContainExtensions(Assembly assembly)
{
if (assembly.FullName.StartsWith("Microsoft.ML.", StringComparison.Ordinal)
&& HasMLNetPublicKey(assembly))
{
return false;
}

return true;
}

private static bool HasMLNetPublicKey(Assembly assembly)
{
return assembly.GetName().GetPublicKey().SequenceEqual(
typeof(ComponentCatalog).Assembly.GetName().GetPublicKey());
}

[BestFriend]
internal object GetExtensionValue(IHostEnvironment env, Type attributeType, string contractName)
{
object exportedValue = null;
if (_extensionsMap.TryGetValue((attributeType, contractName), out Type extensionType))
{
exportedValue = Activator.CreateInstance(extensionType);
}

if (exportedValue == null)
{
throw env.Except($"Unable to locate an extension for the contract '{contractName}'. Ensure you have called {nameof(ComponentCatalog)}.{nameof(ComponentCatalog.RegisterAssembly)} with the Assembly that contains a class decorated with a '{attributeType.FullName}'.");
}

return exportedValue;
}
}
}
23 changes: 23 additions & 0 deletions src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;

namespace Microsoft.ML
{
/// <summary>
/// The base attribute type for all attributes used for extensibility purposes.
/// </summary>
[AttributeUsage(AttributeTargets.Class)]
public abstract class ExtensionBaseAttribute : Attribute
{
public string ContractName { get; }

[BestFriend]
private protected ExtensionBaseAttribute(string contractName)
{
ContractName = contractName;
}
}
}
7 changes: 0 additions & 7 deletions src/Microsoft.ML.Core/Data/IHostEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.

using System;
using System.ComponentModel.Composition.Hosting;

namespace Microsoft.ML
{
Expand Down Expand Up @@ -92,12 +91,6 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
[Obsolete("The host environment is not disposable, so it is inappropriate to use this method. " +
"Please handle your own temporary files within the component yourself, including their proper disposal and deletion.")]
IFileHandle CreateTempFile(string suffix = null, string prefix = null);

/// <summary>
/// Get the MEF composition container. This can be used to instantiate user-provided 'parts' when the model
/// is being loaded, or the components are otherwise created via dependency injection.
/// </summary>
CompositionContainer GetCompositionContainer();
}

/// <summary>
Expand Down
3 changes: 0 additions & 3 deletions src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.ComponentModel.Composition.Hosting;
using System.IO;

namespace Microsoft.ML.Data
Expand Down Expand Up @@ -632,7 +631,5 @@ public virtual void PrintMessageNormalized(TextWriter writer, string message, bo
else if (!removeLastNewLine)
writer.WriteLine();
}

public virtual CompositionContainer GetCompositionContainer() => new CompositionContainer();
}
}
1 change: 0 additions & 1 deletion src/Microsoft.ML.Core/Microsoft.ML.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
<ProjectReference Include="..\Microsoft.Data.DataView\Microsoft.Data.DataView.csproj" />

<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
<PackageReference Include="System.ComponentModel.Composition" Version="$(SystemComponentModelCompositionVersion)" />
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
</ItemGroup>

Expand Down
22 changes: 3 additions & 19 deletions src/Microsoft.ML.Data/MLContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
// See the LICENSE file in the project root for more information.

using System;
using System.ComponentModel.Composition;
using System.ComponentModel.Composition.Hosting;
using Microsoft.ML.Data;

namespace Microsoft.ML
Expand Down Expand Up @@ -69,9 +67,9 @@ public sealed class MLContext : IHostEnvironment
public event EventHandler<LoggingEventArgs> Log;

/// <summary>
/// This is a MEF composition container catalog to be used for model loading.
/// This is a catalog of components that will be used for model loading.
/// </summary>
public CompositionContainer CompositionContainer { get; set; }
public ComponentCatalog ComponentCatalog => _env.ComponentCatalog;

/// <summary>
/// Create the ML context.
Expand All @@ -80,7 +78,7 @@ public sealed class MLContext : IHostEnvironment
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
public MLContext(int? seed = null, int conc = 0)
{
_env = new LocalEnvironment(seed, conc, MakeCompositionContainer);
_env = new LocalEnvironment(seed, conc);
_env.AddListener(ProcessMessage);

BinaryClassification = new BinaryClassificationCatalog(_env);
Expand All @@ -94,18 +92,6 @@ public MLContext(int? seed = null, int conc = 0)
Data = new DataOperationsCatalog(_env);
}

private CompositionContainer MakeCompositionContainer()
{
if (CompositionContainer == null)
return null;

var mlContext = CompositionContainer.GetExportedValueOrDefault<MLContext>();
if (mlContext == null)
CompositionContainer.ComposeExportedValue<MLContext>(this);

return CompositionContainer;
}

private void ProcessMessage(IMessageSource source, ChannelMessage message)
{
var log = Log;
Expand All @@ -120,14 +106,12 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)

int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor;
bool IHostEnvironment.IsCancelled => _env.IsCancelled;
ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog;
string IExceptionContext.ContextDescription => _env.ContextDescription;
IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix);
TException IExceptionContext.Process<TException>(TException ex) => _env.Process(ex);
IHost IHostEnvironment.Register(string name, int? seed, bool? verbose, int? conc) => _env.Register(name, seed, verbose, conc);
IChannel IChannelProvider.Start(string name) => _env.Start(name);
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
CompositionContainer IHostEnvironment.GetCompositionContainer() => _env.GetCompositionContainer();
}
}
14 changes: 1 addition & 13 deletions src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.

using System;
using System.ComponentModel.Composition.Hosting;

namespace Microsoft.ML.Data
{
Expand All @@ -14,8 +13,6 @@ namespace Microsoft.ML.Data
/// </summary>
internal sealed class LocalEnvironment : HostEnvironmentBase<LocalEnvironment>
{
private readonly Func<CompositionContainer> _compositionContainerFactory;

private sealed class Channel : ChannelBase
{
public readonly Stopwatch Watch;
Expand Down Expand Up @@ -49,11 +46,9 @@ protected override void Dispose(bool disposing)
/// </summary>
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
/// <param name="compositionContainerFactory">The function to retrieve the composition container</param>
public LocalEnvironment(int? seed = null, int conc = 0, Func<CompositionContainer> compositionContainerFactory = null)
public LocalEnvironment(int? seed = null, int conc = 0)
: base(RandomUtils.Create(seed), verbose: false, conc)
{
_compositionContainerFactory = compositionContainerFactory;
}

/// <summary>
Expand Down Expand Up @@ -96,13 +91,6 @@ protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase pare
return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
}

public override CompositionContainer GetCompositionContainer()
{
if (_compositionContainerFactory != null)
return _compositionContainerFactory();
return base.GetCompositionContainer();
}

private sealed class Host : HostBase
{
public Host(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose, int? conc)
Expand Down
47 changes: 47 additions & 0 deletions src/Microsoft.ML.Transforms/CustomMappingFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.Data.DataView;

namespace Microsoft.ML.Transforms
{
/// <summary>
/// Place this attribute onto a type to cause it to be considered a custom mapping factory.
/// </summary>
[AttributeUsage(AttributeTargets.Class)]
public sealed class CustomMappingFactoryAttributeAttribute : ExtensionBaseAttribute
{
public CustomMappingFactoryAttributeAttribute(string contractName)
: base(contractName)
{
}
}

internal interface ICustomMappingFactory
{
ITransformer CreateTransformer(IHostEnvironment env, string contractName);
}

/// <summary>
/// The base type for custom mapping factories.
/// </summary>
/// <typeparam name="TSrc">The type that describes what 'source' columns are consumed from the input <see cref="IDataView"/>.</typeparam>
/// <typeparam name="TDst">The type that describes what new columns are added by this transform.</typeparam>
public abstract class CustomMappingFactory<TSrc, TDst> : ICustomMappingFactory
where TSrc : class, new()
where TDst : class, new()
{
/// <summary>
/// Returns the mapping delegate that maps from <typeparamref name="TSrc"/> inputs to <typeparamref name="TDst"/> outputs.
/// </summary>
public abstract Action<TSrc, TDst> GetMapping();

ITransformer ICustomMappingFactory.CreateTransformer(IHostEnvironment env, string contractName)
{
Action<TSrc, TDst> mapAction = GetMapping();
return new CustomMappingTransformer<TSrc, TDst>(env, mapAction, contractName);
}
}
}
Loading

0 comments on commit 412e1f9

Please sign in to comment.