Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include <chrono>
#include <iostream>
#include <mutex>
#include <sstream>
#include <string>

namespace TensileLite
Expand All @@ -15,9 +17,65 @@ namespace TensileLite
// Set via command line: --timing-instrumentation
inline bool g_timingInstrumentationEnabled = false;

// Simple RAII timer that prints timing on destruction
// Buffer for timing output to avoid per-event stderr writes.
// Accumulates into a single ostringstream and flushes to stderr
// when flush() is called or the buffer exceeds a size threshold.
class TimingBuffer
{
public:
static TimingBuffer& instance()
{
static TimingBuffer buf;
return buf;
}

void append(const char* data, size_t len)
{
std::lock_guard<std::mutex> lock(m_mutex);
m_stream.write(data, len);
m_stream.put('\n');
m_size += len + 1;
if(m_size >= FlushThreshold)
flushLocked();
}

void flush()
{
std::lock_guard<std::mutex> lock(m_mutex);
flushLocked();
}

private:
static constexpr size_t FlushThreshold = 1 << 20; // 1 MB

void flushLocked()
{
if(m_size == 0)
return;
std::cerr << m_stream.str();
m_stream.str(std::string());
m_stream.clear();
m_size = 0;
}

TimingBuffer() = default;
std::mutex m_mutex;
std::ostringstream m_stream;
size_t m_size = 0;
};

inline void flushTimingBuffer()
{
TimingBuffer::instance().flush();
}

// Simple RAII timer that records timing on destruction
// Output format: TIMING:<category>:<duration_ms>
// This format is easily parseable by post-processing scripts
//
// Timing records are buffered in memory and flushed periodically
// (every ~1 MB) or via flushTimingBuffer() to avoid per-event
// stderr syscall overhead.
class ScopedTimer
{
public:
Expand All @@ -35,7 +93,11 @@ namespace TensileLite
{
auto end = clock::now();
auto duration = std::chrono::duration<double, std::milli>(end - m_start);
std::cerr << "TIMING:" << m_category << ":" << duration.count() << std::endl;
char buf[256];
int n = snprintf(buf, sizeof(buf), "TIMING:%s:%.6f",
m_category.c_str(), duration.count());
if(n > 0)
TimingBuffer::instance().append(buf, n);
}
}

Expand All @@ -48,7 +110,7 @@ namespace TensileLite
}

private:
std::string m_category;
std::string m_category;
std::chrono::time_point<clock> m_start;
};

Expand All @@ -57,7 +119,11 @@ namespace TensileLite
{
if(g_timingInstrumentationEnabled)
{
std::cerr << "TIMING:" << category << ":" << ms << std::endl;
char buf[256];
int n = snprintf(buf, sizeof(buf), "TIMING:%s:%.6f",
category.c_str(), ms);
if(n > 0)
TimingBuffer::instance().append(buf, n);
}
}

Expand All @@ -67,9 +133,12 @@ namespace TensileLite
{
if(g_timingInstrumentationEnabled)
{
std::cerr << "TIMING_CONTEXT:M=" << M << ",N=" << N << ",K=" << K
<< ",batch=" << batchCount << ",typeA=" << typeA << ",typeD=" << typeD
<< std::endl;
char buf[256];
int n = snprintf(buf, sizeof(buf),
"TIMING_CONTEXT:M=%zu,N=%zu,K=%zu,batch=%zu,typeA=%s,typeD=%s",
M, N, K, batchCount, typeA.c_str(), typeD.c_str());
if(n > 0)
TimingBuffer::instance().append(buf, n);
}
}

Expand All @@ -80,11 +149,14 @@ namespace TensileLite
{
if(g_timingInstrumentationEnabled)
{
std::cerr << "TIMING_CONTEXT_GROUPED:index=" << index
<< ",total=" << totalGemms
<< ",M=" << M << ",N=" << N << ",K=" << K
<< ",batch=" << batchCount << ",typeA=" << typeA << ",typeD=" << typeD
<< std::endl;
char buf[256];
int n = snprintf(buf, sizeof(buf),
"TIMING_CONTEXT_GROUPED:index=%zu,total=%zu,"
"M=%zu,N=%zu,K=%zu,batch=%zu,typeA=%s,typeD=%s",
index, totalGemms, M, N, K, batchCount,
typeA.c_str(), typeD.c_str());
if(n > 0)
TimingBuffer::instance().append(buf, n);
}
}

Expand Down
4 changes: 4 additions & 0 deletions projects/hipblaslt/tensilelite/client/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,7 @@ int main(int argc, const char* argv[])

if(exitOnError && listeners.error() > 0)
{
flushTimingBuffer();
// error range in shell is [0-255]
return std::min(listeners.error(), 255);
}
Expand All @@ -956,6 +957,9 @@ int main(int argc, const char* argv[])
listeners.finalizeReport();
}

// Flush all buffered timing records to stderr
flushTimingBuffer();

// error range in shell is [0-255]
return std::min(listeners.error(), 255);
}
Loading