Skip to content

Commit 34d5f14

Browse files
committed
[mlir][python] auto-locs
1 parent 3212704 commit 34d5f14

File tree

5 files changed

+233
-18
lines changed

5 files changed

+233
-18
lines changed

mlir/lib/Bindings/Python/Globals.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,20 @@
1010
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
1111

1212
#include <optional>
13+
#include <regex>
1314
#include <string>
1415
#include <vector>
1516

1617
#include "NanobindUtils.h"
1718
#include "mlir-c/IR.h"
1819
#include "mlir/CAPI/Support.h"
1920
#include "llvm/ADT/DenseMap.h"
21+
#include "llvm/ADT/DenseSet.h"
22+
#include "llvm/ADT/SmallVectorExtras.h"
23+
#include "llvm/ADT/StringExtras.h"
2024
#include "llvm/ADT/StringRef.h"
2125
#include "llvm/ADT/StringSet.h"
26+
#include "llvm/Support/Regex.h"
2227

2328
namespace mlir {
2429
namespace python {
@@ -114,11 +119,72 @@ class PyGlobals {
114119
std::optional<nanobind::object>
115120
lookupOperationClass(llvm::StringRef operationName);
116121

122+
bool locTracebacksEnabled() {
123+
nanobind::ft_lock_guard lock(mutex);
124+
return locTracebackEnabled_;
125+
}
126+
127+
void setLocTracebacksEnabled(bool value) {
128+
nanobind::ft_lock_guard lock(mutex);
129+
locTracebackEnabled_ = value;
130+
}
131+
132+
bool locTracebackFramesLimit() {
133+
nanobind::ft_lock_guard lock(mutex);
134+
return locTracebackFramesLimit_;
135+
}
136+
137+
void setLocTracebackFramesLimit(size_t value) {
138+
nanobind::ft_lock_guard lock(mutex);
139+
locTracebackFramesLimit_ = value;
140+
}
141+
142+
void registerTracebackFileInclusion(llvm::StringRef file) {
143+
nanobind::ft_lock_guard lock(mutex);
144+
userTracebackIncludeFiles.insert(file);
145+
userTracebackIncludeRegex.assign(
146+
llvm::join(llvm::map_range(userTracebackIncludeFiles,
147+
[](llvm::StringRef path) {
148+
return "^" + llvm::Regex::escape(path);
149+
}),
150+
"|"));
151+
}
152+
153+
void registerTracebackFileExclusion(llvm::StringRef file) {
154+
nanobind::ft_lock_guard lock(mutex);
155+
userTracebackExcludeFiles.insert(file);
156+
userTracebackExcludeRegex.assign(
157+
llvm::join(llvm::map_range(userTracebackExcludeFiles,
158+
[](llvm::StringRef path) {
159+
return "^" + llvm::Regex::escape(path);
160+
}),
161+
"|"));
162+
}
163+
164+
bool isUserTracebackFilename(llvm::StringRef file) {
165+
nanobind::ft_lock_guard lock(mutex);
166+
if (!isUserTracebackFilenameCache.contains(file)) {
167+
std::string fileStr = file.str();
168+
isUserTracebackFilenameCache[file] =
169+
!std::regex_search(fileStr, userTracebackIncludeRegex) ||
170+
std::regex_search(fileStr, userTracebackExcludeRegex);
171+
}
172+
return isUserTracebackFilenameCache[file];
173+
}
174+
117175
private:
118176
static PyGlobals *instance;
119177

120178
nanobind::ft_mutex mutex;
121179

180+
bool locTracebackEnabled_ = false;
181+
size_t locTracebackFramesLimit_ = 10;
182+
llvm::DenseSet<llvm::StringRef> userTracebackIncludeFiles;
183+
llvm::DenseSet<llvm::StringRef> userTracebackExcludeFiles;
184+
std::regex userTracebackIncludeRegex;
185+
std::regex userTracebackExcludeRegex;
186+
llvm::StringMap<bool> isUserTracebackFilenameCache;
187+
122188
/// Module name prefixes to search under for dialect implementation modules.
123189
std::vector<std::string> dialectSearchPrefixes;
124190
/// Map of dialect namespace to external dialect class object.

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 127 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,7 +1523,7 @@ nb::object PyOperation::create(std::string_view name,
15231523
llvm::ArrayRef<MlirValue> operands,
15241524
std::optional<nb::dict> attributes,
15251525
std::optional<std::vector<PyBlock *>> successors,
1526-
int regions, DefaultingPyLocation ___location,
1526+
int regions, PyLocation ___location,
15271527
const nb::object &maybeIp, bool inferType) {
15281528
llvm::SmallVector<MlirType, 4> mlirResults;
15291529
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1627,7 +1627,7 @@ nb::object PyOperation::create(std::string_view name,
16271627
if (!operation.ptr)
16281628
throw nb::value_error("Operation creation failed");
16291629
PyOperationRef created =
1630-
PyOperation::createDetached(___location->getContext(), operation);
1630+
PyOperation::createDetached(___location.getContext(), operation);
16311631
maybeInsertOperation(created, maybeIp);
16321632

16331633
return created.getObject();
@@ -1937,9 +1937,9 @@ nb::object PyOpView::buildGeneric(
19371937
std::optional<nb::list> resultTypeList, nb::list operandList,
19381938
std::optional<nb::dict> attributes,
19391939
std::optional<std::vector<PyBlock *>> successors,
1940-
std::optional<int> regions, DefaultingPyLocation ___location,
1940+
std::optional<int> regions, PyLocation ___location,
19411941
const nb::object &maybeIp) {
1942-
PyMlirContextRef context = ___location->getContext();
1942+
PyMlirContextRef context = ___location.getContext();
19431943

19441944
// Class level operation construction metadata.
19451945
// Operand and result segment specs are either none, which does no
@@ -2789,6 +2789,94 @@ class PyOpAttributeMap {
27892789
PyOperationRef operation;
27902790
};
27912791

2792+
// bpo-40421 added PyFrame_GetLasti() to Python 3.11.0b1
2793+
#if PY_VERSION_HEX < 0x030b00b1 && !defined(PYPY_VERSION)
2794+
int PyFrame_GetLasti(PyFrameObject *frame) {
2795+
#if PY_VERSION_HEX >= 0x030a00a7
2796+
// bpo-27129: Since Python 3.10.0a7, f_lasti is an instruction offset,
2797+
// not a bytes offset anymore. Python uses 16-bit "wordcode" (2 bytes)
2798+
// instructions.
2799+
if (frame->f_lasti < 0) {
2800+
return -1;
2801+
}
2802+
return frame->f_lasti * 2;
2803+
#else
2804+
return frame->f_lasti;
2805+
#endif
2806+
}
2807+
#endif
2808+
2809+
std::optional<MlirLocation> tracebackToLocation(MlirContext ctx) {
2810+
size_t framesLimit = PyGlobals::get().locTracebackFramesLimit();
2811+
// We use a thread_local here mostly to avoid requiring a large amount of
2812+
// space.
2813+
thread_local std::vector<MlirLocation> frames;
2814+
frames.reserve(framesLimit);
2815+
size_t count = 0;
2816+
2817+
assert(PyGILState_Check());
2818+
2819+
if (!PyGlobals::get().locTracebacksEnabled())
2820+
return std::nullopt;
2821+
2822+
PyThreadState *tstate = PyThreadState_GET();
2823+
2824+
PyFrameObject *next;
2825+
for (PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2826+
pyFrame != nullptr && count < framesLimit; pyFrame = next) {
2827+
PyCodeObject *code = PyFrame_GetCode(pyFrame);
2828+
auto fileNameStr =
2829+
nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2830+
llvm::StringRef fileName(fileNameStr);
2831+
if (!PyGlobals::get().isUserTracebackFilename(fileName))
2832+
continue;
2833+
2834+
#if PY_VERSION_HEX < 0x030b00f0
2835+
std::string name =
2836+
nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2837+
llvm::StringRef funcName(name);
2838+
int startLine = PyFrame_GetLineNumber(pyFrame);
2839+
MlirLocation loc =
2840+
mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
2841+
#else
2842+
// co_qualname added in py3.11
2843+
std::string name =
2844+
nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2845+
llvm::StringRef funcName(name);
2846+
int startLine, startCol, endLine, endCol;
2847+
int lasti = PyFrame_GetLasti(pyFrame);
2848+
if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2849+
&endCol)) {
2850+
throw nb::python_error();
2851+
}
2852+
MlirLocation loc = mlirLocationFileLineColRangeGet(
2853+
ctx, wrap(fileName), startCol, startCol, endLine, endCol);
2854+
#endif
2855+
2856+
frames.push_back(mlirLocationNameGet(ctx, wrap(funcName), loc));
2857+
++count;
2858+
next = PyFrame_GetBack(pyFrame);
2859+
Py_XDECREF(pyFrame);
2860+
2861+
if (frames.size() > framesLimit)
2862+
break;
2863+
}
2864+
2865+
if (frames.empty())
2866+
return mlirLocationUnknownGet(ctx);
2867+
if (frames.size() == 1)
2868+
return frames.front();
2869+
2870+
MlirLocation callee = frames.front();
2871+
frames.erase(frames.begin());
2872+
MlirLocation caller = frames.back();
2873+
for (const MlirLocation &frame :
2874+
llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2875+
caller = mlirLocationCallSiteGet(frame, caller);
2876+
2877+
return mlirLocationCallSiteGet(callee, caller);
2878+
}
2879+
27922880
} // namespace
27932881

27942882
//------------------------------------------------------------------------------
@@ -3241,7 +3329,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
32413329
.def_static(
32423330
"create",
32433331
[](DefaultingPyLocation loc) {
3244-
MlirModule module = mlirModuleCreateEmpty(loc);
3332+
PyMlirContextRef ctx = loc->getContext();
3333+
MlirLocation mlirLoc = loc;
3334+
if (auto tloc = tracebackToLocation(ctx->get()))
3335+
mlirLoc = *tloc;
3336+
MlirModule module = mlirModuleCreateEmpty(mlirLoc);
32453337
return PyModule::forModule(module).releaseObject();
32463338
},
32473339
nb::arg("loc").none() = nb::none(), "Creates an empty module")
@@ -3467,9 +3559,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34673559
}
34683560
}
34693561

3562+
PyMlirContextRef ctx = ___location->getContext();
3563+
if (auto loc = tracebackToLocation(ctx->get())) {
3564+
return PyOperation::create(name, results, mlirOperands,
3565+
attributes, successors, regions,
3566+
{ctx, *loc}, maybeIp, inferType);
3567+
}
34703568
return PyOperation::create(name, results, mlirOperands, attributes,
3471-
successors, regions, ___location, maybeIp,
3472-
inferType);
3569+
successors, regions, *___location.get(),
3570+
maybeIp, inferType);
34733571
},
34743572
nb::arg("name"), nb::arg("results").none() = nb::none(),
34753573
nb::arg("operands").none() = nb::none(),
@@ -3514,10 +3612,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
35143612
std::optional<std::vector<PyBlock *>> successors,
35153613
std::optional<int> regions, DefaultingPyLocation ___location,
35163614
const nb::object &maybeIp) {
3517-
new (self) PyOpView(PyOpView::buildGeneric(
3518-
name, opRegionSpec, operandSegmentSpecObj,
3519-
resultSegmentSpecObj, resultTypeList, operandList,
3520-
attributes, successors, regions, ___location, maybeIp));
3615+
PyMlirContextRef ctx = ___location->getContext();
3616+
if (auto loc = tracebackToLocation(ctx->get())) {
3617+
new (self) PyOpView(PyOpView::buildGeneric(
3618+
name, opRegionSpec, operandSegmentSpecObj,
3619+
resultSegmentSpecObj, resultTypeList, operandList,
3620+
attributes, successors, regions, {ctx, *loc}, maybeIp));
3621+
} else {
3622+
new (self) PyOpView(PyOpView::buildGeneric(
3623+
name, opRegionSpec, operandSegmentSpecObj,
3624+
resultSegmentSpecObj, resultTypeList, operandList,
3625+
attributes, successors, regions, *___location.get(),
3626+
maybeIp));
3627+
}
35213628
},
35223629
nb::arg("name"), nb::arg("opRegionSpec"),
35233630
nb::arg("operandSegmentSpecObj").none() = nb::none(),
@@ -3558,10 +3665,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
35583665
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
35593666
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
35603667
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3668+
3669+
PyMlirContextRef ctx = ___location->getContext();
3670+
if (auto loc = tracebackToLocation(ctx->get())) {
3671+
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
3672+
resultSegmentSpec, resultTypeList,
3673+
operandList, attributes, successors,
3674+
regions, {ctx, *loc}, maybeIp);
3675+
}
35613676
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
35623677
resultSegmentSpec, resultTypeList,
35633678
operandList, attributes, successors,
3564-
regions, ___location, maybeIp);
3679+
regions, *___location.get(), maybeIp);
35653680
},
35663681
nb::arg("cls"), nb::arg("results").none() = nb::none(),
35673682
nb::arg("operands").none() = nb::none(),

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -722,8 +722,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
722722
llvm::ArrayRef<MlirValue> operands,
723723
std::optional<nanobind::dict> attributes,
724724
std::optional<std::vector<PyBlock *>> successors, int regions,
725-
DefaultingPyLocation ___location, const nanobind::object &ip,
726-
bool inferType);
725+
PyLocation ___location, const nanobind::object &ip, bool inferType);
727726

