diff --git a/herddb-core/pom.xml b/herddb-core/pom.xml
index d9a04ecdd..9aa370ab8 100644
--- a/herddb-core/pom.xml
+++ b/herddb-core/pom.xml
@@ -33,6 +33,16 @@
${maven.build.timestamp}
+
+ io.github.jbellis
+ jvector
+ 1.0.2
+
+
+ com.indeed
+ util-mmap
+ 1.0.52-3042601
+
${project.groupId}
herddb-utils
@@ -189,6 +199,12 @@
+
+ org.apache.openjpa
+ openjpa-kernel
+ 3.1.2
+ test
+
diff --git a/herddb-core/src/main/java/herddb/index/SecondaryIndexVectorSimilarityScan.java b/herddb-core/src/main/java/herddb/index/SecondaryIndexVectorSimilarityScan.java
new file mode 100644
index 000000000..69b12e1b2
--- /dev/null
+++ b/herddb-core/src/main/java/herddb/index/SecondaryIndexVectorSimilarityScan.java
@@ -0,0 +1,51 @@
+/*
+ Licensed to Diennea S.r.l. under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. Diennea S.r.l. licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+ */
+
+package herddb.index;
+
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import herddb.sql.SQLRecordKeyFunction;
+
+/**
+ * Scan on a secondary index
+ *
+ * @author enrico.olivelli
+ */
+@SuppressFBWarnings("EI_EXPOSE_REP2")
+public class SecondaryIndexVectorSimilarityScan implements IndexOperation {
+
+ public final String indexName;
+ public final String column;
+ public final int topK;
+ public final SQLRecordKeyFunction value;
+
+ public SecondaryIndexVectorSimilarityScan(String indexName, String column, int topK, SQLRecordKeyFunction value) {
+ this.indexName = indexName;
+ this.column = column;
+ this.topK = topK;
+ this.value = value;
+ }
+
+ @Override
+ public String getIndexName() {
+ return indexName;
+ }
+
+}
diff --git a/herddb-core/src/main/java/herddb/index/jvector/JVectorIndexManager.java b/herddb-core/src/main/java/herddb/index/jvector/JVectorIndexManager.java
new file mode 100644
index 000000000..120f800fd
--- /dev/null
+++ b/herddb-core/src/main/java/herddb/index/jvector/JVectorIndexManager.java
@@ -0,0 +1,255 @@
+package herddb.index.jvector;
+
+import herddb.codec.RecordSerializer;
+import herddb.core.AbstractIndexManager;
+import herddb.core.AbstractTableManager;
+import herddb.core.PostCheckpointAction;
+import herddb.index.IndexOperation;
+import herddb.index.SecondaryIndexVectorSimilarityScan;
+import herddb.log.CommitLog;
+import herddb.log.LogSequenceNumber;
+import herddb.model.Index;
+import herddb.model.StatementEvaluationContext;
+import herddb.model.StatementExecutionException;
+import herddb.model.Table;
+import herddb.model.TableContext;
+import herddb.storage.DataStorageManager;
+import herddb.storage.DataStorageManagerException;
+import herddb.storage.IndexStatus;
+import herddb.utils.Bytes;
+import herddb.utils.DataAccessor;
+import herddb.utils.ExtendedDataOutputStream;
+import io.github.jbellis.jvector.disk.OnDiskGraphIndex;
+import io.github.jbellis.jvector.graph.GraphIndexBuilder;
+import io.github.jbellis.jvector.graph.GraphSearcher;
+import io.github.jbellis.jvector.graph.NeighborSimilarity;
+import io.github.jbellis.jvector.graph.OnHeapGraphIndex;
+import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
+import io.github.jbellis.jvector.graph.SearchResult;
+import io.github.jbellis.jvector.vector.VectorEncoding;
+import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
+
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+import java.util.stream.Stream;
+
+/**
+ * Index based on JVector, only for arrays of floats
+ */
+public class JVectorIndexManager extends AbstractIndexManager {
+
+ private static final int DIMENSIONS = 5;
+
+ private static final int M = 8;
+ private static final int beamWidth = 60;
+ private static final float neighborOverflow = 1.2f;
+ private static final float alpha = 1.4f;
+
+ private static VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.COSINE;
+ private static final Logger LOGGER = Logger.getLogger(JVectorIndexManager.class.getName());
+
+
+ private GraphIndexBuilder currentGraphBuilder;
+
+ private GraphSearcher graphSearcher;
+
+ LogSequenceNumber bootSequenceNumber;
+
+
+ private final AtomicLong newPageId = new AtomicLong(1);
+
+
+ private final RandomAccessVectorValuesImpl nodeToVectorMapping = new RandomAccessVectorValuesImpl();
+
+ public JVectorIndexManager(Index index, AbstractTableManager tableManager,
+ DataStorageManager dataStorageManager, String tableSpaceUUID, CommitLog log,
+ long createdInTransaction, int writeLockTimeout, int readLockTimeout) throws DataStorageManagerException {
+ super(index, tableManager, dataStorageManager, tableSpaceUUID, log, createdInTransaction, writeLockTimeout, readLockTimeout);
+ }
+
+ @Override
+ protected boolean doStart(LogSequenceNumber sequenceNumber) throws DataStorageManagerException {
+ LOGGER.log(Level.FINE, " start BRIN index {0} uuid {1}", new Object[]{index.name, index.uuid});
+
+ dataStorageManager.initIndex(tableSpaceUUID, index.uuid);
+
+ bootSequenceNumber = sequenceNumber;
+
+ if (LogSequenceNumber.START_OF_TIME.equals(sequenceNumber)) {
+ /* Empty index (booting from the start) */
+ createNewBuilder();
+ LOGGER.log(Level.FINE, "loaded empty index {0}", new Object[]{index.name});
+
+ return true;
+ } else {
+
+ IndexStatus status;
+ try {
+ status = dataStorageManager.getIndexStatus(tableSpaceUUID, index.uuid, sequenceNumber);
+ } catch (DataStorageManagerException e) {
+ LOGGER.log(Level.SEVERE, "cannot load index {0} due to {1}, it will be rebuilt", new Object[]{index.name, e});
+ return false;
+ }
+ newPageId.set(status.newPageId);
+ return true;
+ }
+ }
+
+ private void createNewBuilder() {
+ currentGraphBuilder = new GraphIndexBuilder<>(this.nodeToVectorMapping, VectorEncoding.FLOAT32,
+ vectorSimilarityFunction, M, beamWidth, neighborOverflow, alpha);
+ graphSearcher = new GraphSearcher.Builder<>(currentGraphBuilder.getGraph().getView())
+ .withConcurrentUpdates()
+ .build();
+ }
+
+ @Override
+ public void rebuild() throws DataStorageManagerException {
+ long _start = System.currentTimeMillis();
+ LOGGER.log(Level.FINE, "building index {0}", index.name);
+ dataStorageManager.initIndex(tableSpaceUUID, index.uuid);
+ createNewBuilder();
+ Table table = tableManager.getTable();
+ AtomicLong count = new AtomicLong();
+ tableManager.scanForIndexRebuild(r -> {
+ DataAccessor values = r.getDataAccessor(table);
+ Bytes key = RecordSerializer.serializeIndexKey(values, table, table.primaryKey);
+ Bytes indexKey = RecordSerializer.serializeIndexKey(values, index, index.columnNames);
+// LOGGER.log(Level.SEVERE, "adding " + key + " -> " + values);
+ recordInserted(key, indexKey);
+ count.incrementAndGet();
+ });
+ long _stop = System.currentTimeMillis();
+ if (count.intValue() > 0) {
+ LOGGER.log(Level.INFO, "building index {0} took {1}, scanned {2} records", new Object[]{index.name, (_stop - _start) + " ms", count});
+ }
+ }
+
+ @Override
+ public List checkpoint(LogSequenceNumber sequenceNumber, boolean pin) throws DataStorageManagerException {
+ OnHeapGraphIndex completed = currentGraphBuilder.build();
+ byte[] storedDisk;
+ try (ByteArrayOutputStream flush = new ByteArrayOutputStream();
+ DataOutputStream dataOutputStream = new ExtendedDataOutputStream(flush)) {
+ OnDiskGraphIndex.write(completed, this.nodeToVectorMapping,dataOutputStream );
+ dataOutputStream.flush();
+ storedDisk = flush.toByteArray();
+ } catch (IOException err) {
+ throw new DataStorageManagerException(err);
+ }
+
+ // TODO:
+
+ LOGGER.log(Level.INFO, "Serialized index takes {0} bytes", storedDisk.length);
+
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void unpinCheckpoint(LogSequenceNumber sequenceNumber) throws DataStorageManagerException {
+
+ }
+
+ @Override
+ protected Stream scanner(IndexOperation operation, StatementEvaluationContext context, TableContext tableContext) throws StatementExecutionException {
+ SecondaryIndexVectorSimilarityScan indexVectorSimilarityScan = (SecondaryIndexVectorSimilarityScan) operation;
+ byte[] bytes = indexVectorSimilarityScan.value.computeNewValue(null, context, tableContext);
+
+ float[] targetVector = Bytes.to_float_array(bytes, 0, bytes.length);
+ int topK = indexVectorSimilarityScan.topK;
+
+ NeighborSimilarity.ExactScoreFunction scoreFunction = (i) -> {
+ return vectorSimilarityFunction.compare(targetVector,
+ this.nodeToVectorMapping.nodeIdToVector.get(i).to_float_array());
+ };
+ SearchResult search = graphSearcher.search(scoreFunction, null, topK, null);
+
+ List result = new ArrayList<>();
+ for (SearchResult.NodeScore node : search.getNodes()) {
+ int nodeId = node.node;
+ Bytes primaryKey = this.nodeToVectorMapping.nodeIdToKey.get(nodeId);
+ result.add(primaryKey);
+ }
+ return result.stream();
+
+ }
+
+ @Override
+ public void recordUpdated(Bytes key, Bytes indexKeyRemoved, Bytes indexKeyAdded) throws DataStorageManagerException {
+ throw new DataStorageManagerException("Update not supported");
+ }
+
+ @Override
+ public void recordInserted(Bytes key, Bytes indexKey) throws DataStorageManagerException {
+ int nodeId = nodeToVectorMapping.registerRecord(key, indexKey);
+ float[] floatArray = indexKey.to_float_array();
+ LOGGER.log(Level.INFO, "Adding {0} as node id {1}", new Object[]{Arrays.toString(floatArray), nodeId});
+ currentGraphBuilder.addGraphNode(nodeId, nodeToVectorMapping);
+ }
+
+ @Override
+ public void recordDeleted(Bytes key, Bytes indexKey) throws DataStorageManagerException {
+ throw new DataStorageManagerException("Delete not supported");
+ }
+
+ @Override
+ public void truncate() throws DataStorageManagerException {
+ throw new DataStorageManagerException("TRUNCATE not supported");
+ }
+
+ @Override
+ public boolean valueAlreadyMapped(Bytes key, Bytes primaryKey) throws DataStorageManagerException {
+ // this method is for UNIQUE indexes
+ return false;
+ }
+
+ private static class RandomAccessVectorValuesImpl implements RandomAccessVectorValues {
+ private AtomicInteger nextNodeId = new AtomicInteger(1);
+
+ ConcurrentHashMap nodeIdToVector = new ConcurrentHashMap<>();
+ ConcurrentHashMap nodeIdToKey = new ConcurrentHashMap<>();
+
+ public int registerRecord(Bytes primaryKey, Bytes vectorValue) {
+ Integer newId = nextNodeId.incrementAndGet();
+ nodeIdToVector.put(newId, vectorValue);
+ nodeIdToKey.put(newId, primaryKey);
+ return newId;
+ }
+
+ @Override
+ public int size() {
+ return nodeIdToVector.size();
+ }
+
+ @Override
+ public int dimension() {
+ return DIMENSIONS;
+ }
+
+ @Override
+ public float[] vectorValue(int i) {
+ Bytes bytes = nodeIdToVector.get(i);
+ return bytes != null ? bytes.to_float_array() : null;
+ }
+
+ @Override
+ public boolean isValueShared() {
+ return true;
+ }
+
+ @Override
+ public RandomAccessVectorValues copy() {
+ return this;
+ }
+ }
+}
diff --git a/herddb-core/src/main/java/herddb/index/jvector/MMapReader.java b/herddb-core/src/main/java/herddb/index/jvector/MMapReader.java
new file mode 100644
index 000000000..1e7688df7
--- /dev/null
+++ b/herddb-core/src/main/java/herddb/index/jvector/MMapReader.java
@@ -0,0 +1,70 @@
+package herddb.index.jvector;
+
+import com.indeed.util.mmap.MMapBuffer;
+import io.github.jbellis.jvector.disk.RandomAccessReader;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+public class MMapReader implements RandomAccessReader {
+ private final MMapBuffer buffer;
+ private long position;
+ private byte[] floatsScratch = new byte[0];
+ private byte[] intsScratch = new byte[0];
+
+ MMapReader(MMapBuffer buffer) {
+ this.buffer = buffer;
+ }
+
+ @Override
+ public void seek(long offset) {
+ position = offset;
+ }
+
+ public int readInt() {
+ try {
+ return buffer.memory().getInt(position);
+ } finally {
+ position += Integer.BYTES;
+ }
+ }
+
+ public void readFully(byte[] bytes) {
+ read(bytes, 0, bytes.length);
+ }
+
+ private void read(byte[] bytes, int offset, int count) {
+ try {
+ buffer.memory().getBytes(position, bytes, offset, count);
+ } finally {
+ position += count;
+ }
+ }
+
+ @Override
+ public void readFully(float[] floats) {
+ int bytesToRead = floats.length * Float.BYTES;
+ if (floatsScratch.length < bytesToRead) {
+ floatsScratch = new byte[bytesToRead];
+ }
+ read(floatsScratch, 0, bytesToRead);
+ ByteBuffer byteBuffer = ByteBuffer.wrap(floatsScratch).order(ByteOrder.BIG_ENDIAN);
+ byteBuffer.asFloatBuffer().get(floats);
+ }
+
+ @Override
+ public void read(int[] ints, int offset, int count) {
+ int bytesToRead = (count - offset) * Integer.BYTES;
+ if (intsScratch.length < bytesToRead) {
+ intsScratch = new byte[bytesToRead];
+ }
+ read(intsScratch, 0, bytesToRead);
+ ByteBuffer byteBuffer = ByteBuffer.wrap(intsScratch).order(ByteOrder.BIG_ENDIAN);
+ byteBuffer.asIntBuffer().get(ints, offset, count);
+ }
+
+ @Override
+ public void close() {
+ // don't close buffer, let the Supplier handle that
+ }
+}
diff --git a/herddb-core/src/main/java/herddb/index/jvector/MMapReaderSupplier.java b/herddb-core/src/main/java/herddb/index/jvector/MMapReaderSupplier.java
new file mode 100644
index 000000000..5ed7b623a
--- /dev/null
+++ b/herddb-core/src/main/java/herddb/index/jvector/MMapReaderSupplier.java
@@ -0,0 +1,28 @@
+package herddb.index.jvector;
+
+import com.indeed.util.mmap.MMapBuffer;
+import io.github.jbellis.jvector.disk.RandomAccessReader;
+import io.github.jbellis.jvector.disk.ReaderSupplier;
+
+import java.io.IOException;
+import java.nio.ByteOrder;
+import java.nio.channels.FileChannel;
+import java.nio.file.Path;
+
+public class MMapReaderSupplier implements ReaderSupplier {
+ private final MMapBuffer buffer;
+
+ public MMapReaderSupplier(Path path) throws IOException {
+ buffer = new MMapBuffer(path, FileChannel.MapMode.READ_ONLY, ByteOrder.BIG_ENDIAN);
+ }
+
+ @Override
+ public RandomAccessReader get() {
+ return new MMapReader(buffer);
+ }
+
+ @Override
+ public void close() throws IOException {
+ buffer.close();
+ }
+}
diff --git a/herddb-core/src/main/java/herddb/index/jvector/ReaderSupplierFactory.java b/herddb-core/src/main/java/herddb/index/jvector/ReaderSupplierFactory.java
new file mode 100644
index 000000000..81b59a2e4
--- /dev/null
+++ b/herddb-core/src/main/java/herddb/index/jvector/ReaderSupplierFactory.java
@@ -0,0 +1,22 @@
+package herddb.index.jvector;
+
+import io.github.jbellis.jvector.disk.ReaderSupplier;
+import io.github.jbellis.jvector.disk.SimpleMappedReaderSupplier;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+
+public class ReaderSupplierFactory {
+ public static ReaderSupplier open(Path path) throws IOException {
+ try {
+ return new MMapReaderSupplier(path);
+ } catch (UnsatisfiedLinkError|NoClassDefFoundError e) {
+ if (Files.size(path) > Integer.MAX_VALUE) {
+ throw new RuntimeException("File sizes greater than 2GB are not supported on Windows--contributions welcome");
+ }
+
+ return new SimpleMappedReaderSupplier(path);
+ }
+ }
+}
diff --git a/herddb-core/src/main/java/herddb/model/Index.java b/herddb-core/src/main/java/herddb/model/Index.java
index c4d15183c..b8405d0e5 100644
--- a/herddb-core/src/main/java/herddb/model/Index.java
+++ b/herddb-core/src/main/java/herddb/model/Index.java
@@ -44,6 +44,7 @@ public class Index implements ColumnsList {
public static final String TYPE_HASH = "hash";
public static final String TYPE_BRIN = "brin";
+ public static final String TYPE_JVECTOR = "jvector";
private static final int PROPERTY_UNIQUE = 0x01;
@@ -234,7 +235,9 @@ public Index build() {
if (table == null || table.isEmpty()) {
throw new IllegalArgumentException("table is not defined");
}
- if (!TYPE_HASH.equals(type) && !TYPE_BRIN.equals(type)) {
+ if (!TYPE_HASH.equals(type)
+ && !TYPE_BRIN.equals(type)
+ && !TYPE_JVECTOR.equals(type)) {
throw new IllegalArgumentException("only index type " + TYPE_HASH + "," + TYPE_BRIN + " are supported");
}
if (columns.isEmpty()) {
diff --git a/herddb-core/src/test/java/herddb/index/jvector/JVectorIndexManagerTest.java b/herddb-core/src/test/java/herddb/index/jvector/JVectorIndexManagerTest.java
new file mode 100644
index 000000000..e92aa895e
--- /dev/null
+++ b/herddb-core/src/test/java/herddb/index/jvector/JVectorIndexManagerTest.java
@@ -0,0 +1,103 @@
+package herddb.index.jvector;
+
+import herddb.core.AbstractTableManager;
+import herddb.core.MemoryManager;
+import herddb.index.IndexOperation;
+import herddb.index.SecondaryIndexVectorSimilarityScan;
+import herddb.log.LogSequenceNumber;
+import herddb.mem.MemoryDataStorageManager;
+import herddb.model.ColumnTypes;
+import herddb.model.Index;
+import herddb.model.Record;
+import herddb.model.StatementEvaluationContext;
+import herddb.model.StatementExecutionException;
+import herddb.model.Table;
+import herddb.model.TableContext;
+import herddb.sql.SQLRecordKeyFunction;
+import herddb.utils.Bytes;
+import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
+import org.junit.Test;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static herddb.model.Index.TYPE_JVECTOR;
+import static org.junit.Assert.assertEquals;
+
+public class JVectorIndexManagerTest {
+
+ private static final Logger LOGGER = Logger.getLogger(JVectorIndexManagerTest.class.getName());
+
+ @Test
+ public void basicBuildAndSearch() {
+ String column = "embeddings";
+ Table table = Table
+ .builder()
+ .name("table")
+ .column("key", ColumnTypes.INTEGER)
+ .column("embeddings", ColumnTypes.FLOATARRAY)
+ .primaryKey("key")
+ .build();
+ Index index = Index
+ .builder()
+ .onTable(table)
+ .type(TYPE_JVECTOR)
+ .column(column, ColumnTypes.FLOATARRAY)
+ .build();
+ AbstractTableManager abstractTableManager = null;
+ MemoryDataStorageManager memoryDataStorageManager = new MemoryDataStorageManager();
+ JVectorIndexManager indexManager = new JVectorIndexManager(index, abstractTableManager, memoryDataStorageManager, "xxxx", null, -1, 10000, 1000);
+
+ indexManager.start(LogSequenceNumber.START_OF_TIME);
+ Bytes vector1 = Bytes.from_float_array(new float[] {1, 2, 3, 4, 5});
+
+ Bytes vectorToSearch = vector1;
+
+ Map data = new HashMap<>();
+
+ for (int i = 0; i < 100; i++) {
+ Bytes pk = Bytes.from_int(i);
+ double angle = i * Math.PI / 100;
+ float sin = (float) Math.sin( angle);
+ float cos = (float) Math.cos( angle);
+ Bytes vector = Bytes.from_float_array(new float[] {0, 0, 0, sin, cos});
+ indexManager.recordInserted(pk, vector);
+
+ data.put(pk, vector);
+ }
+ int topK = 10;
+ SQLRecordKeyFunction keyFunction = new DummyConstantValueFunction(table, vectorToSearch);
+ IndexOperation indexOperation = new SecondaryIndexVectorSimilarityScan(index.name, column, topK, keyFunction);
+ Stream scanner = indexManager.scanner(indexOperation, null, null);
+ List collect = scanner.collect(Collectors.toList());
+ collect.forEach(k -> {
+ Bytes vector = data.get(k);
+ float compare = VectorSimilarityFunction.COSINE.compare(vector1.to_float_array(), vector.to_float_array());
+ LOGGER.log(Level.INFO, "Found record with key {0} and value {1} compare {2}", new Object[] {k.to_int(), vector.to_float_array(), compare});
+ });
+ assertEquals(10, collect.size());
+
+ }
+
+ private static class DummyConstantValueFunction extends SQLRecordKeyFunction {
+ private final Bytes vectorToSearch;
+
+ public DummyConstantValueFunction(Table table, Bytes vectorToSearch) {
+ super(Collections.emptyList(), Collections.emptyList(), table);
+ this.vectorToSearch = vectorToSearch;
+ }
+
+ @Override
+ public byte[] computeNewValue(Record previous,
+ StatementEvaluationContext context,
+ TableContext tableContext) throws StatementExecutionException {
+ return vectorToSearch.to_array();
+ }
+ }
+}
diff --git a/herddb-docker/pom.xml b/herddb-docker/pom.xml
index 3b10a1836..8f978aeff 100644
--- a/herddb-docker/pom.xml
+++ b/herddb-docker/pom.xml
@@ -78,7 +78,7 @@
com.google.cloud.tools
jib-maven-plugin
- 2.5.2
+ 3.4.0