Skip to content

Commit e5960dc

Browse files
committed
[mlir][python] auto-locs
1 parent 3212704 commit e5960dc

File tree

8 files changed

+323
-21
lines changed

8 files changed

+323
-21
lines changed

mlir/lib/Bindings/Python/Globals.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,19 @@
1010
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
1111

1212
#include <optional>
13+
#include <regex>
1314
#include <string>
15+
#include <unordered_set>
1416
#include <vector>
1517

1618
#include "NanobindUtils.h"
1719
#include "mlir-c/IR.h"
1820
#include "mlir/CAPI/Support.h"
1921
#include "llvm/ADT/DenseMap.h"
22+
#include "llvm/ADT/StringExtras.h"
2023
#include "llvm/ADT/StringRef.h"
2124
#include "llvm/ADT/StringSet.h"
25+
#include "llvm/Support/Regex.h"
2226

2327
namespace mlir {
2428
namespace python {
@@ -114,6 +118,37 @@ class PyGlobals {
114118
std::optional<nanobind::object>
115119
lookupOperationClass(llvm::StringRef operationName);
116120

121+
class TracebackLoc {
122+
public:
123+
bool locTracebacksEnabled() const;
124+
125+
void setLocTracebacksEnabled(bool value);
126+
127+
size_t locTracebackFramesLimit() const;
128+
129+
void setLocTracebackFramesLimit(size_t value);
130+
131+
void registerTracebackFileInclusion(const std::string &file);
132+
133+
void registerTracebackFileExclusion(const std::string &file);
134+
135+
bool isUserTracebackFilename(llvm::StringRef file);
136+
137+
private:
138+
nanobind::ft_mutex mutex;
139+
bool locTracebackEnabled_ = false;
140+
size_t locTracebackFramesLimit_ = 10;
141+
std::unordered_set<std::string> userTracebackIncludeFiles;
142+
std::unordered_set<std::string> userTracebackExcludeFiles;
143+
std::regex userTracebackIncludeRegex;
144+
bool rebuildUserTracebackIncludeRegex = false;
145+
std::regex userTracebackExcludeRegex;
146+
bool rebuildUserTracebackExcludeRegex = false;
147+
llvm::StringMap<bool> isUserTracebackFilenameCache;
148+
};
149+
150+
TracebackLoc &getTracebackLoc() { return tracebackLoc; }
151+
117152
private:
118153
static PyGlobals *instance;
119154

@@ -134,6 +169,8 @@ class PyGlobals {
134169
/// Set of dialect namespaces that we have attempted to import implementation
135170
/// modules for.
136171
llvm::StringSet<> loadedDialectModules;
172+
173+
TracebackLoc tracebackLoc;
137174
};
138175

139176
} // namespace python

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 121 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@
2020
#include "nanobind/nanobind.h"
2121
#include "llvm/ADT/ArrayRef.h"
2222
#include "llvm/ADT/SmallVector.h"
23-
#include "llvm/Support/raw_ostream.h"
2423

2524
#include <optional>
26-
#include <system_error>
27-
#include <utility>
2825

2926
namespace nb = nanobind;
3027
using namespace nb::literals;
@@ -1523,7 +1520,7 @@ nb::object PyOperation::create(std::string_view name,
15231520
llvm::ArrayRef<MlirValue> operands,
15241521
std::optional<nb::dict> attributes,
15251522
std::optional<std::vector<PyBlock *>> successors,
1526-
int regions, DefaultingPyLocation ___location,
1523+
int regions, PyLocation ___location,
15271524
const nb::object &maybeIp, bool inferType) {
15281525
llvm::SmallVector<MlirType, 4> mlirResults;
15291526
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1627,7 +1624,7 @@ nb::object PyOperation::create(std::string_view name,
16271624
if (!operation.ptr)
16281625
throw nb::value_error("Operation creation failed");
16291626
PyOperationRef created =
1630-
PyOperation::createDetached(___location->getContext(), operation);
1627+
PyOperation::createDetached(___location.getContext(), operation);
16311628
maybeInsertOperation(created, maybeIp);
16321629

16331630
return created.getObject();
@@ -1937,9 +1934,9 @@ nb::object PyOpView::buildGeneric(
19371934
std::optional<nb::list> resultTypeList, nb::list operandList,
19381935
std::optional<nb::dict> attributes,
19391936
std::optional<std::vector<PyBlock *>> successors,
1940-
std::optional<int> regions, DefaultingPyLocation ___location,
1937+
std::optional<int> regions, PyLocation ___location,
19411938
const nb::object &maybeIp) {
1942-
PyMlirContextRef context = ___location->getContext();
1939+
PyMlirContextRef context = ___location.getContext();
19431940

19441941
// Class level operation construction metadata.
19451942
// Operand and result segment specs are either none, which does no
@@ -2789,6 +2786,111 @@ class PyOpAttributeMap {
27892786
PyOperationRef operation;
27902787
};
27912788

2789+
// copied/borrow from
2790+
// https://github.com/python/pythoncapi-compat/blob/b541b98df1e3e5aabb5def27422a75c876f5a88a/pythoncapi_compat.h#L222
2791+
// bpo-40421 added PyFrame_GetLasti() to Python 3.11.0b1
2792+
#if PY_VERSION_HEX < 0x030b00b1 && !defined(PYPY_VERSION)
2793+
int PyFrame_GetLasti(PyFrameObject *frame) {
2794+
#if PY_VERSION_HEX >= 0x030a00a7
2795+
// bpo-27129: Since Python 3.10.0a7, f_lasti is an instruction offset,
2796+
// not a bytes offset anymore. Python uses 16-bit "wordcode" (2 bytes)
2797+
// instructions.
2798+
if (frame->f_lasti < 0) {
2799+
return -1;
2800+
}
2801+
return frame->f_lasti * 2;
2802+
#else
2803+
return frame->f_lasti;
2804+
#endif
2805+
}
2806+
#endif
2807+
2808+
constexpr size_t kMaxFrames = 512;
2809+
2810+
MlirLocation tracebackToLocation(MlirContext ctx) {
2811+
size_t framesLimit =
2812+
PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
2813+
// We use a thread_local here mostly to avoid requiring a large amount of
2814+
// space.
2815+
thread_local std::array<MlirLocation, kMaxFrames> frames;
2816+
size_t count = 0;
2817+
2818+
assert(PyGILState_Check());
2819+
2820+
PyThreadState *tstate = PyThreadState_GET();
2821+
2822+
PyFrameObject *next;
2823+
for (PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2824+
pyFrame != nullptr && count < framesLimit;
2825+
next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
2826+
PyCodeObject *code = PyFrame_GetCode(pyFrame);
2827+
auto fileNameStr =
2828+
nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2829+
llvm::StringRef fileName(fileNameStr);
2830+
if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
2831+
continue;
2832+
2833+
#if PY_VERSION_HEX < 0x030b00f0
2834+
std::string name =
2835+
nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2836+
llvm::StringRef funcName(name);
2837+
int startLine = PyFrame_GetLineNumber(pyFrame);
2838+
MlirLocation loc =
2839+
mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
2840+
#else
2841+
// co_qualname and PyCode_Addr2Location added in py3.11
2842+
std::string name =
2843+
nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2844+
llvm::StringRef funcName(name);
2845+
int startLine, startCol, endLine, endCol;
2846+
int lasti = PyFrame_GetLasti(pyFrame);
2847+
if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2848+
&endCol)) {
2849+
throw nb::python_error();
2850+
}
2851+
MlirLocation loc = mlirLocationFileLineColRangeGet(
2852+
ctx, wrap(fileName), startLine, startCol, endLine, endCol);
2853+
#endif
2854+
2855+
frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
2856+
++count;
2857+
if (count > framesLimit)
2858+
break;
2859+
}
2860+
2861+
if (count == 0)
2862+
return mlirLocationUnknownGet(ctx);
2863+
if (count == 1)
2864+
return frames[0];
2865+
2866+
MlirLocation callee = frames[0];
2867+
MlirLocation caller = frames[count - 1];
2868+
for (int i = count - 2; i >= 1; i--)
2869+
caller = mlirLocationCallSiteGet(frames[i], caller);
2870+
2871+
return mlirLocationCallSiteGet(callee, caller);
2872+
}
2873+
2874+
PyLocation
2875+
maybeGetTracebackLocation(const std::optional<PyLocation> &___location) {
2876+
MlirLocation mlirLoc;
2877+
MlirContext mlirCtx;
2878+
if (!___location.has_value() &&
2879+
PyGlobals::get().getTracebackLoc().locTracebacksEnabled()) {
2880+
mlirCtx = DefaultingPyMlirContext::resolve().get();
2881+
mlirLoc = tracebackToLocation(mlirCtx);
2882+
} else if (!___location.has_value()) {
2883+
mlirLoc = DefaultingPyLocation::resolve();
2884+
mlirCtx = mlirLocationGetContext(mlirLoc);
2885+
} else {
2886+
mlirLoc = *___location;
2887+
mlirCtx = mlirLocationGetContext(mlirLoc);
2888+
}
2889+
assert(!mlirLocationIsNull(mlirLoc) && "expected non-null mlirLoc");
2890+
PyMlirContextRef ctx = PyMlirContext::forContext(mlirCtx);
2891+
return {ctx, mlirLoc};
2892+
}
2893+
27922894
} // namespace
27932895

27942896
//------------------------------------------------------------------------------
@@ -3240,8 +3342,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
32403342
kModuleParseDocstring)
32413343
.def_static(
32423344
"create",
3243-
[](DefaultingPyLocation loc) {
3244-
MlirModule module = mlirModuleCreateEmpty(loc);
3345+
[](std::optional<PyLocation> loc) {
3346+
PyLocation pyLoc = maybeGetTracebackLocation(loc);
3347+
MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
32453348
return PyModule::forModule(module).releaseObject();
32463349
},
32473350
nb::arg("loc").none() = nb::none(), "Creates an empty module")
@@ -3454,7 +3557,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34543557
std::optional<std::vector<PyValue *>> operands,
34553558
std::optional<nb::dict> attributes,
34563559
std::optional<std::vector<PyBlock *>> successors, int regions,
3457-
DefaultingPyLocation ___location, const nb::object &maybeIp,
3560+
std::optional<PyLocation> ___location, const nb::object &maybeIp,
34583561
bool inferType) {
34593562
// Unpack/validate operands.
34603563
llvm::SmallVector<MlirValue, 4> mlirOperands;
@@ -3467,8 +3570,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34673570
}
34683571
}
34693572

