@@ -59,7 +59,8 @@ const std::vector<IterDomain*>& getDomainOf(
5959std::pair<Val*, bool > computeLoopIndex (
6060 IterDomain* id,
6161 const std::vector<IterDomain*>& sources,
62- std::unordered_map<IterDomain*, std::pair<Val*, bool >>& id_to_index) {
62+ std::unordered_map<IterDomain*, std::pair<Val*, bool >>& id_to_index,
63+ const std::unordered_map<ParallelType, Val*>& pt_to_index) {
6364 if (id == nullptr ) {
6465 return {nullptr , false };
6566 }
@@ -86,7 +87,9 @@ std::pair<Val*, bool> computeLoopIndex(
8687 div (in_info.first , inner->extent ()), in_info.second };
8788 id_to_index[inner] = {
8889 mod (in_info.first , inner->extent ()), in_info.second };
89- } else if (auto * merge = dynamic_cast <Merge*>(transform)) {
90+ continue ;
91+ }
92+ if (auto * merge = dynamic_cast <Merge*>(transform)) {
9093 auto * outer = merge->outer ()->as <IterDomain>();
9194 auto * inner = merge->inner ()->as <IterDomain>();
9295 auto * out = merge->out ()->as <IterDomain>();
@@ -96,9 +99,22 @@ std::pair<Val*, bool> computeLoopIndex(
9699 id_to_index[out] = {
97100 add (mul (outer_info.first , inner->extent ()), inner_info.first ),
98101 outer_info.second || inner_info.second };
99- } else {
100- NVF_THROW (" Unexpected transform: " , transform);
102+ continue ;
101103 }
104+ if (auto * swizzle = dynamic_cast <Swizzle1D*>(transform)) {
105+ auto * in = swizzle->in ()->as <IterDomain>();
106+ auto * out = swizzle->out ()->as <IterDomain>();
107+
108+ const auto & in_info = id_to_index.at (in);
109+ Val* extent = out->extent ();
110+ Val* pt_val = pt_to_index.at (swizzle->parallelType ());
111+ // Inverse of the swizzle formula in_idx = (out_idx + pt_val) % extent:
112+ // out_idx = (in_idx - pt_val + extent) % extent
113+ id_to_index[out] = {
114+ mod (add (sub (in_info.first , pt_val), extent), extent), in_info.second };
115+ continue ;
116+ }
117+ NVF_THROW (" Unexpected transform: " , transform);
102118 }
103119
104120 return id_to_index.at (id);
@@ -241,9 +257,26 @@ bool haveDifferentShardings(
241257 std::vector<Val*> assumptions;
242258 assumptions.reserve (
243259 (producer->getLogicalDomain ().size () +
244- consumer->getMaybeRootDomain ().size ()) *
260+ consumer->getMaybeRootDomain ().size () + kParallelTypeDIDs . size () ) *
245261 2 );
246262
263+ // Create symbolic Vals for each device parallel type present in the mesh,
264+ // representing the device's index within the team for that type. These are
265+ // used by computeLoopIndex to symbolically compute Swizzle1D outputs.
266+ std::unordered_map<ParallelType, Val*> pt_to_index;
267+ const DeviceMesh& mesh = producer->getDeviceMesh ();
268+ for (ParallelType pt : kParallelTypeDIDs ) {
269+ if (!mesh.hasParallelType (pt)) {
270+ continue ;
271+ }
272+ Val* device_idx = IrBuilder::create<Val>(DataType::Index);
273+ pt_to_index[pt] = device_idx;
274+ Val* team_size = IrBuilder::create<Val>(mesh.size (pt), DataType::Index);
275+ assumptions.push_back (
276+ SimplifyingIrBuilder::leExpr (fusion->zeroVal (), device_idx));
277+ assumptions.push_back (SimplifyingIrBuilder::ltExpr (device_idx, team_size));
278+ }
279+
247280 auto create_index = [&](IterDomain* id, bool mapped) {
248281 auto * index = IrBuilder::create<Val>(DataType::Index);
249282 NVF_ERROR (id_to_index.emplace (id, std::make_pair (index, mapped)).second );
@@ -311,7 +344,10 @@ bool haveDifferentShardings(
311344 Val* p_index = nullptr ;
312345 bool p_mapped = false ;
313346 std::tie (p_index, p_mapped) = computeLoopIndex (
314- p_id, getDomainOf (producer, DomainType::kLogical ), id_to_index);
347+ p_id,
348+ getDomainOf (producer, DomainType::kLogical ),
349+ id_to_index,
350+ pt_to_index);
315351 if (!p_mapped) {
316352 p_index = nullptr ;
317353 }
@@ -320,7 +356,10 @@ bool haveDifferentShardings(
320356 Val* c_index = nullptr ;
321357 bool c_mapped = false ;
322358 std::tie (c_index, c_mapped) = computeLoopIndex (
323- c_id, getDomainOf (consumer, DomainType::kRoot ), id_to_index);
359+ c_id,
360+ getDomainOf (consumer, DomainType::kRoot ),
361+ id_to_index,
362+ pt_to_index);
324363 if (!c_mapped) {
325364 c_index = nullptr ;
326365 }
0 commit comments