Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
colgreen committed Jul 23, 2019
2 parents 4d1c5b6 + c9e5f8b commit 0b81901
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* ***************************************************************************
* This file is part of SharpNEAT - Evolution of Neural Networks.
*
* Copyright 2004-2016 Colin Green (sharpneat@gmail.com)
*
* Written by Scott DeBoer, 2019 (naigonakoii@gmail.com)
*
* SharpNEAT is free software; you can redistribute it and/or modify
* it under the terms of The MIT License (MIT).
*
* You should have received a copy of the MIT License
* along with SharpNEAT; if not, see https://opensource.org/licenses/MIT.
*/
using System;
using System.Collections.Generic;

namespace SharpNeat.Network
{
public static class ActivationFunctionRegistry
{
// The default set of activation functions.
private static Dictionary<string, IActivationFunction> __activationFunctionTable = new Dictionary<string, IActivationFunction>()
{
// Bipolar.
{ "BipolarSigmoid", BipolarSigmoid.__DefaultInstance },
{ "BipolarGaussian", BipolarGaussian.__DefaultInstance },
{ "Linear", Linear.__DefaultInstance },
{ "Sine", Sine.__DefaultInstance },

// Unipolar.
{ "ArcSinH", ArcSinH.__DefaultInstance },
{ "ArcTan", ArcTan.__DefaultInstance },
{ "Gaussian", Gaussian.__DefaultInstance },
{ "LeakyReLU", LeakyReLU.__DefaultInstance },
{ "LeakyReLUShifted", LeakyReLUShifted.__DefaultInstance },
{ "LogisticFunction", LogisticFunction.__DefaultInstance },
{ "LogisticFunctionSteep", LogisticFunctionSteep.__DefaultInstance },
{ "MaxMinusOne", MaxMinusOne.__DefaultInstance },
{ "PolynomialApproximantSteep", PolynomialApproximantSteep.__DefaultInstance },
{ "QuadraticSigmoid", QuadraticSigmoid.__DefaultInstance },
{ "ReLU", ReLU.__DefaultInstance },
{ "ScaledELU", ScaledELU.__DefaultInstance },
{ "SoftSignSteep", SoftSignSteep.__DefaultInstance },
{ "SReLU", SReLU.__DefaultInstance },
{ "SReLUShifted", SReLUShifted.__DefaultInstance },
{ "TanH", TanH.__DefaultInstance },

// Radial Basis.
{ "RbfGaussian", RbfGaussian.__DefaultInstance },
};

#region Public Static Methods

/// <summary>
/// Registers a custom activation function in addition to those in the default activation function table.
/// Alows loading of neural nets from XML that use custom activation functions.
/// </summary>
public static void RegisterActivationFunction(IActivationFunction function)
{
if(!__activationFunctionTable.ContainsKey(function.FunctionId)) {
__activationFunctionTable.Add(function.FunctionId, function);
}
}

/// <summary>
/// Gets an IActivationFunction with the given short name.
/// </summary>
public static IActivationFunction GetActivationFunction(string name)
{
if(!__activationFunctionTable.ContainsKey(name)) {
throw new ArgumentException($"Unexpected activation function [{name}]");
}
return __activationFunctionTable[name];
}

#endregion
}
}
40 changes: 1 addition & 39 deletions src/SharpNeatLib/Network/NetworkXmlIO.cs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ public static IActivationFunctionLibrary ReadActivationFunctionLibrary(XmlReader
string fnName = xrSubtree.GetAttribute(__AttrName);

// Lookup function name.
IActivationFunction activationFn = GetActivationFunction(fnName);
IActivationFunction activationFn = ActivationFunctionRegistry.GetActivationFunction(fnName);

// Add new function to our list of functions.
ActivationFunctionInfo fnInfo = new ActivationFunctionInfo(id, selectionProb, activationFn);
Expand Down Expand Up @@ -528,44 +528,6 @@ public static string GetNodeTypeString(NodeType nodeType)
throw new ArgumentException($"Unexpected NodeType [{nodeType}]");
}

/// <summary>
/// Gets an IActivationFunction from its short name.
/// </summary>
public static IActivationFunction GetActivationFunction(string name)
{
switch(name)
{
// Bipolar.
case "BipolarGaussian": return BipolarGaussian.__DefaultInstance;
case "BipolarSigmoid": return BipolarSigmoid.__DefaultInstance;
case "Linear": return Linear.__DefaultInstance;
case "Sine": return Sine.__DefaultInstance;

// Unipolar.
case "ArcSinH": return ArcSinH.__DefaultInstance;
case "ArcTan": return ArcTan.__DefaultInstance;
case "Gaussian": return Gaussian.__DefaultInstance;
case "LeakyReLU": return LeakyReLU.__DefaultInstance;
case "LeakyReLUShifted": return LeakyReLUShifted.__DefaultInstance;
case "LogisticFunction": return LogisticFunction.__DefaultInstance;
case "LogisticFunctionSteep": return LogisticFunctionSteep.__DefaultInstance;
case "MaxMinusOne": return MaxMinusOne.__DefaultInstance;
case "PolynomialApproximantSteep": return PolynomialApproximantSteep.__DefaultInstance;
case "QuadraticSigmoid": return QuadraticSigmoid.__DefaultInstance;
case "ReLU": return ReLU.__DefaultInstance;
case "ScaledELU": return ScaledELU.__DefaultInstance;
case "SoftSignSteep": return SoftSignSteep.__DefaultInstance;
case "SReLU": return SReLU.__DefaultInstance;
case "SReLUShifted": return SReLUShifted.__DefaultInstance;
case "TanH": return TanH.__DefaultInstance;

// Radial Basis.
case "RbfGaussian": return RbfGaussian.__DefaultInstance;

}
throw new ArgumentException($"Unexpected activation function [{name}]");
}

#endregion

#region Private Static Methods
Expand Down

0 comments on commit 0b81901

Please sign in to comment.