Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <unordered_set>

#include <rocRoller/CodeGen/Instruction.hpp>
#include <rocRoller/Context.hpp>
#include <rocRoller/GPUArchitecture/GPUInstructionInfo.hpp>
Expand Down Expand Up @@ -40,8 +42,9 @@ namespace rocRoller
*
* @param labelState
*/
void assertSafeToBranchTo(const WaitcntState& labelState,
std::string const& label) const;
void assertSafeToBranchTo(const WaitcntState& branchState,
std::string const& label,
bool strict) const;

private:
// These members are duplicates of the waitcntobserver members, except we're storing a
Expand Down Expand Up @@ -100,6 +103,24 @@ namespace rocRoller
// This member tracks, for every label, what the waitcnt state is everywhere a branch instruction targets that label.
std::unordered_map<std::string, std::vector<WaitcntState>> m_branchStates;

// Live snapshot of observer state at forward branch points, used to
// restore or merge state when reaching the target label.
struct LiveBranchState
{
WaitCntQueues instructionQueues;
WaitQueueMap<bool> needsWaitZero;
WaitQueueMap<GPUWaitQueueType> typeInQueue;
};
std::unordered_map<std::string, LiveBranchState> m_liveBranchStates;

// Set after an unconditional s_branch; cleared at the next label.
bool m_afterUnconditionalBranch = false;

// Labels that were targeted by a branch before the label was
// encountered in the linear traversal (forward branches, i.e.
// conditionals). Backward branches (loops) are not in this set.
std::unordered_set<std::string> m_forwardBranchLabels;

/**
* This function updates the given wait queue by applying the given waitcnt.
**/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#pragma once

#include <sstream>
#include <algorithm>

