Skip to content

Commit fcb2a48

Browse files
authored
Almost-exact graph recognizes equivalence in the split-split pattern (#5986)
1 parent 8ae5ab7 commit fcb2a48

File tree

4 files changed

+146
-34
lines changed

4 files changed

+146
-34
lines changed

csrc/id_model/id_model.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "device_lower/lower2device.h"
1616
#include "device_lower/utils.h"
1717
#include "disjoint_set.h"
18+
#include "expr_simplifier.h"
1819
#include "id_model/loop_promotion.h"
1920
#include "id_model/to_string.h"
2021
#include "id_model/transform_replay.h"
@@ -481,6 +482,66 @@ std::vector<std::vector<Val*>> getTriviallyMappedIds(Expr* expr) {
481482
return mapped_ids;
482483
}
483484

485+
// The following is a subpattern of
486+
// https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md#2-properties-of-iterdomain-transformations
487+
//
488+
// outer, _ = split(root)
489+
// outermost_grand, _ = split(outer)
490+
// outer', _ = split(root)
491+
//
492+
// If outermost_grand and outer' have the same extent, map them.
493+
// The splits must be divisible for this mapping to be valid.
494+
void mapDivisibleSplits(ValGraph& graph) {
495+
auto is_divisible = [](Split* s) {
496+
return simplifyExpr(s->isDivisible())->isTrue();
497+
};
498+
499+
std::vector<std::pair<Val*, Val*>> ids_to_map;
500+
for (const ValGroup& root : graph.disjointValSets().disjointSets()) {
501+
const ExprGroups& uses_of_root = graph.getUses(root);
502+
std::vector<ValGroup> outermost_grands;
503+
for (const ExprGroup& use_of_root : uses_of_root) {
504+
auto* split0 = dynamic_cast<Split*>(use_of_root->front());
505+
if (split0 == nullptr || !is_divisible(split0)) {
506+
continue;
507+
}
508+
// Only follow the outer output of the first split; outer and inner
509+
// must not be conflated.
510+
const ValGroup& outer = graph.toGroup(split0->outer());
511+
for (const ExprGroup& use_of_outer : graph.getUses(outer)) {
512+
auto* split1 = dynamic_cast<Split*>(use_of_outer->front());
513+
if (split1 == nullptr || !is_divisible(split1)) {
514+
continue;
515+
}
516+
const ValGroup& outermost_grand = graph.toGroup(split1->outer());
517+
outermost_grands.push_back(outermost_grand);
518+
}
519+
}
520+
521+
for (const ValGroup& outermost_grand : outermost_grands) {
522+
Val* extent_of_grand =
523+
outermost_grand->front()->as<IterDomain>()->extent();
524+
525+
for (const ExprGroup& use_of_root : uses_of_root) {
526+
auto* split = dynamic_cast<Split*>(use_of_root->front());
527+
if (split == nullptr || !is_divisible(split)) {
528+
continue;
529+
}
530+
531+
const ValGroup& outer = graph.toGroup(split->outer());
532+
if (outer->front()->as<IterDomain>()->extent()->sameAs(
533+
extent_of_grand)) {
534+
ids_to_map.emplace_back(outermost_grand->front(), outer->front());
535+
}
536+
}
537+
}
538+
}
539+
540+
for (const auto& [id1, id2] : ids_to_map) {
541+
graph.mapVals(id1, id2);
542+
}
543+
}
544+
484545
} // namespace
485546

