Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/Native/LibTorchSharp/THSTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,16 @@ Tensor THSTensor_gelu_(const Tensor tensor)
CATCH_TENSOR(torch::gelu_(*tensor));
}

Tensor THSTensor_gelu_with_approximate(const Tensor tensor, const char* approximate)
{
CATCH_TENSOR(torch::gelu(*tensor, approximate ? approximate : "none"));
}

Tensor THSTensor_gelu_with_approximate_(const Tensor tensor, const char* approximate)
{
CATCH_TENSOR(torch::gelu_(*tensor, approximate ? approximate : "none"));
}

Tensor THSTensor_get1(const Tensor tensor, int64_t index)
{
CATCH_TENSOR((*tensor)[index]);
Expand Down
2 changes: 2 additions & 0 deletions src/Native/LibTorchSharp/THSTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ EXPORT_API(void) THSTensor_ge_scalar_(const Tensor left, const Scalar right);

EXPORT_API(Tensor) THSTensor_gelu(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_gelu_(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_gelu_with_approximate(const Tensor tensor, const char* approximate);
EXPORT_API(Tensor) THSTensor_gelu_with_approximate_(const Tensor tensor, const char* approximate);

EXPORT_API(Tensor) THSTensor_glu(const Tensor tensor, const int64_t dim);

Expand Down
28 changes: 26 additions & 2 deletions src/TorchSharp/NN/Activation/GELU.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@ namespace Modules
/// </summary>
public sealed class GELU : ParameterLessModule<Tensor, Tensor>
{
internal GELU(bool inplace) : base(nameof(GELU))
internal GELU(bool inplace, GELUApproximate approximate = GELUApproximate.none) : base(nameof(GELU))
{
this.inplace = inplace;
this.approximate = approximate;
}

public override Tensor forward(Tensor tensor)
{
return torch.nn.functional.gelu(tensor, inplace);
return torch.nn.functional.gelu(tensor, approximate, inplace);
}

public bool inplace {get; set; }

public GELUApproximate approximate { get; set; }
}
}

Expand All @@ -49,6 +52,16 @@ public static GELU GELU(bool inplace)
return new GELU(inplace);
}

/// <summary>
/// Gaussian Error Linear Units
/// </summary>
/// <param name="approximate">The approximation method to use. Default: none</param>
/// <param name="inplace">Do the operation in-place. Default: False</param>
public static GELU GELU(GELUApproximate approximate, bool inplace = false)
{
return new GELU(inplace, approximate);
}

public static partial class functional
{
/// <summary>
Expand All @@ -61,6 +74,17 @@ public static Tensor gelu(Tensor x, bool inplace)
return inplace ? x.gelu_().alias() : x.gelu();
}

/// <summary>
/// Gaussian Error Linear Units
/// </summary>
/// <param name="x">The input tensor</param>
/// <param name="approximate">The approximation method to use.</param>
/// <param name="inplace">Do the operation in-place. Default: False</param>
public static Tensor gelu(Tensor x, GELUApproximate approximate, bool inplace = false)
{
return inplace ? x.gelu_(approximate).alias() : x.gelu(approximate);
}

/// <summary>
/// Gaussian Error Linear Units
/// </summary>
Expand Down
6 changes: 6 additions & 0 deletions src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,12 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_gelu_(IntPtr tensor);

[DllImport("LibTorchSharp", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)]
internal static extern IntPtr THSTensor_gelu_with_approximate(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate);

[DllImport("LibTorchSharp", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)]
internal static extern IntPtr THSTensor_gelu_with_approximate_(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_glu(IntPtr tensor, long dim);

Expand Down
18 changes: 18 additions & 0 deletions src/TorchSharp/Tensor/Enums/GELUApproximate.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
namespace TorchSharp
{
/// <summary>
/// Specifies the approximation method for the GELU activation function.
/// </summary>
public enum GELUApproximate
{
/// <summary>
/// Exact GELU computation.
/// </summary>
none,
/// <summary>
/// Tanh-based approximation.
/// </summary>
tanh
}
}
26 changes: 26 additions & 0 deletions src/TorchSharp/Tensor/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2983,6 +2983,19 @@ public Tensor gelu()
return new Tensor(res);
}

public Tensor gelu(GELUApproximate approximate)
{
var approximateStr = approximate switch {
GELUApproximate.none => "none",
GELUApproximate.tanh => "tanh",
_ => throw new ArgumentOutOfRangeException(nameof(approximate), approximate, "Unsupported GELU approximation method.")
};
var res = NativeMethods.THSTensor_gelu_with_approximate(Handle, approximateStr);
if (res == IntPtr.Zero)
CheckForErrors();
return new Tensor(res);
}

public Tensor gelu_()
{
var res = NativeMethods.THSTensor_gelu_(Handle);
Expand All @@ -2991,6 +3004,19 @@ public Tensor gelu_()
return new Tensor(res);
}

public Tensor gelu_(GELUApproximate approximate)
{
var approximateStr = approximate switch {
GELUApproximate.none => "none",
GELUApproximate.tanh => "tanh",
_ => throw new ArgumentOutOfRangeException(nameof(approximate), approximate, "Unsupported GELU approximation method.")
};
var res = NativeMethods.THSTensor_gelu_with_approximate_(Handle, approximateStr);
if (res == IntPtr.Zero)
CheckForErrors();
return new Tensor(res);
}

public Tensor glu(long dim = -1)
{
var res = NativeMethods.THSTensor_glu(Handle, dim);
Expand Down
27 changes: 27 additions & 0 deletions test/TorchSharpTest/NN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,33 @@ public void EvaluateGELU()
}
}

[Fact]
public void EvaluateGELUWithTanhApproximate()
{
var rel = GELU(GELUApproximate.tanh);

foreach (var device in TestUtils.AvailableDevices()) {
var input = torch.randn(new long[] { 64, 8 }, device: device) * 25.0;
var output = rel.call(input);
Assert.Equal(device.type, output.device_type);

var values = output.data<float>().ToArray();
Assert.Equal(input.shape, output.shape);
Assert.All(values, val => Assert.True(val >= -0.2));
}

// Verify that tanh approximate produces different results from exact
var x = torch.tensor(new float[] { -1.0f, 0.0f, 1.0f, 2.0f });
var exact = torch.nn.functional.gelu(x);
var approx = torch.nn.functional.gelu(x, GELUApproximate.tanh);
Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5));

// Verify that the in-place tanh approximate matches the out-of-place result
var xInPlace = x.clone();
xInPlace.gelu_(GELUApproximate.tanh);
Assert.True(approx.allclose(xInPlace, rtol: 1e-5, atol: 1e-5));
}

[Fact]
public void EvaluatePReLU()
{
Expand Down