728727
/// Creates an OpView suitable for this operation.
729728
nanobind::object createOpView();
@@ -781,7 +780,7 @@ class PyOpView : public PyOperationBase {
781780
nanobind::list operandList,
782781
std::optional<nanobind::dict> attributes,
783782
std::optional<std::vector<PyBlock *>> successors,
784-
std::optional<int> regions, DefaultingPyLocation ___location,
783+
std::optional<int> regions, PyLocation ___location,
785784
const nanobind::object &maybeIp);
786785

787786
/// Construct an instance of a class deriving from OpView, bypassing its

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
109
#include "Globals.h"
1110
#include "IRModule.h"
1211
#include "NanobindUtils.h"
@@ -44,7 +43,15 @@ NB_MODULE(_mlir, m) {
4443
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
4544
"operation_name"_a, "operation_class"_a, nb::kw_only(),
4645
"replace"_a = false,
47-
"Testing hook for directly registering an operation");
46+
"Testing hook for directly registering an operation")
47+
.def("loc_tracebacks_enabled", &PyGlobals::locTracebacksEnabled)
48+
.def("set_loc_tracebacks_enabled", &PyGlobals::setLocTracebacksEnabled)
49+
.def("set_loc_tracebacks_frame_limit",
50+
&PyGlobals::setLocTracebackFramesLimit)
51+
.def("register_traceback_file_inclusion",
52+
&PyGlobals::registerTracebackFileInclusion)
53+
.def("register_traceback_file_exclusion",
54+
&PyGlobals::registerTracebackFileExclusion);
4855

