Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.apache.doris.nereids.trees.plans.distribute.worker.job.AssignedJobBuilder;
import org.apache.doris.nereids.trees.plans.distribute.worker.job.BucketScanSource;
import org.apache.doris.nereids.trees.plans.distribute.worker.job.DefaultScanSource;
import org.apache.doris.nereids.trees.plans.distribute.worker.job.LocalShuffleAssignedJob;
import org.apache.doris.nereids.trees.plans.distribute.worker.job.StaticAssignedJob;
import org.apache.doris.nereids.trees.plans.distribute.worker.job.UnassignedJob;
import org.apache.doris.nereids.trees.plans.distribute.worker.job.UnassignedJobBuilder;
Expand Down Expand Up @@ -175,16 +174,12 @@ private FragmentIdMapping<DistributedPlan> buildDistributePlans(
}

private FragmentIdMapping<DistributedPlan> linkPlans(FragmentIdMapping<DistributedPlan> plans) {
boolean enableShareHashTableForBroadcastJoin = statementContext.getConnectContext()
.getSessionVariable()
.enableShareHashTableForBroadcastJoin;
for (DistributedPlan receiverPlan : plans.values()) {
for (Entry<ExchangeNode, DistributedPlan> link : receiverPlan.getInputs().entries()) {
linkPipelinePlan(
(PipelineDistributedPlan) receiverPlan,
(PipelineDistributedPlan) link.getValue(),
link.getKey(),
enableShareHashTableForBroadcastJoin
link.getKey()
);
for (Entry<DataSink, List<AssignedJob>> kv :
((PipelineDistributedPlan) link.getValue()).getDestinations().entrySet()) {
Expand All @@ -205,11 +200,10 @@ private FragmentIdMapping<DistributedPlan> linkPlans(FragmentIdMapping<Distribut
private void linkPipelinePlan(
PipelineDistributedPlan receiverPlan,
PipelineDistributedPlan senderPlan,
ExchangeNode linkNode,
boolean enableShareHashTableForBroadcastJoin) {
ExchangeNode linkNode) {

List<AssignedJob> receiverInstances = filterInstancesWhichCanReceiveDataFromRemote(
receiverPlan, enableShareHashTableForBroadcastJoin, linkNode);
receiverPlan, linkNode);
if (linkNode.getPartitionType() == TPartitionType.BUCKET_SHFFULE_HASH_PARTITIONED) {
receiverInstances = getDestinationsByBuckets(receiverPlan, receiverInstances);
}
Expand Down Expand Up @@ -239,13 +233,15 @@ private List<AssignedJob> getDestinationsByBuckets(

private List<AssignedJob> filterInstancesWhichCanReceiveDataFromRemote(
PipelineDistributedPlan receiverPlan,
boolean enableShareHashTableForBroadcastJoin,
ExchangeNode linkNode) {
boolean useLocalShuffle = receiverPlan.getInstanceJobs().stream()
.anyMatch(LocalShuffleAssignedJob.class::isInstance);
if (useLocalShuffle) {
return getFirstInstancePerWorker(receiverPlan.getInstanceJobs());
} else if (enableShareHashTableForBroadcastJoin && linkNode.isRightChildOfBroadcastHashJoin()) {
// isSerialOperator(): UNPARTITIONED or use_serial_exchange (operator-level)
// useSerialSource(): fragment is in pooling mode (fragment-level guard)
// Both must be true: serial exchange semantics AND pooling mode active.
// Note: cannot combine these into isSerialOperator() because useSerialSource()
// calls planRoot.isSerialOperator() which would cause infinite recursion.
if (linkNode.isSerialOperator()
&& linkNode.getFragment().useSerialSource(
statementContext.getConnectContext())) {
return getFirstInstancePerWorker(receiverPlan.getInstanceJobs());
} else {
return receiverPlan.getInstanceJobs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,12 @@ public void setMergeInfo(SortInfo info) {

@Override
protected void toThrift(TPlanNode msg) {
// If this fragment has another scan node, this exchange node is serial or not should be decided by the scan
// node.
msg.setIsSerialOperator((isSerialOperator() || fragment.hasSerialScanNode())
// is_serial = operator-level serial AND fragment-level pooling guard.
// isSerialOperator(): only UNPARTITIONED or use_serial_exchange (not bucket shuffle/hash).
// useSerialSource(): pooling mode where BE manages per-pipeline parallelism.
// Note: useSerialSource() calls planRoot.isSerialOperator(), so we must NOT call
// useSerialSource() inside isSerialOperator() to avoid infinite recursion.
msg.setIsSerialOperator(isSerialOperator()
&& fragment.useSerialSource(ConnectContext.get()));
msg.node_type = TPlanNodeType.EXCHANGE_NODE;
msg.exchange_node = new TExchangeNode();
Expand Down Expand Up @@ -153,8 +156,10 @@ public void setRightChildOfBroadcastHashJoin(boolean value) {
*/
@Override
public boolean isSerialOperator() {
return (ConnectContext.get() != null && ConnectContext.get().getSessionVariable().isUseSerialExchange()
|| partitionType == TPartitionType.UNPARTITIONED) && mergeInfo == null;
boolean forceSerialExchange = ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().isUseSerialExchange();
boolean unPartition = partitionType == TPartitionType.UNPARTITIONED;
return (forceSerialExchange || unPartition) && mergeInfo == null;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.plans.distribute;

import org.apache.doris.nereids.trees.plans.distribute.worker.DistributedPlanWorker;
import org.apache.doris.nereids.trees.plans.distribute.worker.job.AssignedJob;
import org.apache.doris.nereids.trees.plans.distribute.worker.job.BucketScanSource;
import org.apache.doris.nereids.trees.plans.distribute.worker.job.LocalShuffleBucketJoinAssignedJob;
import org.apache.doris.nereids.trees.plans.distribute.worker.job.ScanRanges;
import org.apache.doris.nereids.trees.plans.distribute.worker.job.UnassignedJob;
import org.apache.doris.planner.ScanNode;
import org.apache.doris.thrift.TUniqueId;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
* Tests that bucket shuffle + pooling (LocalShuffle) produces correct destinations:
* - destinations count == bucketNum
* - same BE's buckets all point to the same (first) instance
*
* Uses reflection to call DistributePlanner.sortDestinationInstancesByBuckets (private method).
*/
public class BucketShuffleDestinationTest {

/**
* Simulate 3 workers with 6 buckets (distribution: worker0=[0,3], worker1=[1,4], worker2=[2,5]).
* Input = first-per-worker instances (simulating filterInstancesWhichCanReceiveDataFromRemote output).
* Each first instance's BucketScanSource contains ALL buckets for that worker.
* sortDestinationInstancesByBuckets should produce 6 destinations where same-worker
* buckets point to the same instance.
*/
@Test
public void testBucketShufflePoolingDestinations() throws Exception {
DistributedPlanWorker worker0 = Mockito.mock(DistributedPlanWorker.class);
DistributedPlanWorker worker1 = Mockito.mock(DistributedPlanWorker.class);
DistributedPlanWorker worker2 = Mockito.mock(DistributedPlanWorker.class);
Mockito.when(worker0.id()).thenReturn(100L);
Mockito.when(worker1.id()).thenReturn(200L);
Mockito.when(worker2.id()).thenReturn(300L);

ScanNode mockScanNode = Mockito.mock(ScanNode.class);
UnassignedJob mockJob = Mockito.mock(UnassignedJob.class);

// Worker0 first instance: has buckets 0 and 3
BucketScanSource source0 = makeBucketScanSource(mockScanNode, 0, 3);
LocalShuffleBucketJoinAssignedJob inst0 = new LocalShuffleBucketJoinAssignedJob(
0, 0, new TUniqueId(0, 0), mockJob, worker0, source0, ImmutableSet.of(0, 3));

// Worker1 first instance: has buckets 1 and 4
BucketScanSource source1 = makeBucketScanSource(mockScanNode, 1, 4);
LocalShuffleBucketJoinAssignedJob inst1 = new LocalShuffleBucketJoinAssignedJob(
1, 0, new TUniqueId(0, 1), mockJob, worker1, source1, ImmutableSet.of(1, 4));

// Worker2 first instance: has buckets 2 and 5
BucketScanSource source2 = makeBucketScanSource(mockScanNode, 2, 5);
LocalShuffleBucketJoinAssignedJob inst2 = new LocalShuffleBucketJoinAssignedJob(
2, 0, new TUniqueId(0, 2), mockJob, worker2, source2, ImmutableSet.of(2, 5));

// This simulates getFirstInstancePerWorker output
List<AssignedJob> firstPerWorker = new ArrayList<>();
firstPerWorker.add(inst0);
firstPerWorker.add(inst1);
firstPerWorker.add(inst2);

// Invoke private sortDestinationInstancesByBuckets via reflection
PipelineDistributedPlan mockPlan = Mockito.mock(PipelineDistributedPlan.class);
Mockito.when(mockPlan.getFragmentJob()).thenReturn(mockJob);

Method method = DistributePlanner.class.getDeclaredMethod(
"sortDestinationInstancesByBuckets",
PipelineDistributedPlan.class, List.class, int.class);
method.setAccessible(true);

// Need a DistributePlanner instance — create via Objenesis (Mockito internal) to bypass constructor
Object planner = org.objenesis.ObjenesisStd.class.getDeclaredConstructor()
.newInstance().newInstance(DistributePlanner.class);

@SuppressWarnings("unchecked")
List<AssignedJob> destinations = (List<AssignedJob>) method.invoke(planner, mockPlan, firstPerWorker, 6);

// Verify: destinations count == bucketNum
Assertions.assertEquals(6, destinations.size(), "destinations count should equal bucketNum");

// Verify: each bucket maps to the correct worker's first instance
Assertions.assertSame(inst0, destinations.get(0), "bucket 0 → worker0 first instance");
Assertions.assertSame(inst1, destinations.get(1), "bucket 1 → worker1 first instance");
Assertions.assertSame(inst2, destinations.get(2), "bucket 2 → worker2 first instance");
Assertions.assertSame(inst0, destinations.get(3), "bucket 3 → worker0 first instance (same as bucket 0)");
Assertions.assertSame(inst1, destinations.get(4), "bucket 4 → worker1 first instance (same as bucket 1)");
Assertions.assertSame(inst2, destinations.get(5), "bucket 5 → worker2 first instance (same as bucket 2)");

// Verify: same worker's buckets point to the SAME instance object
Assertions.assertSame(destinations.get(0), destinations.get(3),
"bucket 0 and 3 should be same instance (same worker)");
Assertions.assertSame(destinations.get(1), destinations.get(4),
"bucket 1 and 4 should be same instance (same worker)");
Assertions.assertSame(destinations.get(2), destinations.get(5),
"bucket 2 and 5 should be same instance (same worker)");
}

private static BucketScanSource makeBucketScanSource(ScanNode scanNode, int... bucketIndices) {
Map<Integer, Map<ScanNode, ScanRanges>> bucketMap = Maps.newLinkedHashMap();
for (int idx : bucketIndices) {
Map<ScanNode, ScanRanges> scanMap = ImmutableMap.of(scanNode, new ScanRanges());
bucketMap.put(idx, scanMap);
}
return new BucketScanSource(bucketMap);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.planner;

import org.apache.doris.analysis.TupleId;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.thrift.TPartitionType;
import org.apache.doris.thrift.TPlanNode;

import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;

import java.util.Collections;

/**
* Tests that ExchangeNode.toThrift sets is_serial correctly.
* is_serial is a per-operator property: only UNPARTITIONED (or use_serial_exchange)
* exchanges should be serial, NOT bucket shuffle exchanges that share a fragment
* with a serial scan.
*/
public class ExchangeNodeSerialTest {

private ExchangeNode createExchangeNode(TPartitionType partitionType,
boolean hasSerialScanNode, boolean useSerialSource) {
// Use Mockito spy to avoid complex PlanNode constructor dependencies
ExchangeNode exchange = Mockito.mock(ExchangeNode.class, Mockito.withSettings()
.defaultAnswer(Mockito.CALLS_REAL_METHODS));

// Set the fields that toThrift accesses
exchange.tupleIds = Lists.newArrayList(new TupleId(0));
exchange.conjuncts = Collections.emptyList();

// Use reflection to set private partitionType field
try {
java.lang.reflect.Field ptField = ExchangeNode.class.getDeclaredField("partitionType");
ptField.setAccessible(true);
ptField.set(exchange, partitionType);
} catch (Exception e) {
throw new RuntimeException(e);
}

PlanFragment fragment = Mockito.mock(PlanFragment.class);
Mockito.when(fragment.hasSerialScanNode()).thenReturn(hasSerialScanNode);
Mockito.when(fragment.useSerialSource(Mockito.any())).thenReturn(useSerialSource);
exchange.fragment = fragment;

return exchange;
}

/**
* BUCKET_SHUFFLE exchange in a pooling fragment with serial scan:
* isSerialOperator()=false, hasSerialScanNode()=true, useSerialSource()=true
* → is_serial must be false (scan seriality should not propagate to exchange)
*/
@Test
public void testBucketShuffleExchangeNotSerialInPoolingFragment() {
ConnectContext ctx = new ConnectContext();
ctx.setSessionVariable(new SessionVariable());
try (MockedStatic<ConnectContext> mocked = Mockito.mockStatic(ConnectContext.class)) {
mocked.when(ConnectContext::get).thenReturn(ctx);

ExchangeNode exchange = createExchangeNode(
TPartitionType.BUCKET_SHFFULE_HASH_PARTITIONED, true, true);

TPlanNode thriftNode = new TPlanNode();
exchange.toThrift(thriftNode);

Assertions.assertFalse(thriftNode.isIsSerialOperator(),
"BUCKET_SHUFFLE exchange should NOT be serial even when fragment has serial scan");
}
}

/**
* UNPARTITIONED exchange in a pooling fragment:
* isSerialOperator()=true (UNPARTITIONED), useSerialSource()=true
* → is_serial must be true
*/
@Test
public void testUnpartitionedExchangeSerialInPoolingFragment() {
ConnectContext ctx = new ConnectContext();
ctx.setSessionVariable(new SessionVariable());
try (MockedStatic<ConnectContext> mocked = Mockito.mockStatic(ConnectContext.class)) {
mocked.when(ConnectContext::get).thenReturn(ctx);

ExchangeNode exchange = createExchangeNode(
TPartitionType.UNPARTITIONED, true, true);

TPlanNode thriftNode = new TPlanNode();
exchange.toThrift(thriftNode);

Assertions.assertTrue(thriftNode.isIsSerialOperator(),
"UNPARTITIONED exchange should be serial in pooling fragment");
}
}

/**
* UNPARTITIONED exchange in a non-pooling fragment:
* isSerialOperator()=true, useSerialSource()=false
* → is_serial must be false (not in pooling mode)
*/
@Test
public void testUnpartitionedExchangeNotSerialWithoutPooling() {
ConnectContext ctx = new ConnectContext();
ctx.setSessionVariable(new SessionVariable());
try (MockedStatic<ConnectContext> mocked = Mockito.mockStatic(ConnectContext.class)) {
mocked.when(ConnectContext::get).thenReturn(ctx);

ExchangeNode exchange = createExchangeNode(
TPartitionType.UNPARTITIONED, false, false);

TPlanNode thriftNode = new TPlanNode();
exchange.toThrift(thriftNode);

Assertions.assertFalse(thriftNode.isIsSerialOperator(),
"UNPARTITIONED exchange should NOT be serial without pooling");
}
}

/**
* HASH_PARTITIONED exchange in a pooling fragment with serial scan:
* isSerialOperator()=false (HASH is not serial), useSerialSource()=true
* → is_serial must be false
*/
@Test
public void testHashExchangeNotSerialInPoolingFragment() {
ConnectContext ctx = new ConnectContext();
ctx.setSessionVariable(new SessionVariable());
try (MockedStatic<ConnectContext> mocked = Mockito.mockStatic(ConnectContext.class)) {
mocked.when(ConnectContext::get).thenReturn(ctx);

ExchangeNode exchange = createExchangeNode(
TPartitionType.HASH_PARTITIONED, true, true);

TPlanNode thriftNode = new TPlanNode();
exchange.toThrift(thriftNode);

Assertions.assertFalse(thriftNode.isIsSerialOperator(),
"HASH exchange should NOT be serial even when fragment has serial scan");
}
}
}
Loading