Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 56 additions & 50 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,75 +31,81 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
else()
option(HNSWLIB_EXAMPLES "Build examples and tests." OFF)
endif()

# Option to treat warnings as errors (useful for CI)
option(HNSWLIB_WERROR "Treat warnings as errors" OFF)

if(HNSWLIB_EXAMPLES)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 14)

# Common warning flags for GCC/Clang
set(HNSWLIB_WARNING_FLAGS
-Wall
-Wextra
-Wno-unknown-pragmas
)

if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
SET( CMAKE_CXX_FLAGS "-Ofast -std=c++11 -DHAVE_CXX0X -openmp -fpic -ftree-vectorize" )
set(CMAKE_CXX_FLAGS "-O3 -DHAVE_CXX0X -fpic -ftree-vectorize")
check_cxx_compiler_flag("-march=native" COMPILER_SUPPORT_NATIVE_FLAG)
if(COMPILER_SUPPORT_NATIVE_FLAG)
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native" )
message("set -march=native flag")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native")
message(STATUS "set -march=native flag")
else()
check_cxx_compiler_flag("-mcpu=apple-m1" COMPILER_SUPPORT_M1_FLAG)
if(COMPILER_SUPPORT_M1_FLAG)
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=apple-m1" )
message("set -mcpu=apple-m1 flag")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=apple-m1")
message(STATUS "set -mcpu=apple-m1 flag")
endif()
endif()
# Add OpenMP if available
find_package(OpenMP QUIET)
if(OpenMP_CXX_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
SET( CMAKE_CXX_FLAGS "-Ofast -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" )
set(CMAKE_CXX_FLAGS "-O3 -lrt -DHAVE_CXX0X -march=native -fpic -fopenmp -ftree-vectorize")
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
SET( CMAKE_CXX_FLAGS "/O2 -DHAVE_CXX0X /W1 /openmp /EHsc" )
set(CMAKE_CXX_FLAGS "/O2 -DHAVE_CXX0X /W3 /openmp /EHsc")
set(HNSWLIB_WARNING_FLAGS /W3)
endif()

# examples
add_executable(example_search examples/cpp/example_search.cpp)
target_link_libraries(example_search hnswlib)

add_executable(example_epsilon_search examples/cpp/example_epsilon_search.cpp)
target_link_libraries(example_epsilon_search hnswlib)

add_executable(example_multivector_search examples/cpp/example_multivector_search.cpp)
target_link_libraries(example_multivector_search hnswlib)

add_executable(example_filter examples/cpp/example_filter.cpp)
target_link_libraries(example_filter hnswlib)

add_executable(example_replace_deleted examples/cpp/example_replace_deleted.cpp)
target_link_libraries(example_replace_deleted hnswlib)

add_executable(example_mt_search examples/cpp/example_mt_search.cpp)
target_link_libraries(example_mt_search hnswlib)
# Add warning-as-error flag if requested
if(HNSWLIB_WERROR)
if(MSVC)
list(APPEND HNSWLIB_WARNING_FLAGS /WX)
else()
list(APPEND HNSWLIB_WARNING_FLAGS -Werror)
endif()
endif()

add_executable(example_mt_filter examples/cpp/example_mt_filter.cpp)
target_link_libraries(example_mt_filter hnswlib)
# Helper function to add warning flags to targets
function(hnswlib_add_executable target_name)
add_executable(${target_name} ${ARGN})
target_link_libraries(${target_name} hnswlib)
target_compile_options(${target_name} PRIVATE ${HNSWLIB_WARNING_FLAGS})
endfunction()

add_executable(example_mt_replace_deleted examples/cpp/example_mt_replace_deleted.cpp)
target_link_libraries(example_mt_replace_deleted hnswlib)
# examples
hnswlib_add_executable(example_search examples/cpp/example_search.cpp)
hnswlib_add_executable(example_epsilon_search examples/cpp/example_epsilon_search.cpp)
hnswlib_add_executable(example_multivector_search examples/cpp/example_multivector_search.cpp)
hnswlib_add_executable(example_filter examples/cpp/example_filter.cpp)
hnswlib_add_executable(example_replace_deleted examples/cpp/example_replace_deleted.cpp)
hnswlib_add_executable(example_mt_search examples/cpp/example_mt_search.cpp)
hnswlib_add_executable(example_mt_filter examples/cpp/example_mt_filter.cpp)
hnswlib_add_executable(example_mt_replace_deleted examples/cpp/example_mt_replace_deleted.cpp)

# tests
add_executable(multivector_search_test tests/cpp/multivector_search_test.cpp)
target_link_libraries(multivector_search_test hnswlib)

add_executable(epsilon_search_test tests/cpp/epsilon_search_test.cpp)
target_link_libraries(epsilon_search_test hnswlib)

add_executable(test_updates tests/cpp/updates_test.cpp)
target_link_libraries(test_updates hnswlib)

add_executable(searchKnnCloserFirst_test tests/cpp/searchKnnCloserFirst_test.cpp)
target_link_libraries(searchKnnCloserFirst_test hnswlib)

add_executable(searchKnnWithFilter_test tests/cpp/searchKnnWithFilter_test.cpp)
target_link_libraries(searchKnnWithFilter_test hnswlib)

add_executable(multiThreadLoad_test tests/cpp/multiThreadLoad_test.cpp)
target_link_libraries(multiThreadLoad_test hnswlib)

add_executable(multiThread_replace_test tests/cpp/multiThread_replace_test.cpp)
target_link_libraries(multiThread_replace_test hnswlib)
hnswlib_add_executable(multivector_search_test tests/cpp/multivector_search_test.cpp)
hnswlib_add_executable(epsilon_search_test tests/cpp/epsilon_search_test.cpp)
hnswlib_add_executable(test_updates tests/cpp/updates_test.cpp)
hnswlib_add_executable(searchKnnCloserFirst_test tests/cpp/searchKnnCloserFirst_test.cpp)
hnswlib_add_executable(searchKnnWithFilter_test tests/cpp/searchKnnWithFilter_test.cpp)
hnswlib_add_executable(multiThreadLoad_test tests/cpp/multiThreadLoad_test.cpp)
hnswlib_add_executable(multiThread_replace_test tests/cpp/multiThread_replace_test.cpp)

add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp)
target_link_libraries(main hnswlib)
target_compile_options(main PRIVATE ${HNSWLIB_WARNING_FLAGS})
endif()
8 changes: 4 additions & 4 deletions examples/cpp/example_mt_filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ int main() {
}