3573+
PyLocation pyLoc = maybeGetTracebackLocation(___location);
34703574
return PyOperation::create(name, results, mlirOperands, attributes,
3471-
successors, regions, ___location, maybeIp,
3575+
successors, regions, pyLoc, maybeIp,
34723576
inferType);
34733577
},
34743578
nb::arg("name"), nb::arg("results").none() = nb::none(),
@@ -3512,12 +3616,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
35123616
std::optional<nb::list> resultTypeList, nb::list operandList,
35133617
std::optional<nb::dict> attributes,
35143618
std::optional<std::vector<PyBlock *>> successors,
3515-
std::optional<int> regions, DefaultingPyLocation ___location,
3619+
std::optional<int> regions, std::optional<PyLocation> ___location,
35163620
const nb::object &maybeIp) {
3621+
PyLocation pyLoc = maybeGetTracebackLocation(___location);
35173622
new (self) PyOpView(PyOpView::buildGeneric(
35183623
name, opRegionSpec, operandSegmentSpecObj,
35193624
resultSegmentSpecObj, resultTypeList, operandList,
3520-
attributes, successors, regions, ___location, maybeIp));
3625+
attributes, successors, regions, pyLoc, maybeIp));
35213626
},
35223627
nb::arg("name"), nb::arg("opRegionSpec"),
35233628
nb::arg("operandSegmentSpecObj").none() = nb::none(),
@@ -3551,17 +3656,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
35513656
[](nb::handle cls, std::optional<nb::list> resultTypeList,
35523657
nb::list operandList, std::optional<nb::dict> attributes,
35533658
std::optional<std::vector<PyBlock *>> successors,
3554-
std::optional<int> regions, DefaultingPyLocation ___location,
3659+
std::optional<int> regions, std::optional<PyLocation> ___location,
35553660
const nb::object &maybeIp) {
35563661
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
35573662
std::tuple<int, bool> opRegionSpec =
35583663
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
35593664
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
35603665
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3666+
PyLocation pyLoc = maybeGetTracebackLocation(___location);
35613667
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
35623668
resultSegmentSpec, resultTypeList,
35633669
operandList, attributes, successors,
3564-
regions, ___location, maybeIp);
3670+
regions, pyLoc, maybeIp);
35653671
},
35663672
nb::arg("cls"), nb::arg("results").none() = nb::none(),
35673673
nb::arg("operands").none() = nb::none(),

