Skip to content

Commit 5b210dd

Browse files
authored
Handle swizzle1d in isResharding (#6028)
1 parent 54d48ae commit 5b210dd

File tree

5 files changed

+99
-10
lines changed

5 files changed

+99
-10
lines changed

csrc/multidevice/device_mesh.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,9 @@ std::vector<DeviceIdxType> DeviceMesh::getSlice(
175175
indices.reserve(rank());
176176
for (int64_t i : arange(rank())) {
177177
if (i == axis) {
178-
indices.push_back(at::indexing::Slice());
178+
indices.emplace_back(at::indexing::Slice());
179179
} else {
180-
indices.push_back(index[i]);
180+
indices.emplace_back(index[i]);
181181
}
182182
}
183183
at::Tensor slice = devices_.index(indices);

csrc/multidevice/device_mesh.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,9 @@ class DeviceMesh final {
9696
}
9797

9898
// Returns the rank (number of dimensions) of the mesh.
99+
// Returns -1 if the mesh is empty.
99100
int64_t rank() const {
100-
return devices_.dim();
101+
return size() > 0 ? devices_.dim() : -1;
101102
}
102103

103104
bool operator==(const DeviceMesh& other) const {

csrc/multidevice/resharding.cpp

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ const std::vector<IterDomain*>& getDomainOf(
5959
std::pair<Val*, bool> computeLoopIndex(
6060
IterDomain* id,
6161
const std::vector<IterDomain*>& sources,
62-
std::unordered_map<IterDomain*, std::pair<Val*, bool>>& id_to_index) {
62+
std::unordered_map<IterDomain*, std::pair<Val*, bool>>& id_to_index,
63+
const std::unordered_map<ParallelType, Val*>& pt_to_index) {
6364
if (id == nullptr) {
6465
return {nullptr, false};
6566
}
@@ -86,7 +87,9 @@ std::pair<Val*, bool> computeLoopIndex(
8687
div(in_info.first, inner->extent()), in_info.second};
8788
id_to_index[inner] = {
8889
mod(in_info.first, inner->extent()), in_info.second};
89-
} else if (auto* merge = dynamic_cast<Merge*>(transform)) {
90+
continue;
91+
}
92+
if (auto* merge = dynamic_cast<Merge*>(transform)) {
9093
auto* outer = merge->outer()->as<IterDomain>();
9194
auto* inner = merge->inner()->as<IterDomain>();
9295
auto* out = merge->out()->as<IterDomain>();
@@ -96,9 +99,22 @@ std::pair<Val*, bool> computeLoopIndex(
9699
id_to_index[out] = {
97100
add(mul(outer_info.first, inner->extent()), inner_info.first),
98101
outer_info.second || inner_info.second};
99-
} else {
100-
NVF_THROW("Unexpected transform: ", transform);
102+
continue;
101103
}
104+
if (auto* swizzle = dynamic_cast<Swizzle1D*>(transform)) {
105+
auto* in = swizzle->in()->as<IterDomain>();
106+
auto* out = swizzle->out()->as<IterDomain>();
107+
108+
const auto& in_info = id_to_index.at(in);
109+
Val* extent = out->extent();
110+
Val* pt_val = pt_to_index.at(swizzle->parallelType());
111+
// Inverse of the swizzle formula in_idx = (out_idx + pt_val) % extent:
112+
// out_idx = (in_idx - pt_val + extent) % extent
113+
id_to_index[out] = {
114+
mod(add(sub(in_info.first, pt_val), extent), extent), in_info.second};
115+
continue;
116+
}
117+
NVF_THROW("Unexpected transform: ", transform);
102118
}
103119

104120
return id_to_index.at(id);
@@ -241,9 +257,26 @@ bool haveDifferentShardings(
241257
std::vector<Val*> assumptions;
242258
assumptions.reserve(
243259
(producer->getLogicalDomain().size() +
244-
consumer->getMaybeRootDomain().size()) *
260+
consumer->getMaybeRootDomain().size() + kParallelTypeDIDs.size()) *
245261
2);
246262

263+
// Create symbolic Vals for each device parallel type present in the mesh,
264+
// representing the device's index within the team for that type. These are
265+
// used by computeLoopIndex to symbolically compute Swizzle1D outputs.
266+
std::unordered_map<ParallelType, Val*> pt_to_index;
267+
const DeviceMesh& mesh = producer->getDeviceMesh();
268+
for (ParallelType pt : kParallelTypeDIDs) {
269+
if (!mesh.hasParallelType(pt)) {
270+
continue;
271+
}
272+
Val* device_idx = IrBuilder::create<Val>(DataType::Index);
273+
pt_to_index[pt] = device_idx;
274+
Val* team_size = IrBuilder::create<Val>(mesh.size(pt), DataType::Index);
275+
assumptions.push_back(
276+
SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx));
277+
assumptions.push_back(SimplifyingIrBuilder::ltExpr(device_idx, team_size));
278+
}
279+
247280
auto create_index = [&](IterDomain* id, bool mapped) {
248281
auto* index = IrBuilder::create<Val>(DataType::Index);
249282
NVF_ERROR(id_to_index.emplace(id, std::make_pair(index, mapped)).second);
@@ -311,7 +344,10 @@ bool haveDifferentShardings(
311344
Val* p_index = nullptr;
312345
bool p_mapped = false;
313346
std::tie(p_index, p_mapped) = computeLoopIndex(
314-
p_id, getDomainOf(producer, DomainType::kLogical), id_to_index);
347+
p_id,
348+
getDomainOf(producer, DomainType::kLogical),
349+
id_to_index,
350+
pt_to_index);
315351
if (!p_mapped) {
316352
p_index = nullptr;
317353
}
@@ -320,7 +356,10 @@ bool haveDifferentShardings(
320356
Val* c_index = nullptr;
321357
bool c_mapped = false;
322358
std::tie(c_index, c_mapped) = computeLoopIndex(
323-
c_id, getDomainOf(consumer, DomainType::kRoot), id_to_index);
359+
c_id,
360+
getDomainOf(consumer, DomainType::kRoot),
361+
id_to_index,
362+
pt_to_index);
324363
if (!c_mapped) {
325364
c_index = nullptr;
326365
}

tests/cpp/test_multidevice_host_ir.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,11 @@ TEST_F(MultiDeviceHostIrTest, SymmetricContiguousView) {
455455
FusionGuard::setCurFusion(hic.get());
456456

457457
// Create input and output TensorViews
458+
DeviceMesh mesh = DeviceMesh::createForNumDevices(communicator_size);
459+
458460
TensorView* input_tv = makeContigConcreteTensor(sharded_sizes);
459461
input_tv->setMemoryType(MemoryType::Symmetric);
462+
input_tv->setDeviceMesh(mesh);
460463
input_tv->axis(0)->parallelize(ParallelType::DIDx);
461464

462465
TensorView* output_tv = makeContigConcreteTensor(unsharded_sizes);

tests/cpp/test_resharding.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,4 +631,50 @@ TEST_F(ReshardingSelectOpTest, ReshardingSelectIntoNonDeviceDim) {
631631
EXPECT_TRUE(isResharding(tv1->definition()));
632632
}
633633

634+
TEST_F(ReshardingTest, Swizzle1D_DIDToStream) {
635+
Fusion fusion;
636+
FusionGuard fg(&fusion);
637+
const int d = 2;
638+
auto mesh = DeviceMesh::createForNumDevices(d);
639+
640+
TensorView* in = makeContigTensor(1);
641+
in->setDeviceMesh(mesh);
642+
in->outer_split(0, d);
643+
in->axis(0)->parallelize(ParallelType::DIDx);
644+
645+
TensorView* out = set(in);
646+
out->setDeviceMesh(mesh);
647+
out->outer_split(0, d);
648+
out->swizzle1d(0, ParallelType::DIDx);
649+
out->axis(0)->parallelize(ParallelType::Stream);
650+
651+
EXPECT_TRUE(haveDifferentShardings(
652+
in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream}));
653+
654+
EXPECT_TRUE(haveDifferentShardings(
655+
in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::DIDx}));
656+
}
657+
658+
TEST_F(ReshardingTest, Swizzle1D_ConsistentSwizzle) {
659+
Fusion fusion;
660+
FusionGuard fg(&fusion);
661+
const int d = 2;
662+
auto mesh = DeviceMesh::createForNumDevices(d);
663+
664+
TensorView* in = makeContigTensor(1);
665+
in->setDeviceMesh(mesh);
666+
in->outer_split(0, d);
667+
in->swizzle1d(0, ParallelType::DIDx);
668+
in->axis(0)->parallelize(ParallelType::Stream);
669+
670+
TensorView* out = set(in);
671+
out->setDeviceMesh(mesh);
672+
out->outer_split(0, d);
673+
out->swizzle1d(0, ParallelType::DIDx);
674+
out->axis(0)->parallelize(ParallelType::Stream);
675+
676+
EXPECT_FALSE(haveDifferentShardings(
677+
in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream}));
678+
}
679+
634680
} // namespace nvfuser

0 commit comments

Comments
 (0)