diff --git a/src/SharpNeatLib/Network/ActivationFunctions/ActivationFunctionRegistry.cs b/src/SharpNeatLib/Network/ActivationFunctions/ActivationFunctionRegistry.cs new file mode 100644 index 00000000..139a5e18 --- /dev/null +++ b/src/SharpNeatLib/Network/ActivationFunctions/ActivationFunctionRegistry.cs @@ -0,0 +1,78 @@ +/* *************************************************************************** + * This file is part of SharpNEAT - Evolution of Neural Networks. + * + * Copyright 2004-2016 Colin Green (sharpneat@gmail.com) + * + * Written by Scott DeBoer, 2019 (naigonakoii@gmail.com) + * + * SharpNEAT is free software; you can redistribute it and/or modify + * it under the terms of The MIT License (MIT). + * + * You should have received a copy of the MIT License + * along with SharpNEAT; if not, see https://opensource.org/licenses/MIT. + */ +using System; +using System.Collections.Generic; + +namespace SharpNeat.Network +{ + public static class ActivationFunctionRegistry + { + // The default set of activation functions. + private static Dictionary __activationFunctionTable = new Dictionary() + { + // Bipolar. + { "BipolarSigmoid", BipolarSigmoid.__DefaultInstance }, + { "BipolarGaussian", BipolarGaussian.__DefaultInstance }, + { "Linear", Linear.__DefaultInstance }, + { "Sine", Sine.__DefaultInstance }, + + // Unipolar. + { "ArcSinH", ArcSinH.__DefaultInstance }, + { "ArcTan", ArcTan.__DefaultInstance }, + { "Gaussian", Gaussian.__DefaultInstance }, + { "LeakyReLU", LeakyReLU.__DefaultInstance }, + { "LeakyReLUShifted", LeakyReLUShifted.__DefaultInstance }, + { "LogisticFunction", LogisticFunction.__DefaultInstance }, + { "LogisticFunctionSteep", LogisticFunctionSteep.__DefaultInstance }, + { "MaxMinusOne", MaxMinusOne.__DefaultInstance }, + { "PolynomialApproximantSteep", PolynomialApproximantSteep.__DefaultInstance }, + { "QuadraticSigmoid", QuadraticSigmoid.__DefaultInstance }, + { "ReLU", ReLU.__DefaultInstance }, + { "ScaledELU", ScaledELU.__DefaultInstance }, + { "SoftSignSteep", SoftSignSteep.__DefaultInstance }, + { "SReLU", SReLU.__DefaultInstance }, + { "SReLUShifted", SReLUShifted.__DefaultInstance }, + { "TanH", TanH.__DefaultInstance }, + + // Radial Basis. + { "RbfGaussian", RbfGaussian.__DefaultInstance }, + }; + + #region Public Static Methods + + /// + /// Registers a custom activation function in addition to those in the default activation function table. + /// Alows loading of neural nets from XML that use custom activation functions. + /// + public static void RegisterActivationFunction(IActivationFunction function) + { + if(!__activationFunctionTable.ContainsKey(function.FunctionId)) { + __activationFunctionTable.Add(function.FunctionId, function); + } + } + + /// + /// Gets an IActivationFunction with the given short name. + /// + public static IActivationFunction GetActivationFunction(string name) + { + if(!__activationFunctionTable.ContainsKey(name)) { + throw new ArgumentException($"Unexpected activation function [{name}]"); + } + return __activationFunctionTable[name]; + } + + #endregion + } +} diff --git a/src/SharpNeatLib/Network/NetworkXmlIO.cs b/src/SharpNeatLib/Network/NetworkXmlIO.cs index c0536e38..39a553c4 100644 --- a/src/SharpNeatLib/Network/NetworkXmlIO.cs +++ b/src/SharpNeatLib/Network/NetworkXmlIO.cs @@ -459,7 +459,7 @@ public static IActivationFunctionLibrary ReadActivationFunctionLibrary(XmlReader string fnName = xrSubtree.GetAttribute(__AttrName); // Lookup function name. - IActivationFunction activationFn = GetActivationFunction(fnName); + IActivationFunction activationFn = ActivationFunctionRegistry.GetActivationFunction(fnName); // Add new function to our list of functions. ActivationFunctionInfo fnInfo = new ActivationFunctionInfo(id, selectionProb, activationFn); @@ -528,44 +528,6 @@ public static string GetNodeTypeString(NodeType nodeType) throw new ArgumentException($"Unexpected NodeType [{nodeType}]"); } - /// - /// Gets an IActivationFunction from its short name. - /// - public static IActivationFunction GetActivationFunction(string name) - { - switch(name) - { - // Bipolar. - case "BipolarGaussian": return BipolarGaussian.__DefaultInstance; - case "BipolarSigmoid": return BipolarSigmoid.__DefaultInstance; - case "Linear": return Linear.__DefaultInstance; - case "Sine": return Sine.__DefaultInstance; - - // Unipolar. - case "ArcSinH": return ArcSinH.__DefaultInstance; - case "ArcTan": return ArcTan.__DefaultInstance; - case "Gaussian": return Gaussian.__DefaultInstance; - case "LeakyReLU": return LeakyReLU.__DefaultInstance; - case "LeakyReLUShifted": return LeakyReLUShifted.__DefaultInstance; - case "LogisticFunction": return LogisticFunction.__DefaultInstance; - case "LogisticFunctionSteep": return LogisticFunctionSteep.__DefaultInstance; - case "MaxMinusOne": return MaxMinusOne.__DefaultInstance; - case "PolynomialApproximantSteep": return PolynomialApproximantSteep.__DefaultInstance; - case "QuadraticSigmoid": return QuadraticSigmoid.__DefaultInstance; - case "ReLU": return ReLU.__DefaultInstance; - case "ScaledELU": return ScaledELU.__DefaultInstance; - case "SoftSignSteep": return SoftSignSteep.__DefaultInstance; - case "SReLU": return SReLU.__DefaultInstance; - case "SReLUShifted": return SReLUShifted.__DefaultInstance; - case "TanH": return TanH.__DefaultInstance; - - // Radial Basis. - case "RbfGaussian": return RbfGaussian.__DefaultInstance; - - } - throw new ArgumentException($"Unexpected activation function [{name}]"); - } - #endregion #region Private Static Methods