Skip to content

Commit ae49bda

Browse files
committed
inverting the apply-invertedApply
1 parent 8254739 commit ae49bda

File tree

8 files changed

+63
-32
lines changed

8 files changed

+63
-32
lines changed

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ private AbstractNode<NodeReference> compactNodeFromTuples(@Nonnull final AffineO
220220
@Nonnull final Tuple vectorTuple,
221221
@Nonnull final Tuple neighborsTuple) {
222222
final RealVector vector =
223-
storageTransform.invertedApply(StorageAdapter.vectorFromTuple(getConfig(), vectorTuple));
223+
storageTransform.apply(StorageAdapter.vectorFromTuple(getConfig(), vectorTuple));
224224
final List<NodeReference> nodeReferences = Lists.newArrayListWithExpectedSize(neighborsTuple.size());
225225

226226
for (int i = 0; i < neighborsTuple.size(); i ++) {

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) {
275275
final EntryNodeReference entryNodeReference = accessInfo.getEntryNodeReference();
276276

277277
final AffineOperator storageTransform = storageTransform(accessInfo);
278-
final RealVector transformedQueryVector = storageTransform.invertedApply(queryVector);
278+
final RealVector transformedQueryVector = storageTransform.apply(queryVector);
279279
final Quantizer quantizer = quantizer(accessInfo);
280280
final Estimator estimator = quantizer.estimator();
281281

@@ -356,7 +356,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) {
356356
Objects.requireNonNull(nodeReferenceAndNode).getNodeReferenceWithDistance();
357357
final AbstractNode<N> node = nodeReferenceAndNode.getNode();
358358
@Nullable final RealVector reconstructedVector =
359-
includeVectors ? storageTransform.apply(node.asCompactNode().getVector()) : null;
359+
includeVectors ? storageTransform.invertedApply(node.asCompactNode().getVector()) : null;
360360

361361
resultBuilder.add(
362362
new ResultEntry(node.getPrimaryKey(),
@@ -789,7 +789,7 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
789789
.thenCompose(accessInfo -> {
790790
final AccessInfo currentAccessInfo;
791791
final AffineOperator storageTransform = storageTransform(accessInfo);
792-
final RealVector transformedNewVector = storageTransform.invertedApply(newVector);
792+
final RealVector transformedNewVector = storageTransform.apply(newVector);
793793
final Quantizer quantizer = quantizer(accessInfo);
794794
final Estimator estimator = quantizer.estimator();
795795

@@ -826,7 +826,7 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
826826
currentAccessInfo = accessInfo;
827827
}
828828
}
829-
829+
830830
final EntryNodeReference entryNodeReference = accessInfo.getEntryNodeReference();
831831
final int lMax = entryNodeReference.getLayer();
832832
if (logger.isTraceEnabled()) {
@@ -849,14 +849,30 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
849849
insertIntoLayers(transaction, storageTransform, quantizer, newPrimaryKey,
850850
transformedNewVector, nodeReference, lMax, insertionLayer))
851851
.thenCompose(ignored ->
852-
addToStats(transaction, currentAccessInfo, transformedNewVector));
852+
addToStatsIfNecessary(transaction, currentAccessInfo, transformedNewVector));
853853
}).thenCompose(ignored -> AsyncUtil.DONE);
854854
}
855855

