Skip to content

Commit

Permalink
split allocator_wrapper and default Allocator (#145)
Browse files Browse the repository at this point in the history
Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
LHT129 authored Nov 20, 2024
1 parent e8cabd7 commit f7846e3
Showing 10 changed files with 85 additions and 56 deletions.
74 changes: 74 additions & 0 deletions src/allocator_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "vsag/allocator.h"

namespace vsag {

template <class T>
class AllocatorWrapper {
public:
using value_type = T;
using pointer = T*;
using void_pointer = void*;
using const_void_pointer = const void*;
using size_type = size_t;
using difference_type = std::ptrdiff_t;

AllocatorWrapper(Allocator* allocator) {
this->allocator_ = allocator;
}

template <class U>
AllocatorWrapper(const AllocatorWrapper<U>& other) : allocator_(other.allocator_) {
}

bool
operator==(const AllocatorWrapper& other) const noexcept {
return allocator_ == other.allocator_;
}

inline pointer
allocate(size_type n, const_void_pointer hint = 0) {
return static_cast<pointer>(allocator_->Allocate(n * sizeof(value_type)));
}

inline void
deallocate(pointer p, size_type n) {
allocator_->Deallocate(static_cast<void_pointer>(p));
}

template <class U, class... Args>
inline void
construct(U* p, Args&&... args) {
::new ((void_pointer)p) U(std::forward<Args>(args)...);
}

template <class U>
inline void
destroy(U* p) {
p->~U();
}

template <class U>
struct rebind {
using other = AllocatorWrapper<U>;
};

Allocator* allocator_{};
};
} // namespace vsag
1 change: 1 addition & 0 deletions src/data_cell/graph_interface_test.cpp
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
#include <fstream>

#include "catch2/catch_test_macros.hpp"
#include "default_allocator.h"
#include "fixtures.h"

using namespace vsag;
1 change: 1 addition & 0 deletions src/data_cell/sparse_graph_datacell_test.cpp
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
#include "sparse_graph_datacell.h"

#include "catch2/catch_template_test_macros.hpp"
#include "default_allocator.h"
#include "fmt/format-inl.h"
#include "graph_interface_test.h"

2 changes: 1 addition & 1 deletion src/dataset_impl_test.cpp
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "./dataset_impl.h"
#include "dataset_impl.h"

#include <catch2/catch_test_macros.hpp>

54 changes: 0 additions & 54 deletions src/default_allocator.h
Original file line number Diff line number Diff line change
@@ -15,7 +15,6 @@

#pragma once

#include <map>
#include <memory>
#include <unordered_set>
#include <vector>
@@ -70,57 +69,4 @@ class DefaultAllocator : public Allocator {
#endif
};

template <class T>
class AllocatorWrapper {
public:
using value_type = T;
using pointer = T*;
using void_pointer = void*;
using const_void_pointer = const void*;
using size_type = size_t;
using difference_type = ptrdiff_t;

AllocatorWrapper(Allocator* allocator) {
this->allocator_ = allocator;
}

template <class U>
AllocatorWrapper(const AllocatorWrapper<U>& other) : allocator_(other.allocator_) {
}

bool
operator==(const AllocatorWrapper& other) const noexcept {
return allocator_ == other.allocator_;
}

pointer
allocate(size_type n, const_void_pointer hint = 0) {
return static_cast<pointer>(allocator_->Allocate(n * sizeof(value_type)));
}

void
deallocate(pointer p, size_type n) {
allocator_->Deallocate(static_cast<void_pointer>(p));
}

template <class U, class... Args>
void
construct(U* p, Args&&... args) {
::new ((void_pointer)p) U(std::forward<Args>(args)...);
}

template <class U>
void
destroy(U* p) {
p->~U();
}

template <class U>
struct rebind {
using other = AllocatorWrapper<U>;
};

Allocator* allocator_{};
};

} // namespace vsag
1 change: 1 addition & 0 deletions src/quantization/fp32_quantizer_test.cpp
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
#include <catch2/catch_test_macros.hpp>
#include <memory>

#include "default_allocator.h"
#include "fixtures.h"
#include "quantizer_test.h"

2 changes: 2 additions & 0 deletions src/quantization/sq4_quantizer_test.cpp
Original file line number Diff line number Diff line change
@@ -18,8 +18,10 @@
#include <catch2/catch_test_macros.hpp>
#include <vector>

#include "default_allocator.h"
#include "fixtures.h"
#include "quantizer_test.h"

using namespace vsag;

const auto dims = fixtures::get_common_used_dims();
1 change: 1 addition & 0 deletions src/quantization/sq4_uniform_quantizer_test.cpp
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
#include <catch2/catch_test_macros.hpp>
#include <vector>

#include "default_allocator.h"
#include "fixtures.h"
#include "quantizer_test.h"

1 change: 1 addition & 0 deletions src/quantization/sq8_quantizer_test.cpp
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
#include <catch2/catch_test_macros.hpp>
#include <memory>

#include "default_allocator.h"
#include "fixtures.h"
#include "quantizer_test.h"

4 changes: 3 additions & 1 deletion src/utils.h
Original file line number Diff line number Diff line change
@@ -18,9 +18,11 @@
#include <chrono>
#include <cstdint>
#include <string>
#include <unordered_set>
#include <vector>

#include "default_allocator.h"
#include "allocator_wrapper.h"
#include "logger.h"
#include "spdlog/spdlog.h"
#include "vsag/errors.h"
#include "vsag/expected.hpp"

0 comments on commit f7846e3

Please sign in to comment.