Skip to content

Commit be02fde

Browse files
committed
🔨 version 1.0.3 - Add GetMaxElements, GetCurrentElementCount, GetDeleteCount, GetVectorByLabel APIs
1 parent 7a3eb65 commit be02fde

File tree

7 files changed

+142
-28
lines changed

7 files changed

+142
-28
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ Hnswlib to go. Golang interface to hnswlib(https://github.com/nmslib/hnswlib). T
66

77
### Version
88

9+
* version 1.0.3
10+
* Add `GetMaxElements`, `GetCurrentElementCount`, `GetDeleteCount`, `GetVectorByLabel` APIs
11+
912
* version 1.0.2
1013
* Update hnswlib compatible version to 0.7.0
1114
* Add `AddBatchPoints`, `SearchBatchKNN`, `SetNormalize`, `ResizeIndex`, `MarkDelete`, `UnmarkDelete`, `GetLabelIsMarkedDeleted` API

example/demo.go

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package main
33
import (
44
"fmt"
55
"math/rand"
6+
"reflect"
67
"runtime"
78
"time"
89

@@ -43,26 +44,35 @@ func randVector(dim int) []float32 {
4344
}
4445

4546
// 单个写入
46-
func exampleAddPoint(indexFileName string) {
47+
func exampleAddPoint(indexFileName string) []float32 {
4748
var dim, M, ef = 128, 32, 300
4849
// 最大的 elements 数
49-
var maxElements uint32 = 10000
50+
var maxElements uint32 = 100
5051
// 定义距离 cosine
5152
var spaceType = "cosine"
52-
var randomSeed = 100
53+
var randomSeed = 2000
5354
fmt.Println("Before Create HNSW")
5455
traceMemStats()
5556
// Init new index
5657
h := hnswgo.New(dim, M, ef, randomSeed, maxElements, spaceType)
58+
59+
// randomIndex to test the api GetVectorByLabel
60+
var randomIndex []float32
61+
5762
// Insert 1000 vectors to index. Label Type is uint32
5863
var i uint32
5964
for ; i < maxElements; i++ {
6065
if i%1000 == 0 {
6166
fmt.Println(i)
6267
}
63-
h.AddPoint(randVector(dim), i)
68+
randVec := randVector(dim)
69+
h.AddPoint(randVec, i)
70+
if i == 0 {
71+
randomIndex = randVec
72+
}
6473
}
6574
h.Save(indexFileName)
75+
return randomIndex
6676
}
6777

6878
// 批量写入
@@ -97,7 +107,7 @@ func exampleBatchAddPoint(indexFileName string) {
97107
}
98108

99109
// 读取
100-
func exampleLoadIndex(indexFileName, spaceType string, dim int) {
110+
func exampleLoadIndex(indexFileName, spaceType string, dim int) []float32 {
101111
h := hnswgo.Load(indexFileName, dim, spaceType)
102112
// Search vector with maximum 5 NN
103113
h.SetEf(15)
@@ -109,36 +119,63 @@ func exampleLoadIndex(indexFileName, spaceType string, dim int) {
109119
fmt.Println(endTime - startTime)
110120
fmt.Println(labels, vectors)
111121

122+
// Test GetMaxElements API Before Resize
123+
maxElementsBeforeResize := h.GetMaxElements()
124+
currentElementsBeforeResize := h.GetCurrentElementCount()
125+
fmt.Println("maxElements, currentElements(before resize): ", maxElementsBeforeResize, currentElementsBeforeResize)
126+
112127
// Test ResizeIndex API
113128
isResize := h.ResizeIndex(12000)
114129
fmt.Println("Size flag: ", isResize)
115130

131+
// Test GetMaxElements API After Resize
132+
maxElementsAfterResize := h.GetMaxElements()
133+
currentElementsAfterResize := h.GetCurrentElementCount()
134+
fmt.Println("maxElements, currentElements(after resize): ", maxElementsAfterResize, currentElementsAfterResize)
135+
136+
// Test GetDeleteCount API
137+
deleteCountBeforeDelete := h.GetDeleteCount()
138+
fmt.Println("GetDeleteCount(before): ", deleteCountBeforeDelete)
139+
116140
// Test Mark API
117141
isMarkDelete := h.MarkDelete(10)
118142
fmt.Println("isMarkDelete: ", isMarkDelete)
119143

120144
labelIsDelete := h.GetLabelIsMarkedDeleted(10)
121145
fmt.Println("labelIsDelete: ", labelIsDelete)
122146

147+
// Test GetDeleteCount API
148+
deleteCountBeforeAfter := h.GetDeleteCount()
149+
fmt.Println("GetDeleteCount(after): ", deleteCountBeforeAfter)
150+
123151
isUnmarkDelete := h.UnmarkDelete(10)
124152
fmt.Println("isUnmarkDelete: ", isUnmarkDelete)
125153

154+
// Test GetVectorByLabel API
155+
getVectorByIdRes := h.GetVectorByLabel(0, dim)
156+
fmt.Println("Vector: ", getVectorByIdRes)
157+
126158
// Test Unload API
127159
fmt.Println("Before Unload")
128160
traceMemStats()
129161
h.Unload()
130162
fmt.Println("After Unload")
131163
traceMemStats()
164+
165+
return getVectorByIdRes
132166
}
133167

134168
func main() {
135169
// 单条写入 add index point by point
136-
exampleAddPoint("hnsw_demo_single.bin")
170+
demoVector := exampleAddPoint("hnsw_demo_single.bin")
137171
// 测试读取 test loading
138-
exampleLoadIndex("hnsw_demo_single.bin", "cosine", 128)
172+
demoSearchVector := exampleLoadIndex("hnsw_demo_single.bin", "cosine", 128)
173+
// test GetVectorByLabel API
174+
isEqual := reflect.DeepEqual(demoVector, demoSearchVector)
175+
fmt.Println("GetVectorByLabel return data is equal: ", isEqual)
139176

140177
// 批量写入 add index with batch mode
141-
//exampleBatchAddPoint("hnsw_demo_multiple.bin")
178+
exampleBatchAddPoint("hnsw_demo_multiple.bin")
142179
// 测试读取 test loading
143-
//exampleLoadIndex("hnsw_demo_multiple.bin", "cosine", 128)
180+
exampleLoadIndex("hnsw_demo_multiple.bin", "cosine", 128)
144181
}

hnsw.go

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
package hnswgo
22

3-
// #cgo LDFLAGS: -L${SRCDIR} -lhnsw -lm
4-
// #include <stdlib.h>
5-
// #include <stdbool.h>
6-
// #include "hnsw_wrapper.h"
7-
// HNSW initHNSW(int dim, unsigned long int max_elements, int M, int ef_construction, int rand_seed, char stype);
8-
// HNSW loadHNSW(char *location, int dim, char stype);
9-
// void addPoint(HNSW index, float *vec, unsigned long int label);
10-
// int searchKnn(HNSW index, float *vec, int N, unsigned long int *label, float *dist);
11-
// void setEf(HNSW index, int ef);
12-
// bool resizeIndex(HNSW index, unsigned long int new_max_elements);
13-
// bool markDelete(HNSW index, unsigned long int label);
14-
// bool unmarkDelete(HNSW index, unsigned long int label);
15-
// bool isMarkedDeleted(HNSW index, unsigned long int label);
16-
// bool updatePoint(HNSW index, float *vec, unsigned long int label);
3+
import "C"
4+
5+
/*
6+
#cgo CXXFLAGS: -std=c++11
7+
#cgo LDFLAGS: -L${SRCDIR} -lhnsw -lm
8+
#include <stdlib.h>
9+
#include <stdbool.h>
10+
#include "hnsw_wrapper.h"
11+
12+
HNSW initHNSW(int dim, unsigned long int max_elements, int M, int ef_construction, int rand_seed, char stype);
13+
HNSW loadHNSW(char *location, int dim, char stype);
14+
void addPoint(HNSW index, float *vec, unsigned long int label);
15+
int searchKnn(HNSW index, float *vec, int N, unsigned long int *label, float *dist);
16+
void setEf(HNSW index, int ef);
17+
bool resizeIndex(HNSW index, unsigned long int new_max_elements);
18+
bool markDelete(HNSW index, unsigned long int label);
19+
bool unmarkDelete(HNSW index, unsigned long int label);
20+
bool isMarkedDeleted(HNSW index, unsigned long int label);
21+
bool updatePoint(HNSW index, float *vec, unsigned long int label);
22+
23+
void getDataByLabel(HNSW index, unsigned long int label, float* out_data);
24+
*/
1725
import "C"
1826
import (
1927
"math"
@@ -22,6 +30,13 @@ import (
2230
"unsafe"
2331
)
2432

33+
func toSlice(v *C.float, len int) []float32 {
34+
// 创建一个指向C数组的slice
35+
slice := (*[1 << 30]float32)(unsafe.Pointer(v))[:len:len]
36+
// 复制slice的值,将其转换为一个新的Go切片
37+
return append([]float32(nil), slice...)
38+
}
39+
2540
type HNSW struct {
2641
index C.HNSW
2742
spaceType string
@@ -224,3 +239,32 @@ func (h *HNSW) GetLabelIsMarkedDeleted(label uint32) bool {
224239
isDelete := bool(C.isMarkedDeleted(h.index, C.ulong(label)))
225240
return isDelete
226241
}
242+
243+
// GetMaxElements get index max elements
244+
func (h *HNSW) GetMaxElements() int {
245+
maxElements := int(C.getMaxElements(h.index))
246+
return maxElements
247+
}
248+
249+
// GetCurrentElementCount get index current elements
250+
func (h *HNSW) GetCurrentElementCount() int {
251+
elementCnt := int(C.getCurrentElementCount(h.index))
252+
return elementCnt
253+
}
254+
255+
// GetDeleteCount get index count which mark deleted
256+
func (h *HNSW) GetDeleteCount() int {
257+
deleteElementCnt := int(C.getDeleteCount(h.index))
258+
return deleteElementCnt
259+
}
260+
261+
// GetVectorByLabel get index by label
262+
func (h *HNSW) GetVectorByLabel(label uint32, dim int) []float32 {
263+
var outDataPtr C.float
264+
C.getDataByLabel(h.index, C.ulong(label), &outDataPtr)
265+
outData := make([]float32, dim)
266+
for i := 0; i < dim; i++ {
267+
outData[i] = float32(*(*C.float)(unsafe.Pointer(uintptr(unsafe.Pointer(&outDataPtr)) + uintptr(i)*unsafe.Sizeof(C.float(0)))))
268+
}
269+
return outData
270+
}

hnsw_wrapper.cc

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,30 @@ bool updatePoint(HNSW index, float *vec, unsigned long int label) {
118118
return false;
119119
}
120120

121-
// TODO
122-
//std::vector<float> getDataByLabel(HNSW index, unsigned long int label) {
123-
// return ((hnswlib::HierarchicalNSW<float>*)index)->getDataByLabel<float>(label);
124-
//}
121+
void getDataByLabel(HNSW index, unsigned long int label, float* out_data) {
122+
auto data = ((hnswlib::HierarchicalNSW<float>*)index)->getDataByLabel<float>(label);
123+
std::vector<float>* vec = new std::vector<float>(data.begin(), data.end());
124+
if (vec == nullptr) {
125+
return;
126+
}
127+
128+
size_t size = vec->size();
129+
for (size_t i = 0; i < size; i++) {
130+
out_data[i] = (*vec)[i];
131+
}
132+
133+
delete vec;
134+
}
135+
136+
int getMaxElements(HNSW index) {
137+
return ((hnswlib::HierarchicalNSW<float> *) index)->getMaxElements();
138+
}
139+
140+
int getCurrentElementCount(HNSW index) {
141+
return ((hnswlib::HierarchicalNSW<float> *) index)->getCurrentElementCount();
142+
}
143+
144+
int getDeleteCount(HNSW index) {
145+
return ((hnswlib::HierarchicalNSW<float> *) index)->getDeletedCount();
146+
}
147+

hnsw_wrapper.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ bool isMarkedDeleted(HNSW index, unsigned long int label);
2626

2727
bool updatePoint(HNSW index, float *vec, unsigned long int label);
2828

29+
int getMaxElements(HNSW index);
30+
31+
int getCurrentElementCount(HNSW index);
32+
33+
int getDeleteCount(HNSW index);
34+
35+
void getDataByLabel(HNSW index, unsigned long int label, float* out_data);
2936
#ifdef __cplusplus
3037
}
31-
#endif
38+
#endif

hnsw_wrapper.o

2.95 KB
Binary file not shown.

libhnsw.a

3.15 KB
Binary file not shown.

0 commit comments

Comments
 (0)