diff --git a/herddb-core/pom.xml b/herddb-core/pom.xml index d9a04ecdd..9aa370ab8 100644 --- a/herddb-core/pom.xml +++ b/herddb-core/pom.xml @@ -33,6 +33,16 @@ ${maven.build.timestamp} + + io.github.jbellis + jvector + 1.0.2 + + + com.indeed + util-mmap + 1.0.52-3042601 + ${project.groupId} herddb-utils @@ -189,6 +199,12 @@ + + org.apache.openjpa + openjpa-kernel + 3.1.2 + test + diff --git a/herddb-core/src/main/java/herddb/index/SecondaryIndexVectorSimilarityScan.java b/herddb-core/src/main/java/herddb/index/SecondaryIndexVectorSimilarityScan.java new file mode 100644 index 000000000..69b12e1b2 --- /dev/null +++ b/herddb-core/src/main/java/herddb/index/SecondaryIndexVectorSimilarityScan.java @@ -0,0 +1,51 @@ +/* + Licensed to Diennea S.r.l. under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. Diennea S.r.l. licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + */ + +package herddb.index; + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import herddb.sql.SQLRecordKeyFunction; + +/** + * Scan on a secondary index + * + * @author enrico.olivelli + */ +@SuppressFBWarnings("EI_EXPOSE_REP2") +public class SecondaryIndexVectorSimilarityScan implements IndexOperation { + + public final String indexName; + public final String column; + public final int topK; + public final SQLRecordKeyFunction value; + + public SecondaryIndexVectorSimilarityScan(String indexName, String column, int topK, SQLRecordKeyFunction value) { + this.indexName = indexName; + this.column = column; + this.topK = topK; + this.value = value; + } + + @Override + public String getIndexName() { + return indexName; + } + +} diff --git a/herddb-core/src/main/java/herddb/index/jvector/JVectorIndexManager.java b/herddb-core/src/main/java/herddb/index/jvector/JVectorIndexManager.java new file mode 100644 index 000000000..120f800fd --- /dev/null +++ b/herddb-core/src/main/java/herddb/index/jvector/JVectorIndexManager.java @@ -0,0 +1,255 @@ +package herddb.index.jvector; + +import herddb.codec.RecordSerializer; +import herddb.core.AbstractIndexManager; +import herddb.core.AbstractTableManager; +import herddb.core.PostCheckpointAction; +import herddb.index.IndexOperation; +import herddb.index.SecondaryIndexVectorSimilarityScan; +import herddb.log.CommitLog; +import herddb.log.LogSequenceNumber; +import herddb.model.Index; +import herddb.model.StatementEvaluationContext; +import herddb.model.StatementExecutionException; +import herddb.model.Table; +import herddb.model.TableContext; +import herddb.storage.DataStorageManager; +import herddb.storage.DataStorageManagerException; +import herddb.storage.IndexStatus; +import herddb.utils.Bytes; +import herddb.utils.DataAccessor; +import herddb.utils.ExtendedDataOutputStream; +import io.github.jbellis.jvector.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.graph.GraphIndexBuilder; +import io.github.jbellis.jvector.graph.GraphSearcher; +import io.github.jbellis.jvector.graph.NeighborSimilarity; +import io.github.jbellis.jvector.graph.OnHeapGraphIndex; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.SearchResult; +import io.github.jbellis.jvector.vector.VectorEncoding; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.Stream; + +/** + * Index based on JVector, only for arrays of floats + */ +public class JVectorIndexManager extends AbstractIndexManager { + + private static final int DIMENSIONS = 5; + + private static final int M = 8; + private static final int beamWidth = 60; + private static final float neighborOverflow = 1.2f; + private static final float alpha = 1.4f; + + private static VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.COSINE; + private static final Logger LOGGER = Logger.getLogger(JVectorIndexManager.class.getName()); + + + private GraphIndexBuilder currentGraphBuilder; + + private GraphSearcher graphSearcher; + + LogSequenceNumber bootSequenceNumber; + + + private final AtomicLong newPageId = new AtomicLong(1); + + + private final RandomAccessVectorValuesImpl nodeToVectorMapping = new RandomAccessVectorValuesImpl(); + + public JVectorIndexManager(Index index, AbstractTableManager tableManager, + DataStorageManager dataStorageManager, String tableSpaceUUID, CommitLog log, + long createdInTransaction, int writeLockTimeout, int readLockTimeout) throws DataStorageManagerException { + super(index, tableManager, dataStorageManager, tableSpaceUUID, log, createdInTransaction, writeLockTimeout, readLockTimeout); + } + + @Override + protected boolean doStart(LogSequenceNumber sequenceNumber) throws DataStorageManagerException { + LOGGER.log(Level.FINE, " start BRIN index {0} uuid {1}", new Object[]{index.name, index.uuid}); + + dataStorageManager.initIndex(tableSpaceUUID, index.uuid); + + bootSequenceNumber = sequenceNumber; + + if (LogSequenceNumber.START_OF_TIME.equals(sequenceNumber)) { + /* Empty index (booting from the start) */ + createNewBuilder(); + LOGGER.log(Level.FINE, "loaded empty index {0}", new Object[]{index.name}); + + return true; + } else { + + IndexStatus status; + try { + status = dataStorageManager.getIndexStatus(tableSpaceUUID, index.uuid, sequenceNumber); + } catch (DataStorageManagerException e) { + LOGGER.log(Level.SEVERE, "cannot load index {0} due to {1}, it will be rebuilt", new Object[]{index.name, e}); + return false; + } + newPageId.set(status.newPageId); + return true; + } + } + + private void createNewBuilder() { + currentGraphBuilder = new GraphIndexBuilder<>(this.nodeToVectorMapping, VectorEncoding.FLOAT32, + vectorSimilarityFunction, M, beamWidth, neighborOverflow, alpha); + graphSearcher = new GraphSearcher.Builder<>(currentGraphBuilder.getGraph().getView()) + .withConcurrentUpdates() + .build(); + } + + @Override + public void rebuild() throws DataStorageManagerException { + long _start = System.currentTimeMillis(); + LOGGER.log(Level.FINE, "building index {0}", index.name); + dataStorageManager.initIndex(tableSpaceUUID, index.uuid); + createNewBuilder(); + Table table = tableManager.getTable(); + AtomicLong count = new AtomicLong(); + tableManager.scanForIndexRebuild(r -> { + DataAccessor values = r.getDataAccessor(table); + Bytes key = RecordSerializer.serializeIndexKey(values, table, table.primaryKey); + Bytes indexKey = RecordSerializer.serializeIndexKey(values, index, index.columnNames); +// LOGGER.log(Level.SEVERE, "adding " + key + " -> " + values); + recordInserted(key, indexKey); + count.incrementAndGet(); + }); + long _stop = System.currentTimeMillis(); + if (count.intValue() > 0) { + LOGGER.log(Level.INFO, "building index {0} took {1}, scanned {2} records", new Object[]{index.name, (_stop - _start) + " ms", count}); + } + } + + @Override + public List checkpoint(LogSequenceNumber sequenceNumber, boolean pin) throws DataStorageManagerException { + OnHeapGraphIndex completed = currentGraphBuilder.build(); + byte[] storedDisk; + try (ByteArrayOutputStream flush = new ByteArrayOutputStream(); + DataOutputStream dataOutputStream = new ExtendedDataOutputStream(flush)) { + OnDiskGraphIndex.write(completed, this.nodeToVectorMapping,dataOutputStream ); + dataOutputStream.flush(); + storedDisk = flush.toByteArray(); + } catch (IOException err) { + throw new DataStorageManagerException(err); + } + + // TODO: + + LOGGER.log(Level.INFO, "Serialized index takes {0} bytes", storedDisk.length); + + return Collections.emptyList(); + } + + @Override + public void unpinCheckpoint(LogSequenceNumber sequenceNumber) throws DataStorageManagerException { + + } + + @Override + protected Stream scanner(IndexOperation operation, StatementEvaluationContext context, TableContext tableContext) throws StatementExecutionException { + SecondaryIndexVectorSimilarityScan indexVectorSimilarityScan = (SecondaryIndexVectorSimilarityScan) operation; + byte[] bytes = indexVectorSimilarityScan.value.computeNewValue(null, context, tableContext); + + float[] targetVector = Bytes.to_float_array(bytes, 0, bytes.length); + int topK = indexVectorSimilarityScan.topK; + + NeighborSimilarity.ExactScoreFunction scoreFunction = (i) -> { + return vectorSimilarityFunction.compare(targetVector, + this.nodeToVectorMapping.nodeIdToVector.get(i).to_float_array()); + }; + SearchResult search = graphSearcher.search(scoreFunction, null, topK, null); + + List result = new ArrayList<>(); + for (SearchResult.NodeScore node : search.getNodes()) { + int nodeId = node.node; + Bytes primaryKey = this.nodeToVectorMapping.nodeIdToKey.get(nodeId); + result.add(primaryKey); + } + return result.stream(); + + } + + @Override + public void recordUpdated(Bytes key, Bytes indexKeyRemoved, Bytes indexKeyAdded) throws DataStorageManagerException { + throw new DataStorageManagerException("Update not supported"); + } + + @Override + public void recordInserted(Bytes key, Bytes indexKey) throws DataStorageManagerException { + int nodeId = nodeToVectorMapping.registerRecord(key, indexKey); + float[] floatArray = indexKey.to_float_array(); + LOGGER.log(Level.INFO, "Adding {0} as node id {1}", new Object[]{Arrays.toString(floatArray), nodeId}); + currentGraphBuilder.addGraphNode(nodeId, nodeToVectorMapping); + } + + @Override + public void recordDeleted(Bytes key, Bytes indexKey) throws DataStorageManagerException { + throw new DataStorageManagerException("Delete not supported"); + } + + @Override + public void truncate() throws DataStorageManagerException { + throw new DataStorageManagerException("TRUNCATE not supported"); + } + + @Override + public boolean valueAlreadyMapped(Bytes key, Bytes primaryKey) throws DataStorageManagerException { + // this method is for UNIQUE indexes + return false; + } + + private static class RandomAccessVectorValuesImpl implements RandomAccessVectorValues { + private AtomicInteger nextNodeId = new AtomicInteger(1); + + ConcurrentHashMap nodeIdToVector = new ConcurrentHashMap<>(); + ConcurrentHashMap nodeIdToKey = new ConcurrentHashMap<>(); + + public int registerRecord(Bytes primaryKey, Bytes vectorValue) { + Integer newId = nextNodeId.incrementAndGet(); + nodeIdToVector.put(newId, vectorValue); + nodeIdToKey.put(newId, primaryKey); + return newId; + } + + @Override + public int size() { + return nodeIdToVector.size(); + } + + @Override + public int dimension() { + return DIMENSIONS; + } + + @Override + public float[] vectorValue(int i) { + Bytes bytes = nodeIdToVector.get(i); + return bytes != null ? bytes.to_float_array() : null; + } + + @Override + public boolean isValueShared() { + return true; + } + + @Override + public RandomAccessVectorValues copy() { + return this; + } + } +} diff --git a/herddb-core/src/main/java/herddb/index/jvector/MMapReader.java b/herddb-core/src/main/java/herddb/index/jvector/MMapReader.java new file mode 100644 index 000000000..1e7688df7 --- /dev/null +++ b/herddb-core/src/main/java/herddb/index/jvector/MMapReader.java @@ -0,0 +1,70 @@ +package herddb.index.jvector; + +import com.indeed.util.mmap.MMapBuffer; +import io.github.jbellis.jvector.disk.RandomAccessReader; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public class MMapReader implements RandomAccessReader { + private final MMapBuffer buffer; + private long position; + private byte[] floatsScratch = new byte[0]; + private byte[] intsScratch = new byte[0]; + + MMapReader(MMapBuffer buffer) { + this.buffer = buffer; + } + + @Override + public void seek(long offset) { + position = offset; + } + + public int readInt() { + try { + return buffer.memory().getInt(position); + } finally { + position += Integer.BYTES; + } + } + + public void readFully(byte[] bytes) { + read(bytes, 0, bytes.length); + } + + private void read(byte[] bytes, int offset, int count) { + try { + buffer.memory().getBytes(position, bytes, offset, count); + } finally { + position += count; + } + } + + @Override + public void readFully(float[] floats) { + int bytesToRead = floats.length * Float.BYTES; + if (floatsScratch.length < bytesToRead) { + floatsScratch = new byte[bytesToRead]; + } + read(floatsScratch, 0, bytesToRead); + ByteBuffer byteBuffer = ByteBuffer.wrap(floatsScratch).order(ByteOrder.BIG_ENDIAN); + byteBuffer.asFloatBuffer().get(floats); + } + + @Override + public void read(int[] ints, int offset, int count) { + int bytesToRead = (count - offset) * Integer.BYTES; + if (intsScratch.length < bytesToRead) { + intsScratch = new byte[bytesToRead]; + } + read(intsScratch, 0, bytesToRead); + ByteBuffer byteBuffer = ByteBuffer.wrap(intsScratch).order(ByteOrder.BIG_ENDIAN); + byteBuffer.asIntBuffer().get(ints, offset, count); + } + + @Override + public void close() { + // don't close buffer, let the Supplier handle that + } +} diff --git a/herddb-core/src/main/java/herddb/index/jvector/MMapReaderSupplier.java b/herddb-core/src/main/java/herddb/index/jvector/MMapReaderSupplier.java new file mode 100644 index 000000000..5ed7b623a --- /dev/null +++ b/herddb-core/src/main/java/herddb/index/jvector/MMapReaderSupplier.java @@ -0,0 +1,28 @@ +package herddb.index.jvector; + +import com.indeed.util.mmap.MMapBuffer; +import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.disk.ReaderSupplier; + +import java.io.IOException; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.file.Path; + +public class MMapReaderSupplier implements ReaderSupplier { + private final MMapBuffer buffer; + + public MMapReaderSupplier(Path path) throws IOException { + buffer = new MMapBuffer(path, FileChannel.MapMode.READ_ONLY, ByteOrder.BIG_ENDIAN); + } + + @Override + public RandomAccessReader get() { + return new MMapReader(buffer); + } + + @Override + public void close() throws IOException { + buffer.close(); + } +} diff --git a/herddb-core/src/main/java/herddb/index/jvector/ReaderSupplierFactory.java b/herddb-core/src/main/java/herddb/index/jvector/ReaderSupplierFactory.java new file mode 100644 index 000000000..81b59a2e4 --- /dev/null +++ b/herddb-core/src/main/java/herddb/index/jvector/ReaderSupplierFactory.java @@ -0,0 +1,22 @@ +package herddb.index.jvector; + +import io.github.jbellis.jvector.disk.ReaderSupplier; +import io.github.jbellis.jvector.disk.SimpleMappedReaderSupplier; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +public class ReaderSupplierFactory { + public static ReaderSupplier open(Path path) throws IOException { + try { + return new MMapReaderSupplier(path); + } catch (UnsatisfiedLinkError|NoClassDefFoundError e) { + if (Files.size(path) > Integer.MAX_VALUE) { + throw new RuntimeException("File sizes greater than 2GB are not supported on Windows--contributions welcome"); + } + + return new SimpleMappedReaderSupplier(path); + } + } +} diff --git a/herddb-core/src/main/java/herddb/model/Index.java b/herddb-core/src/main/java/herddb/model/Index.java index c4d15183c..b8405d0e5 100644 --- a/herddb-core/src/main/java/herddb/model/Index.java +++ b/herddb-core/src/main/java/herddb/model/Index.java @@ -44,6 +44,7 @@ public class Index implements ColumnsList { public static final String TYPE_HASH = "hash"; public static final String TYPE_BRIN = "brin"; + public static final String TYPE_JVECTOR = "jvector"; private static final int PROPERTY_UNIQUE = 0x01; @@ -234,7 +235,9 @@ public Index build() { if (table == null || table.isEmpty()) { throw new IllegalArgumentException("table is not defined"); } - if (!TYPE_HASH.equals(type) && !TYPE_BRIN.equals(type)) { + if (!TYPE_HASH.equals(type) + && !TYPE_BRIN.equals(type) + && !TYPE_JVECTOR.equals(type)) { throw new IllegalArgumentException("only index type " + TYPE_HASH + "," + TYPE_BRIN + " are supported"); } if (columns.isEmpty()) { diff --git a/herddb-core/src/test/java/herddb/index/jvector/JVectorIndexManagerTest.java b/herddb-core/src/test/java/herddb/index/jvector/JVectorIndexManagerTest.java new file mode 100644 index 000000000..e92aa895e --- /dev/null +++ b/herddb-core/src/test/java/herddb/index/jvector/JVectorIndexManagerTest.java @@ -0,0 +1,103 @@ +package herddb.index.jvector; + +import herddb.core.AbstractTableManager; +import herddb.core.MemoryManager; +import herddb.index.IndexOperation; +import herddb.index.SecondaryIndexVectorSimilarityScan; +import herddb.log.LogSequenceNumber; +import herddb.mem.MemoryDataStorageManager; +import herddb.model.ColumnTypes; +import herddb.model.Index; +import herddb.model.Record; +import herddb.model.StatementEvaluationContext; +import herddb.model.StatementExecutionException; +import herddb.model.Table; +import herddb.model.TableContext; +import herddb.sql.SQLRecordKeyFunction; +import herddb.utils.Bytes; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static herddb.model.Index.TYPE_JVECTOR; +import static org.junit.Assert.assertEquals; + +public class JVectorIndexManagerTest { + + private static final Logger LOGGER = Logger.getLogger(JVectorIndexManagerTest.class.getName()); + + @Test + public void basicBuildAndSearch() { + String column = "embeddings"; + Table table = Table + .builder() + .name("table") + .column("key", ColumnTypes.INTEGER) + .column("embeddings", ColumnTypes.FLOATARRAY) + .primaryKey("key") + .build(); + Index index = Index + .builder() + .onTable(table) + .type(TYPE_JVECTOR) + .column(column, ColumnTypes.FLOATARRAY) + .build(); + AbstractTableManager abstractTableManager = null; + MemoryDataStorageManager memoryDataStorageManager = new MemoryDataStorageManager(); + JVectorIndexManager indexManager = new JVectorIndexManager(index, abstractTableManager, memoryDataStorageManager, "xxxx", null, -1, 10000, 1000); + + indexManager.start(LogSequenceNumber.START_OF_TIME); + Bytes vector1 = Bytes.from_float_array(new float[] {1, 2, 3, 4, 5}); + + Bytes vectorToSearch = vector1; + + Map data = new HashMap<>(); + + for (int i = 0; i < 100; i++) { + Bytes pk = Bytes.from_int(i); + double angle = i * Math.PI / 100; + float sin = (float) Math.sin( angle); + float cos = (float) Math.cos( angle); + Bytes vector = Bytes.from_float_array(new float[] {0, 0, 0, sin, cos}); + indexManager.recordInserted(pk, vector); + + data.put(pk, vector); + } + int topK = 10; + SQLRecordKeyFunction keyFunction = new DummyConstantValueFunction(table, vectorToSearch); + IndexOperation indexOperation = new SecondaryIndexVectorSimilarityScan(index.name, column, topK, keyFunction); + Stream scanner = indexManager.scanner(indexOperation, null, null); + List collect = scanner.collect(Collectors.toList()); + collect.forEach(k -> { + Bytes vector = data.get(k); + float compare = VectorSimilarityFunction.COSINE.compare(vector1.to_float_array(), vector.to_float_array()); + LOGGER.log(Level.INFO, "Found record with key {0} and value {1} compare {2}", new Object[] {k.to_int(), vector.to_float_array(), compare}); + }); + assertEquals(10, collect.size()); + + } + + private static class DummyConstantValueFunction extends SQLRecordKeyFunction { + private final Bytes vectorToSearch; + + public DummyConstantValueFunction(Table table, Bytes vectorToSearch) { + super(Collections.emptyList(), Collections.emptyList(), table); + this.vectorToSearch = vectorToSearch; + } + + @Override + public byte[] computeNewValue(Record previous, + StatementEvaluationContext context, + TableContext tableContext) throws StatementExecutionException { + return vectorToSearch.to_array(); + } + } +} diff --git a/herddb-docker/pom.xml b/herddb-docker/pom.xml index 3b10a1836..8f978aeff 100644 --- a/herddb-docker/pom.xml +++ b/herddb-docker/pom.xml @@ -78,7 +78,7 @@ com.google.cloud.tools jib-maven-plugin - 2.5.2 + 3.4.0