Skip to content

Commit 1f4ea3f

Browse files
committed
fix: add VecType to codegen
1 parent 7f0e0ff commit 1f4ea3f

File tree

6 files changed

+61
-7
lines changed

6 files changed

+61
-7
lines changed

paimon-codegen/src/main/scala/org/apache/paimon/codegen/GenerateUtils.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,13 @@ object GenerateUtils {
127127
s"$sortUtil.compareBinary($leftTerm, $rightTerm)"
128128
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | DATE | TIME_WITHOUT_TIME_ZONE =>
129129
s"($leftTerm > $rightTerm ? 1 : $leftTerm < $rightTerm ? -1 : 0)"
130-
case ARRAY =>
131-
val at = t.asInstanceOf[ArrayType]
130+
case ARRAY | VECTOR =>
131+
val elementType = t.getTypeRoot match {
132+
case ARRAY => t.asInstanceOf[ArrayType].getElementType
133+
case VECTOR => t.asInstanceOf[VecType].getElementType
134+
}
132135
val compareFunc = newName("compareArray")
133-
val compareCode = generateArrayCompare(ctx, nullsIsLast = false, at, "a", "b")
136+
val compareCode = generateArrayCompare(ctx, nullsIsLast = false, elementType, "a", "b")
134137
val funcCode: String =
135138
s"""
136139
public int $compareFunc($ARRAY_DATA a, $ARRAY_DATA b) {
@@ -188,11 +191,10 @@ object GenerateUtils {
188191
def generateArrayCompare(
189192
ctx: CodeGeneratorContext,
190193
nullsIsLast: Boolean,
191-
arrayType: ArrayType,
194+
elementType: DataType,
192195
leftTerm: String,
193196
rightTerm: String): String = {
194197
val nullIsLastRet = if (nullsIsLast) 1 else -1
195-
val elementType = arrayType.getElementType
196198
val fieldA = newName("fieldA")
197199
val isNullA = newName("isNullA")
198200
val lengthA = newName("lengthA")
@@ -379,6 +381,7 @@ object GenerateUtils {
379381
case DOUBLE => className[JDouble]
380382
case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE => className[Timestamp]
381383
case ARRAY => className[InternalArray]
384+
case VECTOR => className[InternalVec]
382385
case MULTISET | MAP => className[InternalMap]
383386
case ROW => className[InternalRow]
384387
case VARIANT => className[Variant]
@@ -417,6 +420,8 @@ object GenerateUtils {
417420
s"$rowTerm.getTimestamp($indexTerm, ${getPrecision(t)})"
418421
case ARRAY =>
419422
s"$rowTerm.getArray($indexTerm)"
423+
case VECTOR =>
424+
s"$rowTerm.getVec($indexTerm)"
420425
case MULTISET | MAP =>
421426
s"$rowTerm.getMap($indexTerm)"
422427
case ROW =>

paimon-codegen/src/main/scala/org/apache/paimon/codegen/ScalarOperatorGens.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,13 @@ object ScalarOperatorGens {
6060
}
6161
// array types
6262
else if (isArray(left.resultType) && canEqual) {
63-
generateArrayComparison(ctx, left, right, resultType)
63+
val elementType = left.resultType.asInstanceOf[ArrayType].getElementType
64+
generateArrayComparison(ctx, left, right, elementType, resultType)
65+
}
66+
// vector type
67+
else if (isVec(left.resultType) && canEqual) {
68+
val elementType = left.resultType.asInstanceOf[VecType].getElementType
69+
generateArrayComparison(ctx, left, right, elementType, resultType)
6470
}
6571
// map types
6672
else if (isMap(left.resultType) && canEqual) {
@@ -196,6 +202,7 @@ object ScalarOperatorGens {
196202
ctx: CodeGeneratorContext,
197203
left: GeneratedExpression,
198204
right: GeneratedExpression,
205+
elementType: DataType,
199206
resultType: DataType): GeneratedExpression = {
200207
generateCallWithStmtIfArgsNotNull(ctx, resultType, Seq(left, right)) {
201208
args =>
@@ -204,7 +211,6 @@ object ScalarOperatorGens {
204211

205212
val resultTerm = newName("compareResult")
206213

207-
val elementType = left.resultType.asInstanceOf[ArrayType].getElementType
208214
val elementCls = primitiveTypeTermForType(elementType)
209215
val elementDefault = primitiveDefaultValue(elementType)
210216

@@ -225,6 +231,7 @@ object ScalarOperatorGens {
225231
rightElementExpr,
226232
new BooleanType(elementType.isNullable))
227233

234+
// TODO: With BinaryVec available, we can use it here.
228235
val stmt =
229236
s"""
230237
|boolean $resultTerm;

paimon-codegen/src/test/java/org/apache/paimon/codegen/EqualiserCodeGeneratorTest.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.paimon.data.serializer.InternalArraySerializer;
3131
import org.apache.paimon.data.serializer.InternalMapSerializer;
3232
import org.apache.paimon.data.serializer.InternalRowSerializer;
33+
import org.apache.paimon.data.serializer.InternalVecSerializer;
3334
import org.apache.paimon.data.serializer.Serializer;
3435
import org.apache.paimon.data.variant.GenericVariant;
3536
import org.apache.paimon.types.DataType;
@@ -133,6 +134,16 @@ public class EqualiserCodeGeneratorTest {
133134
castFromString("[1,2,3]", DataTypes.ARRAY(new VarCharType())),
134135
castFromString("[4,5,6]", DataTypes.ARRAY(new VarCharType()))),
135136
new InternalArraySerializer(DataTypes.VARCHAR(1))));
137+
TEST_DATA.put(
138+
DataTypeRoot.VECTOR,
139+
new GeneratedData(
140+
DataTypes.VECTOR(3, DataTypes.FLOAT()),
141+
Pair.of(
142+
castFromString(
143+
"[1.1,2.2,3.3]", DataTypes.VECTOR(3, DataTypes.FLOAT())),
144+
castFromString(
145+
"[4.4,5.5,6.6]", DataTypes.VECTOR(3, DataTypes.FLOAT()))),
146+
new InternalVecSerializer(DataTypes.FLOAT(), 3)));
136147
TEST_DATA.put(
137148
DataTypeRoot.MULTISET,
138149
new GeneratedData(

paimon-common/src/main/java/org/apache/paimon/utils/TypeCheckUtils.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import static org.apache.paimon.types.DataTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE;
3434
import static org.apache.paimon.types.DataTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
3535
import static org.apache.paimon.types.DataTypeRoot.VARIANT;
36+
import static org.apache.paimon.types.DataTypeRoot.VECTOR;
3637

3738
/** Utils for type. */
3839
public class TypeCheckUtils {
@@ -85,6 +86,10 @@ public static boolean isArray(DataType type) {
8586
return type.getTypeRoot() == ARRAY;
8687
}
8788

89+
public static boolean isVec(DataType type) {
90+
return type.getTypeRoot() == VECTOR;
91+
}
92+
8893
public static boolean isMap(DataType type) {
8994
return type.getTypeRoot() == MAP;
9095
}
@@ -110,6 +115,7 @@ public static boolean isComparable(DataType type) {
110115
&& !isMultiset(type)
111116
&& !isRow(type)
112117
&& !isArray(type)
118+
&& !isVec(type)
113119
&& !isVariant(type)
114120
&& !isBlob(type);
115121
}
@@ -120,6 +126,7 @@ public static boolean isMutable(DataType type) {
120126
case CHAR:
121127
case VARCHAR: // the internal representation of String is BinaryString which is mutable
122128
case ARRAY:
129+
case VECTOR:
123130
case MULTISET:
124131
case MAP:
125132
case ROW:

paimon-common/src/main/java/org/apache/paimon/utils/TypeUtils.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,26 @@
1818

1919
package org.apache.paimon.utils;
2020

21+
import org.apache.paimon.data.ArrayBasedVec;
2122
import org.apache.paimon.data.BinaryString;
2223
import org.apache.paimon.data.Decimal;
2324
import org.apache.paimon.data.GenericArray;
2425
import org.apache.paimon.data.GenericMap;
2526
import org.apache.paimon.data.GenericRow;
27+
import org.apache.paimon.data.InternalArray;
2628
import org.apache.paimon.types.ArrayType;
2729
import org.apache.paimon.types.DataField;
2830
import org.apache.paimon.types.DataType;
2931
import org.apache.paimon.types.DataTypeChecks;
3032
import org.apache.paimon.types.DataTypeRoot;
33+
import org.apache.paimon.types.DataTypes;
3134
import org.apache.paimon.types.DecimalType;
3235
import org.apache.paimon.types.LocalZonedTimestampType;
3336
import org.apache.paimon.types.MapType;
3437
import org.apache.paimon.types.RowType;
3538
import org.apache.paimon.types.TimestampType;
3639
import org.apache.paimon.types.VarCharType;
40+
import org.apache.paimon.types.VecType;
3741

3842
import org.apache.paimon.shade.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
3943
import org.apache.paimon.shade.jackson2.com.fasterxml.jackson.databind.JsonNode;
@@ -212,6 +216,16 @@ public static Object castFromStringInternal(String s, DataType type, boolean isC
212216
throw new RuntimeException(
213217
String.format("Failed to parse Json String %s", s), e);
214218
}
219+
case VECTOR:
220+
VecType vecType = (VecType) type;
221+
DataType vecElementType = vecType.getElementType();
222+
Object vecBaseArr =
223+
castFromStringInternal(s, DataTypes.ARRAY(vecElementType), isCdcValue);
224+
if (vecBaseArr instanceof InternalArray) {
225+
return ArrayBasedVec.from((InternalArray) vecBaseArr);
226+
} else {
227+
throw new RuntimeException("Failed to make array during building a vector");
228+
}
215229
case MAP:
216230
MapType mapType = (MapType) type;
217231
DataType keyType = mapType.getKeyType();
@@ -333,6 +347,7 @@ public static boolean isInteroperable(DataType t1, DataType t2) {
333347

334348
switch (t1.getTypeRoot()) {
335349
case ARRAY:
350+
case VECTOR:
336351
case MAP:
337352
case MULTISET:
338353
case ROW:

paimon-core/src/test/java/org/apache/paimon/codegen/CodeGenUtilsTest.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import static org.apache.paimon.types.DataTypes.DOUBLE;
3333
import static org.apache.paimon.types.DataTypes.INT;
3434
import static org.apache.paimon.types.DataTypes.STRING;
35+
import static org.apache.paimon.types.DataTypes.VECTOR;
3536
import static org.assertj.core.api.Assertions.assertThat;
3637

3738
class CodeGenUtilsTest {
@@ -74,6 +75,14 @@ public void testRecordComparatorCodegenCache() {
7475
() -> newRecordComparator(Arrays.asList(STRING(), INT()), new int[] {0, 1}, true));
7576
}
7677

78+
@Test
79+
public void testRecordComparatorCodegenCacheWithVec() {
80+
assertClassEquals(
81+
() ->
82+
newRecordComparator(
83+
Arrays.asList(STRING(), VECTOR(3, INT())), new int[] {0, 1}, true));
84+
}
85+
7786
@Test
7887
public void testRecordComparatorCodegenCacheMiss() {
7988
assertClassNotEquals(

0 commit comments

Comments
 (0)