Skip to content

Commit

Permalink
Added the new Mish activation module and huber_loss()
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasGustafsson committed Jun 22, 2021
1 parent 4d79dc1 commit 2b1068e
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/Native/LibTorchSharp/THSActivation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ Tensor THSNN_LogSoftmax_forward(const NNModule module, const Tensor tensor)
CATCH_TENSOR((*module)->as<torch::nn::LogSoftmax>()->forward(*tensor));
}

NNModule THSNN_Mish_ctor(NNAnyModule* outAsAnyModule)
{
CATCH_RETURN_NNModule(
res = create_module<torch::nn::MishImpl>(outAsAnyModule);
);
}

Tensor THSNN_Mish_forward(const NNModule module, const Tensor tensor)
{
CATCH_TENSOR((*module)->as<torch::nn::Mish>()->forward(*tensor));
}

NNModule THSNN_ReLU_ctor(bool inplace, NNAnyModule* outAsAnyModule)
{
CATCH_RETURN_NNModule(
Expand Down
10 changes: 10 additions & 0 deletions src/Native/LibTorchSharp/THSLoss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ Tensor THSNN_hinge_embedding_loss(const Tensor input, const Tensor target, const
)
}

Tensor THSNN_huber_loss(const Tensor input, const Tensor target, const double delta, const int64_t reduction)
{
CATCH_RETURN_Tensor(
auto opts = torch::nn::functional::HuberLossFuncOptions().delta(delta);
ApplyReduction(opts, reduction);
res = ResultTensor(torch::nn::functional::huber_loss(*input, *target, opts));
)
}


Tensor THSNN_kl_div_loss(const Tensor input, const Tensor target, const int64_t reduction, const bool log_target)
{
CATCH_RETURN_Tensor(
Expand Down
3 changes: 3 additions & 0 deletions src/Native/LibTorchSharp/THSNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ EXPORT_API(NNModule) THSNN_GELU_ctor(NNAnyModule* outAsAnyModule);
EXPORT_API(Tensor) THSNN_GELU_forward(const NNModule module, const Tensor tensor);
EXPORT_API(NNModule) THSNN_LeakyReLU_ctor(const double negative_sloope, const bool inplace, NNAnyModule* outAsAnyModule);
EXPORT_API(Tensor) THSNN_LeakyReLU_forward(const NNModule module, const Tensor tensor);
EXPORT_API(NNModule) THSNN_Mish_ctor(NNAnyModule* outAsAnyModule);
EXPORT_API(Tensor) THSNN_Mish_forward(const NNModule module, const Tensor tensor);
EXPORT_API(NNModule) THSNN_ReLU_ctor(bool inplace, NNAnyModule* outAsAnyModule);
EXPORT_API(Tensor) THSNN_ReLU_forward(const NNModule module, const Tensor tensor);
EXPORT_API(NNModule) THSNN_ReLU6_ctor(bool inplace, NNAnyModule* outAsAnyModule);
Expand Down Expand Up @@ -294,6 +296,7 @@ EXPORT_API(Tensor) THSNN_cosine_embedding_loss(const Tensor input1, const Tensor
EXPORT_API(Tensor) THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction);
EXPORT_API(Tensor) THSNN_ctc_loss(const Tensor log_probs, const Tensor targets, const Tensor input_lengths, const Tensor target_lengths, int64_t blank, bool zero_infinity, const int64_t reduction);
EXPORT_API(Tensor) THSNN_hinge_embedding_loss(const Tensor input, const Tensor target, const double margin, const int64_t reduction);
EXPORT_API(Tensor) THSNN_huber_loss(const Tensor input, const Tensor target, const double delta, const int64_t reduction);
EXPORT_API(Tensor) THSNN_l1_loss(const Tensor input, const Tensor target, const int64_t reduction);
EXPORT_API(Tensor) THSNN_margin_ranking_loss(const Tensor input1, const Tensor input2, const Tensor target, const double margin, const int64_t reduction);
EXPORT_API(Tensor) THSNN_mse_loss(const Tensor input, const Tensor target, const int64_t reduction);
Expand Down
62 changes: 62 additions & 0 deletions src/TorchSharp/NN/Activation/Mish.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) Microsoft Corporation and contributors. All Rights Reserved. See License.txt in the project root for license information.
using System;
using System.Runtime.InteropServices;
using TorchSharp.Tensor;

namespace TorchSharp.NN
{
/// <summary>
/// This class is used to represent a Mish module.
/// </summary>
public class Mish : Module
{
internal Mish (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle) { }

[DllImport ("LibTorchSharp")]
private static extern IntPtr THSNN_Mish_forward (Module.HType module, IntPtr tensor);

public override TorchTensor forward (TorchTensor tensor)
{
var res = THSNN_Mish_forward (handle, tensor.Handle);
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
}

public override string GetName ()
{
return typeof (Mish).Name;
}
}

public static partial class Modules
{
[DllImport ("LibTorchSharp")]
extern static IntPtr THSNN_Mish_ctor(out IntPtr pBoxedModule);

/// <summary>
/// A Self Regularized Non-Monotonic Neural Activation Function.
/// </summary>
/// <returns></returns>
static public Mish Mish ()
{
var handle = THSNN_Mish_ctor (out var boxedHandle);
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
return new Mish (handle, boxedHandle);
}
}
public static partial class Functions
{
/// <summary>
/// A Self Regularized Non-Monotonic Neural Activation Function.
/// </summary>
/// <param name="x">The input tensor</param>
/// <returns></returns>
static public TorchTensor Mish (TorchTensor x)
{
using (var m = Modules.Mish()) {
return m.forward (x);
}
}
}

}
21 changes: 21 additions & 0 deletions src/TorchSharp/NN/Losses.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,27 @@ public static Loss hinge_embedding_loss(double margin = 0.0, Reduction reduction
};
}

[DllImport("LibTorchSharp")]
private static extern IntPtr THSNN_huber_loss(IntPtr input, IntPtr trgt, double delta, long reduction);

/// <summary>
/// Creates a criterion that uses a squared term if the absolute element-wise error falls below delta and a delta-scaled L1 term otherwise.
///
/// See: https://pytorch.org/docs/stable/generated/torch.nn.HuberLoss.html#torch.nn.HuberLoss
/// </summary>
/// <param name="delta">Specifies the threshold at which to change between delta-scaled L1 and L2 loss. The value must be positive. Default: 1.0</param>
/// <param name="reduction">Specifies the reduction to apply to the output</param>
/// <returns></returns>
public static Loss huber_loss(double delta = 1.0, Reduction reduction = Reduction.Mean)
{
return (TorchTensor input, TorchTensor target) => {
var res = THSNN_huber_loss(input.Handle, target.Handle, delta, (long)reduction);
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor(res);
};
}


[DllImport("LibTorchSharp")]
private static extern IntPtr THSNN_margin_ranking_loss(IntPtr input1, IntPtr input2, IntPtr target, double margin, long reduction);

Expand Down
22 changes: 22 additions & 0 deletions test/TorchSharpTest/NN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,16 @@ public void EvaluateLeakyRelu()
Assert.Equal(input.shape, output.shape);
}

[Fact]
public void EvaluateMish()
{
var rel = Mish();
var input = Float32Tensor.randn(new long[] { 64, 8 });
var output = rel.forward(input);
var values = output.Data<float>().ToArray();
Assert.Equal(input.shape, output.shape);
}

[Fact]
public void EvaluateRRelu()
{
Expand Down Expand Up @@ -633,6 +643,18 @@ public void TestHingeEmbeddingLoss()
}
}

[Fact]
public void TestHuberLoss()
{
using (TorchTensor input = Float32Tensor.randn(new long[] { 15, 5 }, requiresGrad: true))
using (TorchTensor target = Float32Tensor.randn(new long[] { 15, 5 }).sign()) {
var outTensor = huber_loss()(input, target);
outTensor.backward();
outTensor = huber_loss(1.5)(input, target);
outTensor.backward();
}
}

[Fact]
public void TestMarginRankingLoss()
{
Expand Down

0 comments on commit 2b1068e

Please sign in to comment.