Skip to content

Commit f65a3f7

Browse files
committed
Make MLIR Pass Timing output configurable through injection
This makes it possible for the client to control where the pass timings will be printed. Differential Revision: https://reviews.llvm.org/D78891
1 parent cd84bfb commit f65a3f7

File tree

3 files changed

+70
-28
lines changed

3 files changed

+70
-28
lines changed

mlir/include/mlir/Pass/PassManager.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,38 @@ class PassManager : public OpPassManager {
229229
//===--------------------------------------------------------------------===//
230230
// Pass Timing
231231

232+
/// A configuration struct provided to the pass timing feature.
233+
class PassTimingConfig {
234+
public:
235+
using PrintCallbackFn = function_ref<void(raw_ostream &)>;
236+
237+
/// Initialize the configuration.
238+
/// * 'displayMode' switch between list or pipeline display (see the
239+
/// `PassDisplayMode` enum documentation).
240+
explicit PassTimingConfig(
241+
PassDisplayMode displayMode = PassDisplayMode::Pipeline)
242+
: displayMode(displayMode) {}
243+
244+
virtual ~PassTimingConfig();
245+
246+
/// A hook that may be overridden by a derived config to control the
247+
/// printing. The callback is supplied by the framework and the config is
248+
/// responsible to call it back with a stream for the output.
249+
virtual void printTiming(PrintCallbackFn printCallback);
250+
251+
/// Return the `PassDisplayMode` this config was created with.
252+
PassDisplayMode getDisplayMode() { return displayMode; }
253+
254+
private:
255+
PassDisplayMode displayMode;
256+
};
257+
232258
/// Add an instrumentation to time the execution of passes and the computation
233259
/// of analyses.
234260
/// Note: Timing should be enabled after all other instrumentations to avoid
235261
/// any potential "ghost" timing from other instrumentations being
236262
/// unintentionally included in the timing results.
237-
void enableTiming(PassDisplayMode displayMode = PassDisplayMode::Pipeline);
263+
void enableTiming(std::unique_ptr<PassTimingConfig> config = nullptr);
238264

239265
/// Prompts the pass manager to print the statistics collected for each of the
240266
/// held passes after each call to 'run'.

mlir/lib/Pass/PassManagerOptions.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {
141141
/// Add a pass timing instrumentation if enabled by 'pass-timing' flags.
142142
void PassManagerOptions::addTimingInstrumentation(PassManager &pm) {
143143
if (passTiming)
144-
pm.enableTiming(passTimingDisplayMode);
144+
pm.enableTiming(
145+
std::make_unique<PassManager::PassTimingConfig>(passTimingDisplayMode));
145146
}
146147

147148
void mlir::registerPassManagerCLOptions() {

mlir/lib/Pass/PassTiming.cpp

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ struct Timer {
160160
};
161161

162162
struct PassTiming : public PassInstrumentation {
163-
PassTiming(PassDisplayMode displayMode) : displayMode(displayMode) {}
163+
PassTiming(std::unique_ptr<PassManager::PassTimingConfig> config)
164+
: config(std::move(config)) {}
164165
~PassTiming() override { print(); }
165166

166167
/// Setup the instrumentation hooks.
@@ -231,8 +232,8 @@ struct PassTiming : public PassInstrumentation {
231232
/// A stack of the currently active pass timers per thread.
232233
DenseMap<uint64_t, SmallVector<Timer *, 4>> activeThreadTimers;
233234

234-
/// The display mode to use when printing the timing results.
235-
PassDisplayMode displayMode;
235+
/// The configuration object to use when printing the timing results.
236+
std::unique_ptr<PassManager::PassTimingConfig> config;
236237

237238
/// A mapping of pipeline timers that need to be merged into the parent
238239
/// collection. The timers are mapped to the parent info to merge into.
@@ -353,28 +354,37 @@ void PassTiming::print() {
353354
return;
354355

355356
assert(rootTimers.size() == 1 && "expected one remaining root timer");
356-
auto &rootTimer = rootTimers.begin()->second;
357-
auto os = llvm::CreateInfoOutputFile();
358-
359-
// Print the timer header.
360-
TimeRecord totalTime = rootTimer->getTotalTime();
361-
printTimerHeader(*os, totalTime);
362-
363-
// Defer to a specialized printer for each display mode.
364-
switch (displayMode) {
365-
case PassDisplayMode::List:
366-
printResultsAsList(*os, rootTimer.get(), totalTime);
367-
break;
368-
case PassDisplayMode::Pipeline:
369-
printResultsAsPipeline(*os, rootTimer.get(), totalTime);
370-
break;
371-
}
372-
printTimeEntry(*os, 0, "Total", totalTime, totalTime);
373-
os->flush();
374357

375-
// Reset root timers.
376-
rootTimers.clear();
377-
activeThreadTimers.clear();
358+
auto printCallback = [&](raw_ostream &os) {
359+
auto &rootTimer = rootTimers.begin()->second;
360+
// Print the timer header.
361+
TimeRecord totalTime = rootTimer->getTotalTime();
362+
printTimerHeader(os, totalTime);
363+
// Defer to a specialized printer for each display mode.
364+
switch (config->getDisplayMode()) {
365+
case PassDisplayMode::List:
366+
printResultsAsList(os, rootTimer.get(), totalTime);
367+
break;
368+
case PassDisplayMode::Pipeline:
369+
printResultsAsPipeline(os, rootTimer.get(), totalTime);
370+
break;
371+
}
372+
printTimeEntry(os, 0, "Total", totalTime, totalTime);
373+
os.flush();
374+
375+
// Reset root timers.
376+
rootTimers.clear();
377+
activeThreadTimers.clear();
378+
};
379+
380+
config->printTiming(printCallback);
381+
}
382+
383+
// The default implementation for printTiming uses
384+
// `llvm::CreateInfoOutputFile()` as stream, it can be overridden by clients
385+
// to customize the output.
386+
void PassManager::PassTimingConfig::printTiming(PrintCallbackFn printCallback) {
387+
printCallback(*llvm::CreateInfoOutputFile());
378388
}
379389

380390
/// Print the timing result in list mode.
@@ -449,16 +459,21 @@ void PassTiming::printResultsAsPipeline(raw_ostream &os, Timer *root,
449459
printTimer(0, topLevelTimer.second.get());
450460
}
451461

462+
// Out-of-line as key function.
463+
PassManager::PassTimingConfig::~PassTimingConfig() {}
464+
452465
//===----------------------------------------------------------------------===//
453466
// PassManager
454467
//===----------------------------------------------------------------------===//
455468

456469
/// Add an instrumentation to time the execution of passes and the computation
457470
/// of analyses.
458-
void PassManager::enableTiming(PassDisplayMode displayMode) {
471+
void PassManager::enableTiming(std::unique_ptr<PassTimingConfig> config) {
459472
// Check if pass timing is already enabled.
460473
if (passTiming)
461474
return;
462-
addInstrumentation(std::make_unique<PassTiming>(displayMode));
475+
if (!config)
476+
config = std::make_unique<PassManager::PassTimingConfig>();
477+
addInstrumentation(std::make_unique<PassTiming>(std::move(config)));
463478
passTiming = true;
464479
}

0 commit comments

Comments
 (0)