// Add data to index
ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) {
ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t /*threadId*/) {
alg_hnsw->addPoint((void*)(data + dim * row), row);
});

Expand All @@ -104,13 +104,13 @@ int main() {

// Query the elements for themselves with filter and check returned labels
int k = 10;
std::vector<hnswlib::labeltype> neighbors(max_elements * k);
ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) {
std::vector<hnswlib::labeltype> neighbors(static_cast<size_t>(max_elements * k));
ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t /*threadId*/) {
std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + dim * row, k, &pickIdsDivisibleByTwo);
for (int i = 0; i < k; i++) {
hnswlib::labeltype label = result.top().second;
result.pop();
neighbors[row * k + i] = label;
neighbors[row * static_cast<size_t>(k) + static_cast<size_t>(i)] = label;
}
});

Expand Down
8 changes: 4 additions & 4 deletions examples/cpp/example_mt_replace_deleted.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ int main() {
}

// Add data to index
ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) {
ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t /*threadId*/) {
alg_hnsw->addPoint((void*)(data + dim * row), row);
});

// Mark first half of elements as deleted
int num_deleted = max_elements / 2;
ParallelFor(0, num_deleted, num_threads, [&](size_t row, size_t threadId) {
ParallelFor(0, num_deleted, num_threads, [&](size_t row, size_t /*threadId*/) {
alg_hnsw->markDelete(row);
});

Expand All @@ -102,8 +102,8 @@ int main() {
// Replace deleted data with new elements
// Maximum number of elements is reached therefore we cannot add new items,
// but we can replace the deleted ones by using replace_deleted=true
ParallelFor(0, num_deleted, num_threads, [&](size_t row, size_t threadId) {
hnswlib::labeltype label = max_elements + row;
ParallelFor(0, num_deleted, num_threads, [&](size_t row, size_t /*threadId*/) {
hnswlib::labeltype label = static_cast<hnswlib::labeltype>(max_elements) + row;
alg_hnsw->addPoint((void*)(add_data + dim * row), label, true);
});

Expand Down
8 changes: 4 additions & 4 deletions examples/cpp/example_mt_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,23 @@ int main() {
}

// Add data to index
ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) {
ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t /*threadId*/) {
alg_hnsw->addPoint((void*)(data + dim * row), row);
});

// Query the elements for themselves and measure recall
std::vector<hnswlib::labeltype> neighbors(max_elements);
ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) {
ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t /*threadId*/) {
std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + dim * row, 1);
hnswlib::labeltype label = result.top().second;
neighbors[row] = label;
});
float correct = 0;
for (int i = 0; i < max_elements; i++) {
hnswlib::labeltype label = neighbors[i];
if (label == i) correct++;
if (label == static_cast<hnswlib::labeltype>(i)) correct++;
}
float recall = correct / max_elements;
float recall = correct / static_cast<float>(max_elements);
std::cout << "Recall: " << recall << "\n";