4956
// Aside from making the globals accessible to python, having python manage
5057
// it is necessary to make sure it is destroyed (and releases its python

mlir/test/python/ir/___location.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
import gc
4+
from contextlib import contextmanager
5+
46
from mlir.ir import *
7+
from mlir.dialects._ods_common import _cext
8+
from mlir.dialects import arith
59

610

711
def run(f):
812
print("\nTEST:", f.__name__)
913
f()
1014
gc.collect()
11-
assert Context._get_live_count() == 0
15+
# assert Context._get_live_count() == 0
1216

1317

1418
# CHECK-LABEL: TEST: testUnknown
@@ -27,6 +31,30 @@ def testUnknown():
2731
run(testUnknown)
2832

2933

34+
@contextmanager
35+
def with_infer_location():
36+
_cext.globals.set_loc_tracebacks_enabled(True)
37+
yield
38+
_cext.globals.set_loc_tracebacks_enabled(False)
39+
40+
41+
# CHECK-LABEL: TEST: testInferLocations
42+
def testInferLocations():
43+
with Context() as ctx, Location.unknown(), with_infer_location():
44+
ctx.allow_unregistered_dialects = True
45+
op = Operation.create("custom.op1")
46+
print(op.___location)
47+
# CHECK: loc(
48+
# CHECK-SAME callsite("testInferLocations"("{{.*}}/test/python/ir/___location.py":13:13 to 44:43)
49+
# CHECK-SAME at callsite("run"("{{.*}}/test/python/ir/___location.py":4:4 to 12:7)
50+
# CHECK-SAME at "<module>"("{{.*}}/test/python/ir/___location.py":0:0 to 50:23))))
51+
one = arith.constant(IndexType.get(), 1)
52+
53+
54+
55+
run(testInferLocations)
56+
57+
3058
# CHECK-LABEL: TEST: testLocationAttr
3159
def testLocationAttr():
3260
with Context() as ctxt:

0 commit comments

Comments
 (0)