486547
ValGraph& IdModel::buildAlmostExactGraph() {
@@ -540,6 +601,8 @@ ValGraph& IdModel::buildAlmostExactGraph() {
540601
almost_exact_graph.mapVals(id1, id2);
541602
}
542603

604+
mapDivisibleSplits(almost_exact_graph);
605+
543606
almost_exact_graph.validateConsistency();
544607

545608
if (!allow_self_mapping_) {

csrc/val_graph.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,8 @@ std::string ValGraph::toString() const {
269269
ss << "IdGraph { \n";
270270
ss << "Disjoint Ids:\n"
271271
<< idGroupsString(*this, 1) << "\n\nDisjoint Expression groups:\n"
272-
<< exprGroupsString(*this, 1) << std::endl;
273-
ss << " } IdGraph\n" << std::endl;
272+
<< exprGroupsString(*this, 1) << '\n';
273+
ss << " } IdGraph\n";
274274
return ss.str();
275275
}
276276

@@ -397,11 +397,12 @@ const ExprGroups& ValGraph::getDefinitions(const ValGroup& val_group) const {
397397

398398
const ExprGroups& ValGraph::getUses(const ValGroup& val_group) const {
399399
NVF_ERROR(val_group, "Nullptr not allowed");
400+
401+
static const ExprGroups empty_expr_groups;
400402
const auto it = unique_uses_.find(val_group);
401-
NVF_ERROR(
402-
it != unique_uses_.end(),
403-
"Use group not found for ",
404-
nvfuser::toString(val_group));
403+
if (it == unique_uses_.end()) {
404+
return empty_expr_groups;
405+
}
405406
return it->second;
406407
}
407408

tests/cpp/test_id_model.cpp

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
*/
77
// clang-format on
88

9-
#include <fstream>
10-
119
#include <gmock/gmock-matchers.h>
1210
#include <gtest/gtest.h>
1311

@@ -17,7 +15,6 @@
1715
#include "id_model/loop_promotion.h"
1816
#include "id_model/schedule.h"
1917
#include "id_model/to_string.h"
20-
#include "ir/graphviz.h"
2118
#include "ops/all_ops.h"
2219
#include "scheduler/tools/inlining.h"
2320
#include "scheduler/tools/resize_utils.h"
@@ -235,8 +232,7 @@ void validateIELResolution(
235232
auto promotion_id = iel_promotion_map_it->second;
236233
ASSERT_TRUE(
237234
exact_graph.disjointValSets().strictAreMapped(promotion_id, ref_id))
238-
<< "Unexpected promotion. "
239-
<< "Expected: " << ref_id->toString()
235+
<< "Unexpected promotion. Expected: " << ref_id->toString()
240236
<< ". Actual: " << promotion_id->toString();
241237
ASSERT_TRUE(loop_graph.disjointValSets().strictAreMapped(id, promotion_id))
242238
<< "Promotion of " << id->toString()
@@ -376,9 +372,9 @@ void checkStep4Results(
376372
const auto& iel_promotion_map = tester.s4_iel_promotion_map;
377373

378374
EXPECT_EQ(iel_promotion_map.size(), ref_promotion_map.size())
379-
<< "Mismatched Step-4 result map. "
380-
<< "Expected to have " << ref_promotion_map.size()
381-
<< " mappings but found " << iel_promotion_map.size();
375+
<< "Mismatched Step-4 result map. Expected to have "
376+
<< ref_promotion_map.size() << " mappings but found "
377+
<< iel_promotion_map.size();
382378

383379
for (const auto& ref_promotion_pair : ref_promotion_map) {
384380
const auto& ref_promotion_group = ref_promotion_pair.first;
@@ -2937,9 +2933,8 @@ TEST_F(IdModelTest, LoopPromotionCyclicGraphWar) {
29372933
// Test to verify the split-aware covered group analysis. See
29382934
// also https://github.com/NVIDIA/Fuser/pull/3877.
29392935
TEST_F(IdModelTest, CoveredGroups) {
2940-
auto fusion_ptr = std::make_unique<Fusion>();
2941-
auto& fusion = *fusion_ptr;
2942-
FusionGuard fg(fusion_ptr.get());
2936+
Fusion fusion;
2937+
FusionGuard fg(&fusion);
29432938

29442939
auto tv0 = makeContigConcreteTensor({-1, 1});
29452940
fusion.addInput(tv0);
@@ -3000,7 +2995,7 @@ TEST_F(IdModelTest, CoveredGroups) {
30002995
TEST_F(IdModelTest, InvalidLoopPromotion) {
30012996
auto fusion_ptr = std::make_unique<Fusion>();
30022997
auto& fusion = *fusion_ptr;
3003-
FusionGuard fg(fusion_ptr.get());
2998+
FusionGuard fg(&fusion);
30042999

30053000
auto T0 = makeContigConcreteTensor({1, 32, 6});
30063001
fusion.addInput(T0);
@@ -3086,9 +3081,8 @@ TEST_F(IdModelTest, InvalidLoopPromotion) {
30863081
// When a loop group only includes broadcast IDs, the group should not
30873082
// need to be promoted
30883083
TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) {
3089-
auto fusion_ptr = std::make_unique<Fusion>();
3090-
auto& fusion = *fusion_ptr;
3091-
FusionGuard fg(fusion_ptr.get());
3084+
Fusion fusion;
3085+
FusionGuard fg(&fusion);
30923086

30933087
auto tv0 = makeContigConcreteTensor({-1, 1});
30943088
fusion.addInput(tv0);
@@ -3130,9 +3124,8 @@ TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) {
31303124

31313125
// Scatter output uses unique mapping schemes
31323126
TEST_F(IdModelTest, ScatterLoopMapping) {
3133-
auto fusion_ptr = std::make_unique<Fusion>();
3134-
auto& fusion = *fusion_ptr;
3135-
FusionGuard fg(fusion_ptr.get());
3127+
Fusion fusion;
3128+
FusionGuard fg(&fusion);
31363129

31373130
auto tv0 = makeContigTensor(1);
31383131
fusion.addInput(tv0);
@@ -3185,8 +3178,7 @@ TEST_F(IdModelTest, ScatterLoopMapping) {
31853178
// required but is a WAR for special ops like
31863179
// PreprocessGroupedMatmulInputSf. See also issue #5391.
31873180
TEST_F(IdModelTest, LoopPromotionIncludeOnlyLoopIds) {
3188-
auto fusion_ptr = std::make_unique<Fusion>();
3189-
Fusion& fusion = *fusion_ptr.get();
3181+
Fusion fusion;
31903182
FusionGuard fg(&fusion);
31913183

31923184
auto tv0 = makeSymbolicTensor(2);
@@ -3219,8 +3211,7 @@ TEST_F(IdModelTest, LoopPromotionIncludeOnlyLoopIds) {
32193211
}
32203212

32213213
TEST_F(IdModelTest, PermissiveResizeGraph) {
3222-
auto fusion_ptr = std::make_unique<Fusion>();
3223-
Fusion& fusion = *fusion_ptr.get();
3214+
Fusion fusion;
32243215
FusionGuard fg(&fusion);
32253216

32263217
auto tv0 = makeConcreteTensor({36});
@@ -3270,8 +3261,7 @@ TEST_F(IdModelTest, PermissiveResizeGraph) {
32703261
// This is the failing segment of the reproducer of
32713262
// https://github.com/NVIDIA/Fuser/issues/5803.
32723263
TEST_F(IdModelTest, ReproIssue5803) {
3273-
auto fusion_ptr = std::make_unique<Fusion>();
3274-
Fusion& fusion = *fusion_ptr.get();
3264+
Fusion fusion;
32753265
FusionGuard fg(&fusion);
32763266

32773267
auto tv2 = makeContigConcreteTensor({4}, DataType::Int);
@@ -3312,8 +3302,7 @@ TEST_F(IdModelTest, ReproIssue5803) {
33123302
// This is a minimal fusion pattern to trigger the loop promotion
33133303
// issue as reported in https://github.com/NVIDIA/Fuser/issues/5803
33143304
TEST_F(IdModelTest, ReproIssue5803Minimal) {
3315-
auto fusion_ptr = std::make_unique<Fusion>();
3316-
Fusion& fusion = *fusion_ptr.get();
3305+
Fusion fusion;
33173306
FusionGuard fg(&fusion);
33183307

33193308
auto tv0 = makeConcreteTensor({4, 8});
@@ -3337,4 +3326,63 @@ TEST_F(IdModelTest, ReproIssue5803Minimal) {
33373326
IdModel id_model(&fusion, true);
33383327
}
33393328

3329+
TEST_F(IdModelTest, SplittingReshape_Mapped) {
3330+
Fusion fusion;
3331+
FusionGuard fg(&fusion);
3332+
3333+
TensorView* in = makeContigConcreteTensor({2 * 2 * 2});
3334+
fusion.addInput(in);
3335+
TensorView* out = reshape(in, {2 * 2 * 2}, {2 * 2, 2});
3336+
fusion.addOutput(out);
3337+
3338+
in->outer_split(0, 2);
3339+
out->outer_split(0, 2);
3340+
3341+
IdModel id_model(&fusion);
3342+
const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph();
3343+
EXPECT_TRUE(almost_exact_graph.disjointValSets().strictAreMapped(
3344+
in->axis(0), out->axis(0)));
3345+
EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped(
3346+
in->axis(0), out->axis(1)));
3347+
EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped(
3348+
in->axis(1), out->axis(2)));
3349+
}
3350+
3351+
TEST_F(IdModelTest, SplitingReshape_DifferentExtents_NotMapped) {
3352+
Fusion fusion;
3353+
FusionGuard fg(&fusion);
3354+
3355+
TensorView* in = makeContigConcreteTensor({12});
3356+
fusion.addInput(in);
3357+
TensorView* out = reshape(in, {12}, {6, 2});
3358+
fusion.addOutput(out);
3359+
3360+
in->outer_split(0, 2);
3361+
out->outer_split(0, 3);
3362+
3363+
IdModel id_model(&fusion);
3364+
const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph();
3365+
EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped(
3366+
in->axis(0), out->axis(0)));
3367+
}
3368+
3369+
TEST_F(IdModelTest, NonDivisibleSplits_NotMapped) {
3370+
Fusion fusion;
3371+
FusionGuard fg(&fusion);
3372+
3373+
TensorView* in = makeContigConcreteTensor({6});
3374+
fusion.addInput(in);
3375+
TensorView* out = set(in);
3376+
fusion.addOutput(out);
3377+
3378+
in->outer_split(0, 2);
3379+
out->inner_split(0, 4);
3380+
out->outer_split(0, 2);
3381+
3382+
IdModel id_model(&fusion);
3383+
const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph();
3384+
EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped(
3385+
in->axis(0), out->axis(0)));
3386+
}
3387+
33403388
} // namespace nvfuser

tests/cpp/test_indexing.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -860,9 +860,9 @@ TEST_F(IndexingTest, Reshape) {
860860
// to provide the extent of the group. However, since everything
861861
// should be deterministic, string match should also work.
862862
return std::string(
863-
"( ( ( ( ( i114 * 20 ) + ( ( i115 * 10 ) + i116 ) ) / 25 ) * 25 "
863+
"( ( ( ( ( i126 * 20 ) + ( ( i127 * 10 ) + i128 ) ) / 25 ) * 25 "
864864
") "
865-
"+ ( ( ( i114 * 20 ) + ( ( i115 * 10 ) + i116 ) ) % 25 ) )");
865+
"+ ( ( ( i126 * 20 ) + ( ( i127 * 10 ) + i128 ) ) % 25 ) )");
866866
}
867867
default:
868868
return std::string();

0 commit comments

Comments
 (0)