diff --git a/CHANGELOG.md b/CHANGELOG.md index 91ff59ba89..07e668086a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Fixed - Fixed parsing of date histogram buckets ([#131](https://github.com/opensearch-project/opensearch-net/pull/131)) +- Allow passing both boolean and integer values to `TrackTotalHits` ([#121](https://github.com/opensearch-project/opensearch-net/pull/121)) ### Security - CVE-2019-0820: Removed transitive dependencies on `System.Text.RegularExpressions` from internal packages; **Client Not Impacted** ([#137](https://github.com/opensearch-project/opensearch-net/pull/137)) diff --git a/src/OpenSearch.Client/CommonAbstractions/Infer/TrackTotalHits/TrackTotalHits.cs b/src/OpenSearch.Client/CommonAbstractions/Infer/TrackTotalHits/TrackTotalHits.cs new file mode 100644 index 0000000000..5c176f7cf4 --- /dev/null +++ b/src/OpenSearch.Client/CommonAbstractions/Infer/TrackTotalHits/TrackTotalHits.cs @@ -0,0 +1,95 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System; +using System.Diagnostics; +using System.Globalization; +using OpenSearch.Net; +using OpenSearch.Net.Utf8Json; + +namespace OpenSearch.Client +{ + [JsonFormatter(typeof(TrackTotalHitsFormatter))] + [DebuggerDisplay("{DebugDisplay,nq}")] + public class TrackTotalHits : IEquatable, IUrlParameter + { + public TrackTotalHits(bool trackTotalHits) + { + Tag = 0; + BoolValue = trackTotalHits; + } + + public TrackTotalHits(long trackTotalHitsUpTo) + { + Tag = 1; + LongValue = trackTotalHitsUpTo; + } + + public bool Equals(TrackTotalHits other) + { + if (Tag != other?.Tag) return false; + + return Tag switch + { + 0 => BoolValue == other.BoolValue, + 1 => LongValue == other.LongValue, + _ => false + }; + } + + private byte Tag { get; } + + internal bool? BoolValue { get; } + + internal long? LongValue { get; } + + private string BoolOrLongValue => BoolValue?.ToString(CultureInfo.InvariantCulture) ?? LongValue?.ToString(CultureInfo.InvariantCulture); + + private string DebugDisplay => BoolOrLongValue; + + public override string ToString() => BoolOrLongValue; + + public string GetString(IConnectionConfigurationValues settings) => BoolOrLongValue; + + public static implicit operator TrackTotalHits(bool trackTotalHits) => new(trackTotalHits); + public static implicit operator TrackTotalHits(bool? trackTotalHits) => trackTotalHits is {} b ? new TrackTotalHits(b) : null; + + public static implicit operator TrackTotalHits(long trackTotalHitsUpTo) => new(trackTotalHitsUpTo); + public static implicit operator TrackTotalHits(long? trackTotalHitsUpTo) => trackTotalHitsUpTo is {} l ? new TrackTotalHits(l) : null; + + public override bool Equals(object obj) + { + if (ReferenceEquals(null, obj)) return false; + if (ReferenceEquals(this, obj)) return true; + + return obj switch + { + TrackTotalHits t => Equals(t), + bool b => Equals(b), + long l => Equals(l), + _ => false + }; + } + + private static int TypeHashCode { get; } = typeof(TrackTotalHits).GetHashCode(); + + public override int GetHashCode() + { + unchecked + { + var result = TypeHashCode; + result = (result * 397) ^ (BoolValue?.GetHashCode() ?? 0); + result = (result * 397) ^ (LongValue?.GetHashCode() ?? 0); + return result; + } + } + + public static bool operator ==(TrackTotalHits left, TrackTotalHits right) => Equals(left, right); + + public static bool operator !=(TrackTotalHits left, TrackTotalHits right) => !Equals(left, right); + } +} diff --git a/src/OpenSearch.Client/CommonAbstractions/Infer/TrackTotalHits/TrackTotalHitsFormatter.cs b/src/OpenSearch.Client/CommonAbstractions/Infer/TrackTotalHits/TrackTotalHitsFormatter.cs new file mode 100644 index 0000000000..ac27e26f2c --- /dev/null +++ b/src/OpenSearch.Client/CommonAbstractions/Infer/TrackTotalHits/TrackTotalHitsFormatter.cs @@ -0,0 +1,29 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using OpenSearch.Net.Utf8Json; + +namespace OpenSearch.Client +{ + internal class TrackTotalHitsFormatter : IJsonFormatter + { + public void Serialize(ref JsonWriter writer, TrackTotalHits value, IJsonFormatterResolver formatterResolver) + { + if (value?.LongValue is { } l) + writer.WriteInt64(l); + else if (value?.BoolValue is { } b) + writer.WriteBoolean(b); + else + writer.WriteNull(); + } + + public TrackTotalHits Deserialize(ref JsonReader reader, IJsonFormatterResolver formatterResolver) => + reader.GetCurrentJsonToken() == JsonToken.Number + ? new TrackTotalHits(reader.ReadInt64()) + : new TrackTotalHits(reader.ReadBoolean()); + } +} diff --git a/src/OpenSearch.Client/Requests.NoNamespace.cs b/src/OpenSearch.Client/Requests.NoNamespace.cs index d861152255..517d0f5017 100644 --- a/src/OpenSearch.Client/Requests.NoNamespace.cs +++ b/src/OpenSearch.Client/Requests.NoNamespace.cs @@ -2654,7 +2654,7 @@ Fields StoredFields } [DataMember(Name = "track_total_hits")] - bool? TrackTotalHits + TrackTotalHits TrackTotalHits { get; set; diff --git a/src/OpenSearch.Client/Search/Search/SearchRequest.cs b/src/OpenSearch.Client/Search/Search/SearchRequest.cs index cdb1408c30..e5863fed8d 100644 --- a/src/OpenSearch.Client/Search/Search/SearchRequest.cs +++ b/src/OpenSearch.Client/Search/Search/SearchRequest.cs @@ -270,7 +270,7 @@ public partial class SearchRequest /// public bool? TrackScores { get; set; } /// - public bool? TrackTotalHits { get; set; } + public TrackTotalHits TrackTotalHits { get; set; } /// public bool? Version { get; set; } /// @@ -333,7 +333,7 @@ public partial class SearchDescriptor where TInferDocument : cla long? ISearchRequest.TerminateAfter { get; set; } string ISearchRequest.Timeout { get; set; } bool? ISearchRequest.TrackScores { get; set; } - bool? ISearchRequest.TrackTotalHits { get; set; } + TrackTotalHits ISearchRequest.TrackTotalHits { get; set; } bool? ISearchRequest.Version { get; set; } IRuntimeFields ISearchRequest.RuntimeFields { get; set; } @@ -506,7 +506,10 @@ public SearchDescriptor Rescore(Func a.Rescore = v?.Invoke(new RescoringDescriptor()).Value); /// - public SearchDescriptor TrackTotalHits(bool? trackTotalHits = true) => Assign(trackTotalHits, (a, v) => a.TrackTotalHits = v); + public SearchDescriptor TrackTotalHits() => Assign(true, (a, v) => a.TrackTotalHits = v); + + /// + public SearchDescriptor TrackTotalHits(TrackTotalHits trackTotalHits) => Assign(trackTotalHits, (a, v) => a.TrackTotalHits = v); /// public SearchDescriptor RuntimeFields(Func, IPromise> runtimeFieldsSelector) => diff --git a/src/OpenSearch.Net/Api/RequestParameters/RequestParameters.NoNamespace.cs b/src/OpenSearch.Net/Api/RequestParameters/RequestParameters.NoNamespace.cs index e668d7bc64..1a52c7bcae 100644 --- a/src/OpenSearch.Net/Api/RequestParameters/RequestParameters.NoNamespace.cs +++ b/src/OpenSearch.Net/Api/RequestParameters/RequestParameters.NoNamespace.cs @@ -1878,9 +1878,9 @@ public bool? TotalHitsAsInteger } ///Indicate if the number of documents that match the query should be tracked - public bool? TrackTotalHits + public string TrackTotalHits { - get => Q("track_total_hits"); + get => Q("track_total_hits"); set => Q("track_total_hits", value); }