Encapsulate communication lowering variables into a struct#6037
Encapsulate communication lowering variables into a struct#6037
Conversation
Greptile SummaryThis PR encapsulates the four communication-lowering variables ( Key changes:
The refactoring is clean and functionally equivalent. The one pre-existing concern (noted in the previous review thread) is that the Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["convertSingleOpToCommunication(e, my_device_idx, host_loop_index, backend)"]
B["op_type(e)\n→ std::optional<BinaryOpType>"]
C["Build CommunicationLoweringParams\n{backend, my_device_idx,\n host_loop_index, reduction_op}"]
D{communication_info->type}
A --> B
B --> C
C --> D
D --> E["lowerToScatter(params)"]
D --> F["lowerToGather(params)"]
D --> G["lowerToAllgather(params)\nuses my_device_idx"]
D --> H["lowerToBroadcast(params)\nroot = host_loop_index"]
D --> I["lowerToSendRecv(params)"]
D --> J["lowerToReduce(params)\nchecks reduction_op"]
D --> K["lowerToAllreduce(params)\nchecks reduction_op, my_device_idx"]
D --> L["lowerToReduceScatter(params)\nchecks reduction_op, my_device_idx"]
D --> M["lowerToAllToAll(params)"]
D --> N["lowerToCollectivePermute(params)\nchecks host_loop_index != nullptr,\nuses my_device_idx"]
Last reviewed commit: "Merge branch 'main' ..." |
| auto op_type = [](Expr* e) -> std::optional<BinaryOpType> { | ||
| if (auto* reduce = dynamic_cast<ReductionOp*>(e)) { | ||
| return reduce->getReductionOpType(); | ||
| } | ||
|
|
||
| NVF_ERROR(e != nullptr); | ||
| if (e->isA<SqueezeOp>()) { | ||
| if (e != nullptr && e->isA<SqueezeOp>()) { | ||
| return BinaryOpType::Add; | ||
| } | ||
|
|
||
| NVF_THROW("Expected a ReductionOp or a SqueezeOp, but got: ", e); | ||
| return std::nullopt; | ||
| }; |
There was a problem hiding this comment.
Loss of diagnostic information in
op_type lambda
The old lambda threw an explicit, informative error when the expression was neither a ReductionOp nor a SqueezeOp:
NVF_THROW("Expected a ReductionOp or a SqueezeOp, but got: ", e);The new lambda silently returns std::nullopt for any such expression. The error is now deferred to the individual lowerToReduce/lowerToAllreduce/lowerToReduceScatter functions, but those error messages ("Reduce communication requires reduction_op in params", etc.) don't include the actual expression e that caused the problem.
As the comment in base.h above valueOrError itself notes: "If you prefer a better error message, have the caller check has_value() instead so it can provide more context." The current error messages do check has_value() but drop the expression context. Consider propagating e into those error messages, or restoring the NVF_THROW in the lambda for the case when e is non-null and of an unrecognized type:
auto op_type = [](Expr* e) -> std::optional<BinaryOpType> {
if (auto* reduce = dynamic_cast<ReductionOp*>(e)) {
return reduce->getReductionOpType();
}
if (e != nullptr && e->isA<SqueezeOp>()) {
return BinaryOpType::Add;
}
// Optionally retain the informative message for unexpected non-null expressions:
// NVF_ERROR(e == nullptr, "Expected a ReductionOp or a SqueezeOp, but got: ", e);
return std::nullopt;
};|
!test |
| return std::nullopt; | ||
| }; | ||
|
|
||
| CommunicationLoweringParams params{ |
There was a problem hiding this comment.
Looks good. You may want to consider the builder design pattern used by at::TensorOptions: https://docs.pytorch.org/cppdocs/notes/tensor_creation.html#configuring-properties-of-the-tensor. It's a more fluent way to build the parameter without having to mechanically create a struct and overwrite fields.
There was a problem hiding this comment.
For now, I prefer the struct. All the fields need to be overwritten (even if I set the common values as default). If a case arises where some fields are always going to be these default values, then I can consider switching the API
a3cb641 to
2a66774
Compare
|
!test |
This allows me to easily pass in new variables without changing individual signatures for all
lowerTo*methods.