#include <rocRoller/Context.hpp>
#include <rocRoller/GPUArchitecture/GPUInstructionInfo.hpp>
Expand Down Expand Up @@ -100,9 +100,8 @@ namespace rocRoller
{
if(waitCnt >= 0 && m_instructionQueues[queue].size() > (size_t)waitCnt)
{
if(!(m_needsWaitZero[queue]
&& waitCnt
> 0)) //Do not partially clear the queue if a waitcnt zero is needed.
// Do not partially clear the queue if a waitcnt zero is needed.
if(!(m_needsWaitZero[queue] && waitCnt > 0))
{
m_instructionQueues[queue].erase(m_instructionQueues[queue].begin(),
m_instructionQueues[queue].begin()
Expand All @@ -119,30 +118,154 @@ namespace rocRoller

inline void WaitcntObserver::addLabelState(std::string const& label)
{
if(m_afterUnconditionalBranch)
{
// After an unconditional branch the linear-traversal state is stale
// Restore the live state from the branch that targets this label so
// that subsequent instructions get correct wait-count computation.
if(m_liveBranchStates.contains(label))
{
auto& bs = m_liveBranchStates.at(label);
m_instructionQueues = bs.instructionQueues;
m_needsWaitZero = bs.needsWaitZero;
m_typeInQueue = bs.typeInQueue;
}
m_afterUnconditionalBranch = false;
}
else
{
// At join points (e.g. ConditionalBottom) the label is reached
// both by fall-through and by a forward branch whose path may
// have outstanding operations the fall-through path doesn't.
// Both paths diverged from the same pre-conditional state.
//
// We perform a per-position union of entries (back-aligned)
// so that every register that might be in-flight on either
// path is present in the merged queue at a position that
// yields a sufficient waitcnt. Simply taking the
// larger queue may lose track of registers that only
// appear in the shorter queue.
auto it = m_liveBranchStates.find(label);
if(m_liveBranchStates.contains(label))
{
auto& bs = m_liveBranchStates.at(label);
for(auto& [queue, branchEntries] : bs.instructionQueues)
{
auto& labelEntries = m_instructionQueues[queue];

// When the two queues have different sizes, we must
// align entries by their distance from the end so
// the computed waitcnt is ≤ the correct value on
// both paths. We pad the shorter queue from the
// front with empty entries.
size_t mergedSize = std::max(labelEntries.size(), branchEntries.size());

if(labelEntries.size() < mergedSize)
{
size_t padCount = mergedSize - labelEntries.size();
labelEntries.insert(
labelEntries.begin(), padCount, WaitQueueRegisters{});
}

// Offset so branch position i maps to merged
// position i + branchOffset (back-aligned).
size_t branchOffset = mergedSize - branchEntries.size();

// Per-position union: merge branch registers into
// the label entries at the back-aligned position.
for(size_t i = 0; i < branchEntries.size(); i++)
{
auto& mergedEntry = labelEntries[i + branchOffset];

for(auto const& branchReg : branchEntries[i])
{
if(!branchReg)
continue;

// Skip if this register is already tracked
// at this position.
bool alreadyPresent = std::any_of(
mergedEntry.begin(), mergedEntry.end(), [&](auto const& r) {
return r && r->intersects(branchReg);
});

if(alreadyPresent)
continue;

// Find an empty slot in the entry.
auto slotIt = std::find_if(mergedEntry.begin(),
mergedEntry.end(),
[](auto const& r) { return !r; });

if(slotIt != mergedEntry.end())
*slotIt = branchReg;
else // No room, fall back to wait-zero
m_needsWaitZero[queue] = true;
}
}

m_needsWaitZero[queue] = m_needsWaitZero[queue] || bs.needsWaitZero[queue];
if(m_typeInQueue[queue] == GPUWaitQueueType::None)
{
m_typeInQueue[queue] = bs.typeInQueue[queue];
}
else if(bs.typeInQueue[queue] != GPUWaitQueueType::None
&& bs.typeInQueue[queue] != m_typeInQueue[queue])
{
m_needsWaitZero[queue] = true;
}
}
}
}

m_labelStates[label]
= WaitcntState(m_needsWaitZero, m_typeInQueue, m_instructionQueues);
}

inline void WaitcntObserver::addBranchState(std::string const& label)
{
if(m_branchStates.find(label) == m_branchStates.end())
if(!m_branchStates.contains(label))
{
m_branchStates[label] = {};
}

m_branchStates[label].emplace_back(
WaitcntState(m_needsWaitZero, m_typeInQueue, m_instructionQueues));

// If the label hasn't been encountered yet, this is a forward
// branch (conditional pattern). Backward branches (loops) target
// labels that have already been recorded.
bool isForward = !m_labelStates.contains(label);
if(isForward)
{
m_forwardBranchLabels.insert(label);

// Save the native state so we can restore/merge it at the label.
if(!m_liveBranchStates.contains(label))
{
m_liveBranchStates[label]
= {m_instructionQueues, m_needsWaitZero, m_typeInQueue};
}
}
}

inline void WaitcntObserver::assertLabelConsistency()
{
for(auto label_state : m_labelStates)
for(auto const& label_state : m_labelStates)
{
if(m_branchStates.find(label_state.first) != m_branchStates.end())
if(m_branchStates.contains(label_state.first))
{
for(auto branch_state : m_branchStates[label_state.first])
// Forward branches (conditionals): relaxed check, label
// state may be a superset of branch state.
// Backward branches (loops): strict check, states must
// match exactly.
bool isForward = m_forwardBranchLabels.contains(label_state.first);
bool strict = !isForward;

for(auto const& branch_state : m_branchStates.at(label_state.first))
{
label_state.second.assertSafeToBranchTo(branch_state, label_state.first);
label_state.second.assertSafeToBranchTo(
branch_state, label_state.first, strict);
}
}
}
Expand Down
104 changes: 75 additions & 29 deletions shared/rocroller/lib/source/Observers/WaitcntObserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ namespace rocRoller
}
}

void WaitcntState::assertSafeToBranchTo(const WaitcntState& labelState,
std::string const& label) const
void WaitcntState::assertSafeToBranchTo(const WaitcntState& branchState,
std::string const& label,
bool strict) const
{
if(*this == labelState)
if(*this == branchState)
return;

bool fail = false;
Expand All @@ -62,46 +63,84 @@ namespace rocRoller
std::string msg
= "Branching to label '" + label + "' with a different waitcnt state.\n";

// If queues do not have needsWaitZero set, and none of the instructions
// contain a destination, it is still safe to branch, even if the
// queues do not match exactly.
for(auto const& [queue, instructions] : m_instructionQueues)
if(strict)
{
if(m_needsWaitZero.at(queue) || labelState.m_needsWaitZero.at(queue))
// Strict mode (loops / backward branches): require both sides
// to have empty queues if they don't match exactly.
for(auto const& [queue, instructions] : m_instructionQueues)
{
fail = true;
msg += concatenate(" Wait zero: ",
ShowValue(m_needsWaitZero.at(queue)),
ShowValue(labelState.m_needsWaitZero.at(queue)),
ShowValue(queue),
"\n");

if(!longErrMsg)
AssertFatal(!fail, msg);
}

for(auto const& instruction : instructions)
{
if(!instruction.empty())
if(m_needsWaitZero.at(queue) || branchState.m_needsWaitZero.at(queue))
{
fail = true;
msg += concatenate(" Extra register at label: ",
ShowValue(instruction),
msg += concatenate(" Wait zero: ",
ShowValue(m_needsWaitZero.at(queue)),
ShowValue(branchState.m_needsWaitZero.at(queue)),
ShowValue(queue),
"\n");

if(!longErrMsg)
AssertFatal(!fail, msg);
}
}

for(auto const& instruction : labelState.m_instructionQueues.at(queue))
for(auto const& instruction : instructions)
{
if(!instruction.empty())
{
fail = true;
msg += concatenate(" Extra register at label: ",
ShowValue(instruction),
ShowValue(queue),
"\n");

if(!longErrMsg)
AssertFatal(!fail, msg);
}
}

for(auto const& instruction : branchState.m_instructionQueues.at(queue))
{
if(!instruction.empty())
{
fail = true;
msg += concatenate(" Extra register at branch: ",
ShowValue(instruction),
ShowValue(queue),
"\n");

if(!longErrMsg)
AssertFatal(!fail, msg);
}
}
}
}
else
{
// Relaxed mode (conditionals / forward branches): the label
// state may be more conservative (superset) than the branch
// state. Only fail if the branch has entries or flags that the
// label doesn't know about.
for(auto const& [queue, labelInstructions] : m_instructionQueues)
{
if(!instruction.empty())
auto const& branchInstructions = branchState.m_instructionQueues.at(queue);

if(branchState.m_needsWaitZero.at(queue) && !m_needsWaitZero.at(queue))
{
fail = true;
msg += concatenate(
" Branch needs waitZero but label does not: ", ShowValue(queue), "\n");

if(!longErrMsg)
AssertFatal(!fail, msg);
}

if(branchInstructions.size() > labelInstructions.size())
{
fail = true;
msg += concatenate(" Extra register at branch: ",
ShowValue(instruction),
msg += concatenate(" Branch queue larger than label queue: ",
" branch=",
branchInstructions.size(),
" label=",
labelInstructions.size(),
ShowValue(queue),
"\n");

Expand Down Expand Up @@ -168,6 +207,13 @@ namespace rocRoller
"Branch without a label\n",
ShowValue(inst.toString(LogLevel::Debug)));
addBranchState(inst.getSrcs()[0]->toString());

// After an unconditional branch (e.g. end of the if-body),
// flag so that addLabelState resets the observer to the
// saved branch state instead of continuing with the
// state accumulated by the previous block.
if(inst.getOpCode() == "s_branch")
m_afterUnconditionalBranch = true;
}
else if(inst.isLabel())
{
Expand Down
Loading