Skip to content

Commit

Permalink
update Managed NN code
Browse files Browse the repository at this point in the history
  • Loading branch information
daelsepara committed Oct 13, 2018
1 parent 44080f6 commit 1d85cdd
Showing 1 changed file with 55 additions and 8 deletions.
63 changes: 55 additions & 8 deletions DeepLearnMac/ManagedNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,42 @@ ManagedArray Labels(ManagedArray output, NeuralNetworkOptions opts)
return result;
}

public ManagedArray Predict(ManagedArray test, NeuralNetworkOptions opts)
{
Forward(test);

var prediction = new ManagedArray(test.y);

for (int y = 0; y < test.y; y++)
{
if (opts.Categories > 1)
{
double maxval = Double.MinValue;

for (int x = 0; x < opts.Categories; x++)
{
double val = Yk[x, y];

if (val > maxval)
{
maxval = val;
}
}

prediction[y] = maxval;
}
else
{
prediction[y] = Yk[y];
}
}

// cleanup of arrays allocated in Forward
ManagedOps.Free(A2, Yk, Z2);

return prediction;
}

public ManagedIntList Classify(ManagedArray test, NeuralNetworkOptions opts, double threshold = 0.5)
{
Forward(test);
Expand Down Expand Up @@ -217,17 +253,28 @@ public ManagedIntList Classify(ManagedArray test, NeuralNetworkOptions opts, dou
return classification;
}

public void Setup(ManagedArray output, NeuralNetworkOptions opts)
public void SetupLabels(ManagedArray output, NeuralNetworkOptions opts)
{
Wji = new ManagedArray(opts.Inputs + 1, opts.Nodes);
Wkj = new ManagedArray(opts.Nodes + 1, opts.Categories);

Y_output = Labels(output, opts);
}

public void Setup(ManagedArray output, NeuralNetworkOptions opts, bool Reset = true)
{
if (Reset)
{
Wji = new ManagedArray(opts.Inputs + 1, opts.Nodes);
Wkj = new ManagedArray(opts.Nodes + 1, opts.Categories);
}

SetupLabels(output, opts);

var random = new Random(Guid.NewGuid().GetHashCode());

Rand(Wji, random);
Rand(Wkj, random);
if (Reset)
{
Rand(Wji, random);
Rand(Wkj, random);
}

Cost = 1.0;
L2 = 1.0;
Expand Down Expand Up @@ -337,9 +384,9 @@ public FuncOutput OptimizerCost(double[] X)
return new FuncOutput(Cost, X);
}

public void SetupOptimizer(ManagedArray input, ManagedArray output, NeuralNetworkOptions opts)
public void SetupOptimizer(ManagedArray input, ManagedArray output, NeuralNetworkOptions opts, bool Reset = true)
{
Setup(output, opts);
Setup(output, opts, Reset);

Optimizer.MaxIterations = opts.Epochs;

Expand Down

0 comments on commit 1d85cdd

Please sign in to comment.