Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public abstract class AbstractThriftReader extends DorisReader {
private int readCount = 0;

private final Boolean datetimeJava8ApiEnabled;
private final Boolean enableArrayTypeInference;

protected AbstractThriftReader(DorisReaderPartition partition) throws Exception {
super(partition);
Expand Down Expand Up @@ -112,6 +113,7 @@ protected AbstractThriftReader(DorisReaderPartition partition) throws Exception
this.asyncThread = null;
}
this.datetimeJava8ApiEnabled = partition.getDateTimeJava8APIEnabled();
this.enableArrayTypeInference = config.getValue(DorisOptions.DORIS_READ_ARRAY_TYPE_INFERENCE);
}

private void runAsync() throws DorisException, InterruptedException {
Expand All @@ -128,7 +130,7 @@ private void runAsync() throws DorisException, InterruptedException {
});
endOfStream.set(nextResult.isEos());
if (!endOfStream.get()) {
rowBatch = new RowBatch(nextResult, dorisSchema, datetimeJava8ApiEnabled);
rowBatch = new RowBatch(nextResult, dorisSchema, datetimeJava8ApiEnabled, enableArrayTypeInference);
offset += rowBatch.getReadRowCount();
rowBatch.close();
rowBatchQueue.put(rowBatch);
Expand Down Expand Up @@ -182,7 +184,7 @@ public boolean hasNext() throws DorisException {
});
endOfStream.set(nextResult.isEos());
if (!endOfStream.get()) {
rowBatch = new RowBatch(nextResult, dorisSchema, datetimeJava8ApiEnabled);
rowBatch = new RowBatch(nextResult, dorisSchema, datetimeJava8ApiEnabled, enableArrayTypeInference);
}
}
hasNext = !endOfStream.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public class DorisFlightSqlReader extends DorisReader {
private AdbcConnection connection;
private final ArrowReader arrowReader;
private final Boolean datetimeJava8ApiEnabled;
private final Boolean enableArrayTypeInference;

public DorisFlightSqlReader(DorisReaderPartition partition) throws Exception {
super(partition);
Expand All @@ -85,6 +86,7 @@ public DorisFlightSqlReader(DorisReaderPartition partition) throws Exception {
this.schema = processDorisSchema(partition);
this.arrowReader = executeQuery();
this.datetimeJava8ApiEnabled = partition.getDateTimeJava8APIEnabled();
this.enableArrayTypeInference = partition.getConfig().getValue(DorisOptions.DORIS_READ_ARRAY_TYPE_INFERENCE);
}

@Override
Expand All @@ -96,7 +98,7 @@ public boolean hasNext() throws DorisException {
throw new DorisException(e);
}
if (!endOfStream.get()) {
rowBatch = new RowBatch(arrowReader, schema, datetimeJava8ApiEnabled);
rowBatch = new RowBatch(arrowReader, schema, datetimeJava8ApiEnabled, enableArrayTypeInference);
}
}
return !endOfStream.get();
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ public class DorisOptions {

public static final ConfigOption<Boolean> DORIS_READ_BITMAP_TO_BASE64 = ConfigOptions.name("doris.read.bitmap-to-base64").booleanType().defaultValue(false).withDescription("");

/**
* Enable Arrow type inference for ARRAY elements.
* When enabled, the connector will infer precise element types from Arrow schema instead of defaulting to StringType.
* Default: false (backward compatible, maintains existing behavior)
*/
public static final ConfigOption<Boolean> DORIS_READ_ARRAY_TYPE_INFERENCE = ConfigOptions.name("doris.read.array.type.inference").booleanType().defaultValue(false).withDescription("Enable Arrow type inference for ARRAY elements. When enabled, precise element types are inferred from Arrow schema instead of defaulting to StringType.");

public static final ConfigOption<Integer> DORIS_SINK_NET_BUFFER_SIZE = ConfigOptions.name("doris.sink.net.buffer.size").intType().defaultValue(1024 * 1024).withDescription("");


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ import com.fasterxml.jackson.databind.json.JsonMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Instant, LocalDate, LocalDateTime, ZoneId}
import java.util
import scala.collection.JavaConverters.mapAsScalaMapConverter
import scala.collection.JavaConverters.{asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter}
import scala.collection.mutable

object RowConvertors {
Expand Down Expand Up @@ -120,19 +120,135 @@ object RowConvertors {
}

def convertValue(v: Any, dataType: DataType, datetimeJava8ApiEnabled: Boolean): Any = {
dataType match {
case StringType => UTF8String.fromString(v.asInstanceOf[String])
case TimestampType if datetimeJava8ApiEnabled => DateTimeUtils.instantToMicros(v.asInstanceOf[Instant])
case TimestampType => DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp])
case DateType if datetimeJava8ApiEnabled => v.asInstanceOf[LocalDate].toEpochDay.toInt
case DateType => DateTimeUtils.fromJavaDate(v.asInstanceOf[Date])
case _: MapType =>
val map = v.asInstanceOf[java.util.Map[String, String]].asScala
val keys = map.keys.toArray.map(UTF8String.fromString)
val values = map.values.toArray.map(UTF8String.fromString)
ArrayBasedMapData(keys, values)
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | _: DecimalType => v
case _ => throw new Exception(s"Unsupported spark type: ${dataType.typeName}")
dataType match {
case StringType => UTF8String.fromString(v.asInstanceOf[String])
case TimestampType if datetimeJava8ApiEnabled => DateTimeUtils.instantToMicros(v.asInstanceOf[Instant])
case TimestampType => DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp])
case DateType if datetimeJava8ApiEnabled => v.asInstanceOf[LocalDate].toEpochDay.toInt
case DateType => DateTimeUtils.fromJavaDate(v.asInstanceOf[Date])
case _: MapType =>
val map = v.asInstanceOf[java.util.Map[String, String]].asScala
val keys = map.keys.toArray.map(UTF8String.fromString)
val values = map.values.toArray.map(UTF8String.fromString)
ArrayBasedMapData(keys, values)
case at: ArrayType =>
// Handle different input types for ARRAY
// 1. ArrayData (from Spark DataFrame) - already in correct format, return directly
// 2. String (from RowBatch JSON serialization) - parse JSON to List
// 3. List (from direct conversion) - convert to ArrayData
v match {
case arrayData: ArrayData =>
// Already ArrayData (e.g., from DataFrame operations), return directly
// This handles cases where DataFrame is converted to RDD and ARRAY columns are already ArrayData
arrayData
case s: String =>
// ARRAY data from RowBatch comes as JSON string, convert back to List
// Examples: "[\"Alice\",\"Bob\"]" or "[1,2,3]"
val inputValue = try {
MAPPER.readValue(s, classOf[java.util.List[Any]])
} catch {
case _: Exception =>
// Fallback: return empty list if JSON parsing fails
new java.util.ArrayList[Any]()
}
convertListToArrayData(inputValue, at, datetimeJava8ApiEnabled)
case list: java.util.List[Any] =>
// Already a List (e.g., from direct conversion path)
convertListToArrayData(list, at, datetimeJava8ApiEnabled)
case _ =>
// Unexpected type, return empty ArrayData
ArrayData.toArrayData(new Array[Any](0))
}
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | _: DecimalType => v
case _ => throw new Exception(s"Unsupported spark type: ${dataType.typeName}")
}
}

/**
* Convert Java List to Spark ArrayData
* This is a helper method to convert List objects (from JSON parsing or direct conversion) to Spark's ArrayData format
*
* @param javaList the Java List to convert
* @param arrayType the ArrayType schema definition
* @param datetimeJava8ApiEnabled whether to use Java 8 date/time API
* @return Spark ArrayData object
*/
private def convertListToArrayData(javaList: java.util.List[Any], arrayType: ArrayType, datetimeJava8ApiEnabled: Boolean): ArrayData = {
// Performance optimization: Pre-allocate array with known size
val listSize = javaList.size()
val elements = new Array[Any](listSize)
var i = 0
val iterator = javaList.iterator()
while (iterator.hasNext) {
val element = iterator.next()
if (element == null) {
elements(i) = null
} else {
// Recursively convert element based on element type
elements(i) = convertArrayElement(element, arrayType.elementType, datetimeJava8ApiEnabled)
}
i += 1
}

// Create ArrayData based on element type with proper type conversion
// If schema declares StringType but actual data is another type, convert to string
try {
arrayType.elementType match {
case BooleanType => ArrayData.toArrayData(elements.map(_.asInstanceOf[Boolean]))
case ByteType => ArrayData.toArrayData(elements.map(_.asInstanceOf[Byte]))
case ShortType => ArrayData.toArrayData(elements.map(_.asInstanceOf[Short]))
case IntegerType => ArrayData.toArrayData(elements.map(_.asInstanceOf[Int]))
case LongType => ArrayData.toArrayData(elements.map(_.asInstanceOf[Long]))
case FloatType => ArrayData.toArrayData(elements.map(_.asInstanceOf[Float]))
case DoubleType => ArrayData.toArrayData(elements.map(_.asInstanceOf[Double]))
case StringType => ArrayData.toArrayData(elements.map(e => UTF8String.fromString(e.toString)))
case _ => ArrayData.toArrayData(elements)
}
} catch {
case _: ClassCastException =>
// Fallback: if type conversion fails, convert to string array
// This handles cases where schema declares a specific type but actual data is another type
// Common scenario: Schema declares StringType (default) but actual data is IntegerType, etc.
ArrayData.toArrayData(elements.map(e => UTF8String.fromString(e.toString)))
}
}