856+
/**
857+
* Method to keep stats if necessary. Stats need to be kept and maintained when the client would like to use
858+
* e.g. RaBitQ as RaBitQ needs a stable somewhat correct centroid in order to function properly.
859+
* <p>
860+
* Specifically for RaBitQ, we add vectors to a set of sampled vectors in a designated subspace of the HNSW
861+
* structure. The parameter {@link Config#getSampleVectorStatsProbability()} governs when we do sample. Another
862+
* parameter, {@link Config#getMaintainStatsProbability()}, determines how many times we add-up/replace (consume)
863+
* vectors from this sampled-vector space and aggregate them in the typical running count/running sum scheme
864+
* in order to finally compute the centroid if {@link Config#getStatsThreshold()} number of vectors have been
865+
* sampled and aggregated. That centroid is then used to update the access info.
866+
*
867+
* @param transaction the transaction
868+
* @param currentAccessInfo this current access info that was fetched as part of an insert
869+
* @param transformedNewVector the new vector (in the transformed coordinate system) that may be added
870+
* @return a future that returns {@code null} when completed
871+
*/
856872
@Nonnull
857-
private CompletableFuture<Void> addToStats(@Nonnull final Transaction transaction,
858-
@Nonnull final AccessInfo currentAccessInfo,
859-
@Nonnull final RealVector transformedNewVector) {
873+
private CompletableFuture<Void> addToStatsIfNecessary(@Nonnull final Transaction transaction,
874+
@Nonnull final AccessInfo currentAccessInfo,
875+
@Nonnull final RealVector transformedNewVector) {
860876
if (getConfig().isUseRaBitQ() && !currentAccessInfo.canUseRaBitQ()) {
861877
if (shouldSampleVector()) {
862878
StorageAdapter.appendSampledVector(transaction, getSubspace(),
@@ -885,12 +901,12 @@ private CompletableFuture<Void> addToStats(@Nonnull final Transaction transactio
885901
new FhtKacRotator(rotatorSeed, getConfig().getNumDimensions(), 10);
886902

887903
final RealVector centroid =
888-
partialVector.multiply(1.0d / partialCount);
889-
final RealVector transformedCentroid = rotator.invertedApply(centroid);
904+
partialVector.multiply(-1.0d / partialCount);
905+
final RealVector transformedCentroid = rotator.apply(centroid);
890906

891907
final var transformedEntryNodeVector =
892-
rotator.invertedApply(currentAccessInfo.getEntryNodeReference()
893-
.getVector()).subtract(transformedCentroid);
908+
rotator.apply(currentAccessInfo.getEntryNodeReference()
909+
.getVector()).add(transformedCentroid);
894910

895911
final AccessInfo newAccessInfo =
896912
new AccessInfo(currentAccessInfo.getEntryNodeReference().withVector(transformedEntryNodeVector),

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ private NodeReferenceWithVector neighborFromTuples(@Nonnull final AffineOperator
212212
@Nonnull final Tuple keyTuple, @Nonnull final Tuple valueTuple) {
213213
final Tuple neighborPrimaryKey = keyTuple.getNestedTuple(2); // neighbor primary key
214214
final RealVector neighborVector =
215-
storageTransform.invertedApply(
215+
storageTransform.apply(
216216
StorageAdapter.vectorFromTuple(getConfig(), valueTuple)); // the entire value is the vector
217217
return new NodeReferenceWithVector(neighborPrimaryKey, neighborVector);
218218
}

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageTransform.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,15 @@ public StorageTransform(final long seed, final int numDimensions, @Nonnull final
4040
@Nonnull
4141
@Override
4242
public RealVector apply(@Nonnull final RealVector vector) {
43-
if (!(vector instanceof EncodedRealVector)) {
43+
//
44+
// Only transform the vector if it is needed. We make the decision based on whether the vector is encoded or
45+
// not. When we switch on encoding, we apply the new coordinate system from that point onwards meaning that all
46+
// vectors inserted before use the client coordinate system. Therefore, we must transform all regular vectors
47+
// and ignore all encoded vectors.
48+
//
49+
// TODO This could be done better in the future by keeping something like a generation id with the vector
50+
// so we would know in what coordinate system the vector is.
51+
if (vector instanceof EncodedRealVector) {
4452
return vector;
4553
}
4654
return super.apply(vector);
@@ -49,9 +57,17 @@ public RealVector apply(@Nonnull final RealVector vector) {
4957
@Nonnull
5058
@Override
5159
public RealVector invertedApply(@Nonnull final RealVector vector) {
52-
if (vector instanceof EncodedRealVector) {
60+
//
61+
// Only transform the vector if it is needed. We make the decision based on whether the vector is encoded or
62+
// not. When we switch on encoding, we apply the new coordinate system from that point onwards meaning that all
63+
// vectors inserted before use the client coordinate system. For the inverted case, we only have to transform
64+
// the encoded vectors as they are expressed in the internal coordinate system, while regular non-encoded
65+
// vectors are already expressed in the client system.
66+
//
67+
if (!(vector instanceof EncodedRealVector)) {
5368
return vector;
5469
}
70+
5571
return super.invertedApply(vector);
5672
}
5773
}

fdb-extensions/src/main/java/com/apple/foundationdb/linear/AffineOperator.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ public int getNumDimensions() {
5959
public RealVector apply(@Nonnull final RealVector vector) {
6060
RealVector result = vector;
6161

62-
if (translationVector != null) {
63-
result = result.add(translationVector);
64-
}
65-
6662
if (linearOperator != null) {
6763
result = linearOperator.apply(result);
6864
}
6965

66+
if (translationVector != null) {
67+
result = result.add(translationVector);
68+
}
69+
7070
return result;
7171
}
7272

@@ -75,14 +75,14 @@ public RealVector apply(@Nonnull final RealVector vector) {
7575
public RealVector invertedApply(@Nonnull final RealVector vector) {
7676
RealVector result = vector;
7777

78-
if (linearOperator != null) {
79-
result = linearOperator.transposedApply(result);
80-
}
81-
8278
if (translationVector != null) {
8379
result = result.subtract(translationVector);
8480
}
8581

82+
if (linearOperator != null) {
83+
result = linearOperator.transposedApply(result);
84+
}
85+
8686
return result;
8787
}
8888

fdb-extensions/src/main/java/com/apple/foundationdb/linear/FhtKacRotator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ public RealVector transposedApply(@Nonnull final RealVector x) {
147147
}
148148

149149
@Nonnull
150-
public double[] operateTranspose(@Nonnull final double[] x) {
150+
private double[] operateTranspose(@Nonnull final double[] x) {
151151
if (x.length != numDimensions) {
152152
throw new IllegalArgumentException("dimensionality of x != n");
153153
}

fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import com.apple.foundationdb.tuple.Tuple;
3737
import com.apple.test.RandomSeedSource;
3838
import com.apple.test.RandomizedTestUtils;
39-
import com.apple.test.SuperSlow;
4039
import com.apple.test.Tags;
4140
import com.google.common.base.Verify;
4241
import com.google.common.collect.ImmutableList;
@@ -303,7 +302,7 @@ private int basicInsertBatch(final HNSW hnsw, final int batchSize,
303302
}
304303

305304
@Test
306-
@SuperSlow
305+
//@SuperSlow
307306
void testSIFTInsertSmall() throws Exception {
308307
final Metric metric = Metric.EUCLIDEAN_METRIC;
309308
final int k = 100;

fdb-extensions/src/test/java/com/apple/foundationdb/rabitq/RaBitQuantizerTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,9 @@ void encodeManyWithEstimationsTest(final long seed, final int numDimensions, fin
166166
logger.trace("v = {}", v);
167167
logger.trace("centroid = {}", centroid);
168168

169-
final RealVector centroidRot = rotator.transposedApply(centroid);
170-
final RealVector qTrans = rotator.transposedApply(q).subtract(centroidRot);
171-
final RealVector vTrans = rotator.transposedApply(v).subtract(centroidRot);
169+
final RealVector centroidRot = rotator.apply(centroid);
170+
final RealVector qTrans = rotator.apply(q).subtract(centroidRot);
171+
final RealVector vTrans = rotator.apply(v).subtract(centroidRot);
172172

173173
logger.trace("qTrans = {}", qTrans);
174174
logger.trace("vTrans = {}", vTrans);
@@ -183,8 +183,8 @@ void encodeManyWithEstimationsTest(final long seed, final int numDimensions, fin
183183

184184
final EncodedRealVector encodedQ = quantizer.encode(qTrans);
185185
final RaBitEstimator estimator = quantizer.estimator();
186-
final RealVector reconstructedQ = rotator.apply(encodedQ.add(centroidRot));
187-
final RealVector reconstructedV = rotator.apply(encodedV.add(centroidRot));
186+
final RealVector reconstructedQ = rotator.transposedApply(encodedQ.add(centroidRot));
187+
final RealVector reconstructedV = rotator.transposedApply(encodedV.add(centroidRot));
188188
final RaBitEstimator.Result estimatedDistance = estimator.estimateDistanceAndErrorBound(qTrans, encodedV);
189189
logger.trace("estimated ||qRot - vRot||^2 = {}", estimatedDistance);
190190
final double trueDistance = Metric.EUCLIDEAN_SQUARE_METRIC.distance(vTrans, qTrans);

0 commit comments

Comments
 (0)