Skip to content

Commit

Permalink
enhance string util functions
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Aug 28, 2024
1 parent d2a1b7a commit 3fa5115
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 42 deletions.
22 changes: 12 additions & 10 deletions include/onnxruntime/core/common/make_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,29 @@
#include <sstream>
#include <type_traits>

#include "core/util/force_inline.h"

namespace onnxruntime {

namespace detail {

inline void MakeStringImpl(std::ostringstream& /*ss*/) noexcept {
ORT_FORCEINLINE void MakeStringImpl(std::ostringstream& /*ss*/) noexcept {
}

template <typename T>
inline void MakeStringImpl(std::ostringstream& ss, const T& t) noexcept {
ORT_FORCEINLINE void MakeStringImpl(std::ostringstream& ss, const T& t) noexcept {
ss << t;
}

template <typename T, typename... Args>
inline void MakeStringImpl(std::ostringstream& ss, const T& t, const Args&... args) noexcept {
ORT_FORCEINLINE void MakeStringImpl(std::ostringstream& ss, const T& t, const Args&... args) noexcept {
MakeStringImpl(ss, t);
MakeStringImpl(ss, args...);
}

// see MakeString comments for explanation of why this is necessary
template <typename... Args>
inline std::string MakeStringImpl(const Args&... args) noexcept {
ORT_FORCEINLINE std::string MakeStringImpl(const Args&... args) noexcept {
std::ostringstream ss;
MakeStringImpl(ss, args...);
return ss.str();
Expand Down Expand Up @@ -78,7 +80,7 @@ using if_char_array_make_ptr_t = typename if_char_array_make_ptr<T>::type;
* This version uses the current locale.
*/
template <typename... Args>
std::string MakeString(const Args&... args) {
ORT_FORCEINLINE std::string MakeString(const Args&... args) {
// We need to update the types from the MakeString template instantiation to decay any char[n] to char*.
// e.g. MakeString("in", "out") goes from MakeString<char[2], char[3]> to MakeStringImpl<char*, char*>
// so that MakeString("out", "in") will also match MakeStringImpl<char*, char*> instead of requiring
Expand All @@ -98,7 +100,7 @@ std::string MakeString(const Args&... args) {
* This version uses std::locale::classic().
*/
template <typename... Args>
std::string MakeStringWithClassicLocale(const Args&... args) {
ORT_FORCEINLINE std::string MakeStringWithClassicLocale(const Args&... args) {
std::ostringstream ss;
ss.imbue(std::locale::classic());
detail::MakeStringImpl(ss, args...);
Expand All @@ -107,19 +109,19 @@ std::string MakeStringWithClassicLocale(const Args&... args) {

// MakeString versions for already-a-string types.

inline std::string MakeString(const std::string& str) {
ORT_FORCEINLINE std::string MakeString(const std::string& str) {
return str;
}

inline std::string MakeString(const char* cstr) {
ORT_FORCEINLINE std::string MakeString(const char* cstr) {
return cstr;
}

inline std::string MakeStringWithClassicLocale(const std::string& str) {
ORT_FORCEINLINE std::string MakeStringWithClassicLocale(const std::string& str) {
return str;
}

inline std::string MakeStringWithClassicLocale(const char* cstr) {
ORT_FORCEINLINE std::string MakeStringWithClassicLocale(const char* cstr) {

Check warning on line 124 in include/onnxruntime/core/common/make_string.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/common/make_string.h:124: Add #include <string> for string [build/include_what_you_use] [4]
return cstr;
}

Expand Down
71 changes: 71 additions & 0 deletions include/onnxruntime/core/common/string_join.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <locale>
#include <sstream>
#include <type_traits>

#include "core/common/make_string.h"

namespace onnxruntime {

namespace detail {

template <typename Separator>
ORT_FORCEINLINE void StringJoinImpl(Separator&& separator, std::ostringstream& ss) noexcept {
}

template <typename Separator, typename T>
ORT_FORCEINLINE void StringJoinImpl(Separator&& separator, std::ostringstream& ss, T&& t) noexcept {
ss << std::forward<Separator>(separator) << std::forward<T>(t);
}

template <typename Separator, typename T, typename... Args>
ORT_FORCEINLINE void StringJoinImpl(Separator&& separator, std::ostringstream& ss, T&& t, Args&&... args) noexcept {
StringJoinImpl(std::forward<Separator>(separator), ss, std::forward<T>(t));
StringJoinImpl(std::forward<Separator>(separator), ss, std::forward<Args>(args)...);
}

template <typename Separator, typename... Args>
ORT_FORCEINLINE std::string StringJoinImpl(Separator&& separator, Args&&... args) noexcept {
std::ostringstream ss;
StringJoinImpl(std::forward<Separator>(separator), ss, std::forward<Args>(args)...);
return ss.str();
}

/**
* Makes a string by concatenating string representations of the arguments using the specified separator.
* Uses std::locale::classic()
*/
template <typename Separator, typename... Args>
ORT_FORCEINLINE std::string StringJoinImplWithClassicLocale(Separator&& separator, Args&&... args) noexcept {
std::ostringstream ss;
ss.imbue(std::locale::classic());
StringJoinImpl(std::forward<Separator>(separator), ss, std::forward<Args>(args)...);
return ss.str();
}
} // namespace detail

/**
* Makes a string by concatenating string representations of the arguments using the specified separator.
*/
template <typename Separator, typename... Args>
ORT_FORCEINLINE std::string StringJoin(Separator&& separator, Args&&... args) {
return detail::StringJoinImpl(separator, std::forward<Args>(args)...);

Check warning on line 56 in include/onnxruntime/core/common/string_join.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for forward [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/common/string_join.h:56: Add #include <utility> for forward [build/include_what_you_use] [4]
}

// StringJoin versions for already-a-string types.

template <typename Separator>
ORT_FORCEINLINE std::string StringJoin(Separator&& /* separator */, const std::string& str) {
return str;
}

template <typename Separator>
ORT_FORCEINLINE std::string StringJoin(Separator&& /* separator */, const char* cstr) {

Check warning on line 67 in include/onnxruntime/core/common/string_join.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/common/string_join.h:67: Add #include <string> for string [build/include_what_you_use] [4]
return cstr;
}

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@
#include <algorithm>
#include <cmath>

#include "core/util/force_inline.h"

namespace onnxruntime {
namespace contrib {

#if defined(_MSC_VER)
#define FORCEINLINE __forceinline
#else
#define FORCEINLINE __attribute__((always_inline)) inline
#endif

typedef enum Bnb_DataType_t {
FP4 = 0,
NF4 = 1,
} Bnb_DataType_t;

FORCEINLINE uint8_t QuantizeOneFP4(float x) {
ORT_FORCEINLINE uint8_t QuantizeOneFP4(float x) {
// FP4 with bias of 3
// first bit is a sign
// subnormals
Expand Down Expand Up @@ -69,7 +65,7 @@ FORCEINLINE uint8_t QuantizeOneFP4(float x) {
}
}

FORCEINLINE uint8_t QuantizeOneNF4(float x) {
ORT_FORCEINLINE uint8_t QuantizeOneNF4(float x) {
if (x > 0.03979014977812767f) {
if (x > 0.3893125355243683f) { // 1
if (x > 0.6427869200706482f) { // 11
Expand Down Expand Up @@ -120,15 +116,15 @@ FORCEINLINE uint8_t QuantizeOneNF4(float x) {
}

template <int32_t DATA_TYPE>
FORCEINLINE uint8_t QuantizeOneBnb4(float x) {
ORT_FORCEINLINE uint8_t QuantizeOneBnb4(float x) {
if constexpr (DATA_TYPE == FP4)
return QuantizeOneFP4(x);
else
return QuantizeOneNF4(x);
}

template <typename T, int32_t block_size, int32_t DATA_TYPE>
FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) {
ORT_FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) {
float local_absmax = 0.0f;

int32_t block_len = std::min(block_size, numel - block_idx * block_size);
Expand Down Expand Up @@ -177,15 +173,15 @@ static float nf4_qaunt_map[16] = {-1.0f,
1.0f};

template <typename T, int32_t DATA_TYPE>
FORCEINLINE T DequantizeOneBnb4(uint8_t x) {
ORT_FORCEINLINE T DequantizeOneBnb4(uint8_t x) {
if constexpr (DATA_TYPE == FP4)
return static_cast<T>(fp4_qaunt_map[x]);
else
return static_cast<T>(nf4_qaunt_map[x]);
}

template <typename T, int32_t block_size, int32_t DATA_TYPE>
FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) {
ORT_FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) {
int32_t block_len = std::min(block_size, numel - block_idx * block_size);
int32_t src_offset = block_idx * block_size / 2;
int32_t dst_offset = block_idx * block_size;
Expand Down
14 changes: 6 additions & 8 deletions onnxruntime/core/framework/murmurhash3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

#include "core/framework/endian.h"

#include "core/util/force_inline.h"

//-----------------------------------------------------------------------------
// Platform-specific functions and macros

// Microsoft Visual Studio

#if defined(_MSC_VER)

#define FORCE_INLINE __forceinline

#include <stdlib.h>

#define ROTL32(x, y) _rotl(x, y)
Expand All @@ -37,8 +37,6 @@

#else // defined(_MSC_VER)

#define FORCE_INLINE inline __attribute__((always_inline))

inline uint32_t rotl32(uint32_t x, int8_t r) {
return (x << r) | (x >> (32 - r));
}
Expand All @@ -61,7 +59,7 @@ inline uint64_t rotl64(uint64_t x, int8_t r) {
//
// Changes to support big-endian from https://github.com/explosion/murmurhash/pull/27/
// were manually applied to original murmurhash3 source code.
FORCE_INLINE uint32_t getblock32(const uint32_t* p, int i) {
ORT_FORCEINLINE uint32_t getblock32(const uint32_t* p, int i) {
if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) {
return p[i];
} else {
Expand All @@ -73,7 +71,7 @@ FORCE_INLINE uint32_t getblock32(const uint32_t* p, int i) {
}
}

FORCE_INLINE uint64_t getblock64(const uint64_t* p, int i) {
ORT_FORCEINLINE uint64_t getblock64(const uint64_t* p, int i) {
if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) {
return p[i];
} else {
Expand All @@ -92,7 +90,7 @@ FORCE_INLINE uint64_t getblock64(const uint64_t* p, int i) {
//-----------------------------------------------------------------------------
// Finalization mix - force all bits of a hash block to avalanche

FORCE_INLINE constexpr uint32_t fmix32(uint32_t h) {
ORT_FORCEINLINE constexpr uint32_t fmix32(uint32_t h) {
h ^= h >> 16;
h *= 0x85ebca6b;
h ^= h >> 13;
Expand All @@ -104,7 +102,7 @@ FORCE_INLINE constexpr uint32_t fmix32(uint32_t h) {

//----------

FORCE_INLINE constexpr uint64_t fmix64(uint64_t k) {
ORT_FORCEINLINE constexpr uint64_t fmix64(uint64_t k) {
k ^= k >> 33;
k *= BIG_CONSTANT(0xff51afd7ed558ccd);
k ^= k >> 33;
Expand Down
10 changes: 3 additions & 7 deletions onnxruntime/core/providers/cpu/tensor/gather_elements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "gather_elements.h"
#include "onnxruntime_config.h"

#include "core/util/force_inline.h"

namespace onnxruntime {

ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Expand Down Expand Up @@ -66,14 +68,8 @@ static inline size_t CalculateOffset(size_t inner_dim, const TensorPitches& inpu
return base_offset;
}

#if defined(_MSC_VER)
#define FORCEINLINE __forceinline
#else
#define FORCEINLINE __attribute__((always_inline)) inline
#endif

template <typename T>
FORCEINLINE int64_t GetIndex(size_t i, const T* indices, int64_t axis_size) {
ORT_FORCEINLINE int64_t GetIndex(size_t i, const T* indices, int64_t axis_size) {
int64_t index = indices[i];
if (index < 0) // Handle negative indices
index += axis_size;
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/util/force_inline.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#if defined(_MSC_VER)
#define ORT_FORCEINLINE __forceinline
#else
#define ORT_FORCEINLINE __attribute__((always_inline)) inline
#endif
6 changes: 1 addition & 5 deletions onnxruntime/core/util/matrix_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@
#include <cstdint>
#include <gsl/gsl>

#if defined(_MSC_VER)
#define ORT_FORCEINLINE __forceinline
#else
#define ORT_FORCEINLINE __attribute__((always_inline)) inline
#endif
#include "core/util/force_inline.h"

namespace onnxruntime {

Expand Down

0 comments on commit 3fa5115

Please sign in to comment.