delete[] data;
Expand Down
8 changes: 4 additions & 4 deletions examples/cpp/example_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ int main() {
for (int i = 0; i < max_elements; i++) {
std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + i * dim, 1);
hnswlib::labeltype label = result.top().second;
if (label == i) correct++;
if (label == static_cast<hnswlib::labeltype>(i)) correct++;
}
float recall = correct / max_elements;
float recall = correct / static_cast<float>(max_elements);
std::cout << "Recall: " << recall << "\n";

// Serialize index
Expand All @@ -47,9 +47,9 @@ int main() {
for (int i = 0; i < max_elements; i++) {
std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + i * dim, 1);
hnswlib::labeltype label = result.top().second;
if (label == i) correct++;
if (label == static_cast<hnswlib::labeltype>(i)) correct++;
}
recall = (float)correct / max_elements;
recall = correct / static_cast<float>(max_elements);
std::cout << "Recall of deserialized index: " << recall << "\n";

delete[] data;
Expand Down
20 changes: 10 additions & 10 deletions hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
std::unordered_map<labeltype, size_t > dict_external_to_internal;


BruteforceSearch(SpaceInterface <dist_t> *s)
BruteforceSearch(SpaceInterface <dist_t>* /*s*/)
: data_(nullptr),
maxelements_(0),
cur_element_count(0),
Expand All @@ -49,7 +49,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
data_ = (char *) malloc(maxElements * size_per_element_);
data_ = static_cast<char*>(malloc(maxElements * size_per_element_));
if (data_ == nullptr)
throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
cur_element_count = 0;
Expand All @@ -61,8 +61,8 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
}


void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) {
int idx;
void addPoint(const void *datapoint, labeltype label, bool /*replace_deleted*/ = false) {
size_t idx;
{
std::unique_lock<std::mutex> lock(index_lock);

Expand Down Expand Up @@ -94,7 +94,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
dict_external_to_internal.erase(found);

size_t cur_c = found->second;
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
labeltype label = *reinterpret_cast<labeltype*>(data_ + size_per_element_ * (cur_element_count-1) + data_size_);
dict_external_to_internal[label] = cur_c;
memcpy(data_ + size_per_element_ * cur_c,
data_ + size_per_element_ * (cur_element_count-1),
Expand All @@ -108,18 +108,18 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
assert(k <= cur_element_count);
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
if (cur_element_count == 0) return topResults;
for (int i = 0; i < k; i++) {
for (size_t i = 0; i < k; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
labeltype label = *reinterpret_cast<labeltype*>(data_ + size_per_element_ * i + data_size_);
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.emplace(dist, label);
}
}
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
for (int i = k; i < cur_element_count; i++) {
for (size_t i = k; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
if (dist <= lastdist) {
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
labeltype label = *reinterpret_cast<labeltype*>(data_ + size_per_element_ * i + data_size_);
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.emplace(dist, label);
}
Expand Down Expand Up @@ -161,7 +161,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
data_ = (char *) malloc(maxelements_ * size_per_element_);
data_ = static_cast<char*>(malloc(maxelements_ * size_per_element_));
if (data_ == nullptr)
throw std::runtime_error("Not enough memory: loadIndex failed to allocate data");

Expand Down
Loading