Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split avx from avx2 #252

Merged
merged 1 commit into from
Dec 26, 2024
Merged

split avx from avx2 #252

merged 1 commit into from
Dec 26, 2024

Conversation

LHT129
Copy link
Collaborator

@LHT129 LHT129 commented Dec 24, 2024

  • refactor all simd functions, just like FP32
  • add openblas for ubuntu-aarch deps

@LHT129 LHT129 added the kind/improvement Code improvements (variable/function renaming, refactoring, etc. ) label Dec 24, 2024
@LHT129 LHT129 self-assigned this Dec 24, 2024
@LHT129 LHT129 force-pushed the avx branch 5 times, most recently from 59a72c4 to 17d89c8 Compare December 24, 2024 15:53
@@ -13,18 +13,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "simd/simd.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function test cases should not depend on any internal definitions or interfaces, so don't include header files from src/

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some basic func like l2sqr, innerproduct is common in functest & internal, If I write it twice, is it a bit redundant? The correctness of the simd module has been guaranteed by unittest.

src/simd/sq8_simd.h Outdated Show resolved Hide resolved
src/simd/sq8_simd.h Outdated Show resolved Hide resolved
src/simd/sq8_simd.h Outdated Show resolved Hide resolved
src/simd/sq8_simd.h Outdated Show resolved Hide resolved
sum = _mm512_add_ps(sum, _mm512_mul_ps(a, b)); // accumulate the product
__m512 a = _mm512_loadu_ps(query + i * 16); // load 16 floats from memory
__m512 b = _mm512_loadu_ps(codes + i * 16); // load 16 floats from memory
sum = _mm512_fmadd_ps(a, b, sum); // accumulate the product
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice opt

@LHT129 LHT129 force-pushed the avx branch 3 times, most recently from 6207ce7 to c80c83b Compare December 25, 2024 03:00
# FIXME(LHT): cause illegal instruction on platform which has avx only
#if (DIST_CONTAINS_AVX2)
# set_source_files_properties (avx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
#endif ()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what?

@@ -56,7 +57,7 @@ endmacro ()

simd_add_definitions (DIST_CONTAINS_SSE -DENABLE_SSE=1)
simd_add_definitions (DIST_CONTAINS_AVX -DENABLE_AVX=1)
#simd_add_definitions (DIST_CONTAINS_AVX2 -DENABLE_AVX2=1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

result += avx::SQ8ComputeL2Sqr(query + i, codes + i, lower_bound + i, diff + i, dim - i);
return result;
#else
return vsag::avx::SQ8ComputeL2Sqr(query, codes, lower_bound, diff, dim); // TODO
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avx::SQ8ComputeL2Sqr

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return result;
#else
return generic::SQ8ComputeCodesIP(codes1, codes2, lowerBound, diff, dim);
return generic::SQ8ComputeCodesIP(codes1, codes2, lower_bound, diff, dim);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here we use generic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return result;
#else
return generic::SQ8ComputeCodesL2Sqr(codes1, codes2, lowerBound, diff, dim);
return generic::SQ8ComputeCodesL2Sqr(codes1, codes2, lower_bound, diff, dim);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

} // namespace avx512

} // namespace vsag
} // namespace vsag::avx512
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add new line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lint warning here

}
ret.dist_support_avx512f = true;
ret.dist_support_avx512dq = true;
ret.dist_support_avx512bw = true;
ret.dist_support_avx512vl = true;
#endif

return ret;
}

DistanceFunc
GetInnerProductDistanceFunc(size_t dim) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the functions? Just use vsag::InnerProductDistance replace it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too much changes, in this pr ,only split the function, but not change usage

@@ -10,9 +10,9 @@ endif ()
if (DIST_CONTAINS_AVX)
target_compile_definitions (unittests PRIVATE ENABLE_AVX=1)
endif ()
#if (DIST_CONTAINS_AVX2)
# target_compile_definitions (unittests PRIVATE ENABLE_AVX2=1)
#endif ()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

# FIXME(LHT): cause illegal instruction on platform which has avx only
#if (DIST_CONTAINS_AVX2)
# set_source_files_properties (avx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
#endif ()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what?

@@ -56,7 +57,7 @@ endmacro ()

simd_add_definitions (DIST_CONTAINS_SSE -DENABLE_SSE=1)
simd_add_definitions (DIST_CONTAINS_AVX -DENABLE_AVX=1)
#simd_add_definitions (DIST_CONTAINS_AVX2 -DENABLE_AVX2=1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

src/simd/CMakeLists.txt Show resolved Hide resolved
@@ -28,6 +28,10 @@ GetFP32ComputeIP() {
} else if (SimdStatus::SupportAVX2()) {
#if defined(ENABLE_AVX2)
return avx2::FP32ComputeIP;
#endif
} else if (SimdStatus::SupportAVX()) {
#if defined(ENABLE_AVX)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not ENABLE_AVX then the if statement will be empty ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

#endif
} else if (SimdStatus::SupportAVX()) {
#if defined(ENABLE_AVX)
return avx::FP32ComputeL2Sqr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

src/simd/generic.cpp Show resolved Hide resolved
@@ -28,6 +28,10 @@ GetNormalize() {
} else if (SimdStatus::SupportAVX2()) {
#if defined(ENABLE_AVX2)
return avx2::Normalize;
#endif
} else if (SimdStatus::SupportAVX()) {
#if defined(ENABLE_AVX)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];

return sum;
InnerProduct(const void* pVect1v, const void* pVect2v, const void* qty_ptr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@LHT129 LHT129 force-pushed the avx branch 2 times, most recently from 9c80eb4 to c279e9f Compare December 25, 2024 14:35
Copy link
Collaborator

@jiaweizone jiaweizone left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@wxyucs wxyucs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

- refactor all simd functions, just like FP32
- add openblas for ubuntu-aarch deps
- hgraph sq recall reduce to 0.96

Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
@LHT129 LHT129 merged commit b55e104 into antgroup:main Dec 26, 2024
8 checks passed
@LHT129 LHT129 deleted the avx branch December 26, 2024 06:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kind/improvement Code improvements (variable/function renaming, refactoring, etc. ) size/XXL version/0.12
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants