Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.windowing.AfterFirst;
import org.apache.beam.sdk.transforms.windowing.AfterPane;
import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Repeatedly;
import org.apache.beam.sdk.transforms.windowing.Trigger;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.Preconditions;
import org.apache.beam.sdk.values.KV;
Expand Down Expand Up @@ -84,6 +87,7 @@
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.DateTimeUtils;
import org.joda.time.DateTimeZone;
import org.joda.time.Duration;
import org.joda.time.format.DateTimeFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -159,7 +163,8 @@ static void ensureEOSSupport() {
@Override
public PCollection<Void> expand(PCollection<ProducerRecord<K, V>> input) {
String topic = Preconditions.checkStateNotNull(spec.getTopic());

int numElements = spec.getEosTriggerNumElements();
Duration timeout = spec.getEosTriggerTimeout();
int numShards = spec.getNumShards();
if (numShards <= 0) {
try (Consumer<?, ?> consumer = openConsumer(spec)) {
Expand All @@ -172,17 +177,34 @@ public PCollection<Void> expand(PCollection<ProducerRecord<K, V>> input) {
}
}
checkState(numShards > 0, "Could not set number of shards");

Trigger.OnceTrigger trigger = null;
if (timeout != null) {
trigger =
AfterFirst.of(
AfterPane.elementCountAtLeast(numElements),
AfterProcessingTime.pastFirstElementInPane().plusDelayOf(timeout));
} else {
// fallback to default
trigger = AfterPane.elementCountAtLeast(numElements);
}
return input
.apply(
Window.<ProducerRecord<K, V>>into(new GlobalWindows()) // Everything into global window.
.triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1)))
.triggering(Repeatedly.forever(trigger))
.discardingFiredPanes())
.apply(
String.format("Shuffle across %d shards", numShards),
ParDo.of(new Reshard<>(numShards)))
.apply("Persist sharding", GroupByKey.create())
.apply("Assign sequential ids", ParDo.of(new Sequencer<>()))
// Reapply the windowing configuration as the continuation trigger doesn't maintain the
// desired batching.
.apply(
"Windowing",
Window.<KV<Integer, KV<Long, TimestampedValue<ProducerRecord<K, V>>>>>into(
new GlobalWindows()) // Everything into global window.
.triggering(Repeatedly.forever(trigger))
.discardingFiredPanes())
.apply("Persist ids", GroupByKey.create())
.apply(
String.format("Write to Kafka topic '%s'", spec.getTopic()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,8 @@ public static <K, V> WriteRecords<K, V> writeRecords() {
return new AutoValue_KafkaIO_WriteRecords.Builder<K, V>()
.setProducerConfig(WriteRecords.DEFAULT_PRODUCER_PROPERTIES)
.setEOS(false)
.setEosTriggerNumElements(1) // keep default numElements
.setEosTriggerTimeout(null) // keep default trigger (timeout)
.setNumShards(0)
.setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN)
.setBadRecordRouter(BadRecordRouter.THROWING_ROUTER)
Expand Down Expand Up @@ -3185,6 +3187,10 @@ public abstract static class WriteRecords<K, V>
@Pure
public abstract boolean isEOS();

public abstract int getEosTriggerNumElements();

public abstract @Nullable Duration getEosTriggerTimeout();

@Pure
public abstract @Nullable String getSinkGroupId();

Expand Down Expand Up @@ -3221,6 +3227,10 @@ abstract Builder<K, V> setPublishTimestampFunction(

abstract Builder<K, V> setEOS(boolean eosEnabled);

abstract Builder<K, V> setEosTriggerNumElements(int numElements);

abstract Builder<K, V> setEosTriggerTimeout(@Nullable Duration timeout);

abstract Builder<K, V> setSinkGroupId(String sinkGroupId);

abstract Builder<K, V> setNumShards(int numShards);
Expand Down Expand Up @@ -3368,6 +3378,15 @@ public WriteRecords<K, V> withEOS(int numShards, String sinkGroupId) {
return toBuilder().setEOS(true).setNumShards(numShards).setSinkGroupId(sinkGroupId).build();
}

public WriteRecords<K, V> withEOSTriggerConfig(int numElements, Duration timeout) {
checkArgument(numElements >= 1, "numElements should be >= 1");
checkArgument(timeout != null, "timeout is required for exactly-once sink");
return toBuilder()
.setEosTriggerNumElements(numElements)
.setEosTriggerTimeout(timeout)
.build();
}

/**
* When exactly-once semantics are enabled (see {@link #withEOS(int, String)}), the sink needs
* to fetch previously stored state with Kafka topic. Fetching the metadata requires a consumer.
Expand Down Expand Up @@ -3653,6 +3672,19 @@ public Write<K, V> withEOS(int numShards, String sinkGroupId) {
return withWriteRecordsTransform(getWriteRecordsTransform().withEOS(numShards, sinkGroupId));
}

/**
* Set the frequency and numElements threshold at which messages are triggered.
*
* <p>This is only applicable when the write method is set to EOS.
*
* <p>Every timeout duration, or numElements (repeated, after first condition is met) collection
* of elements written.
*/
public Write<K, V> withEOSTriggerConfig(int numElements, Duration timeout) {
return withWriteRecordsTransform(
getWriteRecordsTransform().withEOSTriggerConfig(numElements, timeout));
}

/**
* Wrapper method over {@link WriteRecords#withConsumerFactoryFn(SerializableFunction)}, used to
* keep the compatibility with old API based on KV type of element.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ static class KafkaIOWriteTranslator implements TransformPayloadTranslator<Write<
.addNullableByteArrayField("producer_factory_fn")
.addNullableByteArrayField("publish_timestamp_fn")
.addBooleanField("eos")
.addInt32Field("eos_trigger_num_elements")
.addNullableInt64Field("eos_trigger_timeout_ms")
.addInt32Field("num_shards")
.addNullableStringField("sink_group_id")
.addNullableByteArrayField("consumer_factory_fn")
Expand Down Expand Up @@ -547,6 +549,11 @@ public Row toConfigRow(Write<?, ?> transform) {
}

fieldValues.put("eos", writeRecordsTransform.isEOS());
org.joda.time.Duration eosTriggerTimeout = writeRecordsTransform.getEosTriggerTimeout();
if (eosTriggerTimeout != null) {
fieldValues.put("eos_trigger_timeout_ms", eosTriggerTimeout.getMillis());
}
fieldValues.put("eos_trigger_num_elements", writeRecordsTransform.getEosTriggerNumElements());
fieldValues.put("num_shards", writeRecordsTransform.getNumShards());

if (writeRecordsTransform.getSinkGroupId() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ public class KafkaIOTranslationTest {
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getValueSerializer", "value_serializer");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getPublishTimestampFunction", "publish_timestamp_fn");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("isEOS", "eos");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getEosTriggerTimeout", "eos_trigger_timeout_ms");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getEosTriggerNumElements", "eos_trigger_num_elements");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getSinkGroupId", "sink_group_id");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getNumShards", "num_shards");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getConsumerFactoryFn", "consumer_factory_fn");
Expand Down
Loading