mlir/lib/Bindings/Python/IRModule.cpp

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313

1414
#include "Globals.h"
1515
#include "NanobindUtils.h"
16+
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1617
#include "mlir-c/Support.h"
1718
#include "mlir/Bindings/Python/Nanobind.h"
18-
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1919

2020
namespace nb = nanobind;
2121
using namespace mlir;
@@ -197,3 +197,69 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
197197
// Not found and loading did not yield a registration.
198198
return std::nullopt;
199199
}
200+
201+
bool PyGlobals::TracebackLoc::locTracebacksEnabled() const {
202+
return locTracebackEnabled_;
203+
}
204+
205+
void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) {
206+
nanobind::ft_lock_guard lock(mutex);
207+
locTracebackEnabled_ = value;
208+
}
209+
210+
size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() const {
211+
return locTracebackFramesLimit_;
212+
}
213+
214+
void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) {
215+
nanobind::ft_lock_guard lock(mutex);
216+
locTracebackFramesLimit_ = value;
217+
}
218+
219+
void PyGlobals::TracebackLoc::registerTracebackFileInclusion(
220+
const std::string &file) {
221+
nanobind::ft_lock_guard lock(mutex);
222+
auto reg = "^" + llvm::Regex::escape(file);
223+
if (userTracebackIncludeFiles.insert(reg).second)
224+
rebuildUserTracebackIncludeRegex = true;
225+
if (userTracebackExcludeFiles.count(reg)) {
226+
if (userTracebackExcludeFiles.erase(reg))
227+
rebuildUserTracebackExcludeRegex = true;
228+
}
229+
}
230+
231+
void PyGlobals::TracebackLoc::registerTracebackFileExclusion(
232+
const std::string &file) {
233+
nanobind::ft_lock_guard lock(mutex);
234+
auto reg = "^" + llvm::Regex::escape(file);
235+
if (userTracebackExcludeFiles.insert(reg).second)
236+
rebuildUserTracebackExcludeRegex = true;
237+
if (userTracebackIncludeFiles.count(reg)) {
238+
if (userTracebackIncludeFiles.erase(reg))
239+
rebuildUserTracebackIncludeRegex = true;
240+
}
241+
}
242+
243+
bool PyGlobals::TracebackLoc::isUserTracebackFilename(
244+
const llvm::StringRef file) {
245+
nanobind::ft_lock_guard lock(mutex);
246+
if (rebuildUserTracebackIncludeRegex) {
247+
userTracebackIncludeRegex.assign(
248+
llvm::join(userTracebackIncludeFiles, "|"));
249+
rebuildUserTracebackIncludeRegex = false;
250+
isUserTracebackFilenameCache.clear();
251+
}
252+
if (rebuildUserTracebackExcludeRegex) {
253+
userTracebackExcludeRegex.assign(
254+
llvm::join(userTracebackExcludeFiles, "|"));
255+
rebuildUserTracebackExcludeRegex = false;
256+
isUserTracebackFilenameCache.clear();
257+
}
258+
if (!isUserTracebackFilenameCache.contains(file)) {
259+
std::string fileStr = file.str();
260+
bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
261+
bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
262+
isUserTracebackFilenameCache[file] = include || !exclude;
263+
}
264+
return isUserTracebackFilenameCache[file];
265+
}

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -722,8 +722,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
722722
llvm::ArrayRef<MlirValue> operands,
723723
std::optional<nanobind::dict> attributes,
724724
std::optional<std::vector<PyBlock *>> successors, int regions,
725-
DefaultingPyLocation ___location, const nanobind::object &ip,
726-
bool inferType);
725+
PyLocation ___location, const nanobind::object &ip, bool inferType);
727726

728727
/// Creates an OpView suitable for this operation.
729728
nanobind::object createOpView();
@@ -781,7 +780,7 @@ class PyOpView : public PyOperationBase {
781780
nanobind::list operandList,
782781
std::optional<nanobind::dict> attributes,
783782
std::optional<std::vector<PyBlock *>> successors,
784-
std::optional<int> regions, DefaultingPyLocation ___location,
783+
std::optional<int> regions, PyLocation ___location,
785784
const nanobind::object &maybeIp);
786785

787786
/// Construct an instance of a class deriving from OpView, bypassing its

0 commit comments

Comments
 (0)