/**
* Recursively convert array element based on element type
* Supports nested arrays and all primitive types with proper type conversion
*
* @param element the element value to convert
* @param elementType the expected Spark DataType for the element
* @param datetimeJava8ApiEnabled whether to use Java 8 date/time API
* @return converted element value suitable for Spark InternalRow
*/
private def convertArrayElement(element: Any, elementType: DataType, datetimeJava8ApiEnabled: Boolean): Any = {
elementType match {
case StringType => UTF8String.fromString(element.toString)
case TimestampType if datetimeJava8ApiEnabled && element.isInstanceOf[LocalDateTime] =>
val localDateTime = element.asInstanceOf[LocalDateTime]
val instant = localDateTime.atZone(ZoneId.systemDefault()).toInstant
DateTimeUtils.instantToMicros(instant)
case TimestampType if element.isInstanceOf[Timestamp] =>
DateTimeUtils.fromJavaTimestamp(element.asInstanceOf[Timestamp])
case DateType if datetimeJava8ApiEnabled && element.isInstanceOf[LocalDate] =>
element.asInstanceOf[LocalDate].toEpochDay.toInt
case DateType if element.isInstanceOf[Date] =>
DateTimeUtils.fromJavaDate(element.asInstanceOf[Date])
case nestedArray: ArrayType =>
// Handle nested arrays recursively
val nestedJavaList = element.asInstanceOf[java.util.List[Any]]
val nestedListSize = nestedJavaList.size()
val nestedElements = new Array[Any](nestedListSize)
var j = 0
val nestedIterator = nestedJavaList.iterator()
while (nestedIterator.hasNext) {
val e = nestedIterator.next()
nestedElements(j) = convertArrayElement(e, nestedArray.elementType, datetimeJava8ApiEnabled)
j += 1
}
ArrayData.toArrayData(nestedElements)
case _ => element
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,49 @@ package org.apache.doris.spark.util

import org.apache.doris.sdk.thrift.TScanColumnDesc
import org.apache.doris.spark.rest.models.{Field, Schema}
import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, MapType}
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, MapType}

object SchemaConvertors {

/**
* Convert inferred element type string to Spark DataType
* Used for ARRAY type precision inference
*
* @param elementTypeString the inferred element type (e.g., "INT", "STRING", "ARRAY<INT>")
* @param precision precision for decimal types
* @param scale scale for decimal types
* @return Spark DataType
*/
def elementTypeStringToDataType(elementTypeString: String, precision: Int = -1, scale: Int = -1): DataType = {
if (elementTypeString == null || elementTypeString.isEmpty) {
return DataTypes.StringType // Default fallback
}

// Handle nested arrays (e.g., "ARRAY<INT>")
if (elementTypeString.startsWith("ARRAY<") && elementTypeString.endsWith(">")) {
val innerType = elementTypeString.substring(6, elementTypeString.length - 1)
val elementType = elementTypeStringToDataType(innerType, precision, scale)
return ArrayType(elementType, containsNull = true)
}

// Handle primitive types
elementTypeString match {
case "TINYINT" => DataTypes.ByteType
case "SMALLINT" => DataTypes.ShortType
case "INT" => DataTypes.IntegerType
case "BIGINT" => DataTypes.LongType
case "FLOAT" => DataTypes.FloatType
case "DOUBLE" => DataTypes.DoubleType
case "STRING" => DataTypes.StringType
case "BOOLEAN" => DataTypes.BooleanType
case "DATE" => DataTypes.DateType
case "DATETIME" => DataTypes.TimestampType
case "BINARY" => DataTypes.BinaryType
case "DECIMAL" => if (precision > 0 && scale >= 0) DecimalType(precision, scale) else DecimalType(10, 0)
case _ => DataTypes.StringType // Default fallback
}
}

@throws[IllegalArgumentException]
def toCatalystType(dorisType: String, precision: Int, scale: Int): DataType = {
dorisType match {
Expand Down Expand Up @@ -52,7 +91,11 @@ object SchemaConvertors {
case "DECIMAL128" => DecimalType(precision, scale)
case "TIME" => DataTypes.DoubleType
case "STRING" => DataTypes.StringType
case "ARRAY" => DataTypes.StringType
case "ARRAY" => ArrayType(DataTypes.StringType, containsNull = true)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this change the previous behavior?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will change the previous state, but the data content will remain compatible.

Core Changes

Previous:
Schema: tags: string
Value: "[\"Alice\",\"Bob\"]" (string)
Array operations were not supported.

Now:
Schema: tags: array
Value: WrappedArray(Alice, Bob) (array)
Explode, array_contains, etc., are supported.

Impact

Schema type change: StringType → ArrayType(StringType)
Value type change: String → ArrayData
Data content compatibility: Elements are still strings ["Alice", "Bob"]

I understand your concerns. User code that relies on StringType checks or uses row.getString() will need to be adapted.

// Default to ArrayType(StringType) for backward compatibility.
// Actual element type is inferred from Arrow schema at runtime during data conversion.
// RowBatch will read actual element types from Arrow ListVector, and RowConvertors
// will handle type conversions automatically (e.g., Integer -> String if schema declares StringType).
case "MAP" => MapType(DataTypes.StringType, DataTypes.StringType)
case "STRUCT" => DataTypes.StringType
case "VARIANT" => DataTypes.StringType
Expand Down
Loading
Loading