Skip to content

Commit f63eaf2

Browse files
committed
[mlir][python] auto-locs
1 parent 3212704 commit f63eaf2

File tree

6 files changed

+250
-20
lines changed

6 files changed

+250
-20
lines changed

mlir/lib/Bindings/Python/Globals.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,19 @@
1010
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
1111

1212
#include <optional>
13+
#include <regex>
1314
#include <string>
15+
#include <unordered_set>
1416
#include <vector>
1517

1618
#include "NanobindUtils.h"
1719
#include "mlir-c/IR.h"
1820
#include "mlir/CAPI/Support.h"
1921
#include "llvm/ADT/DenseMap.h"
22+
#include "llvm/ADT/StringExtras.h"
2023
#include "llvm/ADT/StringRef.h"
2124
#include "llvm/ADT/StringSet.h"
25+
#include "llvm/Support/Regex.h"
2226

2327
namespace mlir {
2428
namespace python {
@@ -114,11 +118,66 @@ class PyGlobals {
114118
std::optional<nanobind::object>
115119
lookupOperationClass(llvm::StringRef operationName);
116120

121+
bool locTracebacksEnabled() {
122+
nanobind::ft_lock_guard lock(mutex);
123+
return locTracebackEnabled_;
124+
}
125+
126+
void setLocTracebacksEnabled(bool value) {
127+
nanobind::ft_lock_guard lock(mutex);
128+
locTracebackEnabled_ = value;
129+
}
130+
131+
size_t locTracebackFramesLimit() {
132+
nanobind::ft_lock_guard lock(mutex);
133+
return locTracebackFramesLimit_;
134+
}
135+
136+
void setLocTracebackFramesLimit(size_t value) {
137+
nanobind::ft_lock_guard lock(mutex);
138+
locTracebackFramesLimit_ = value;
139+
}
140+
141+
void registerTracebackFileInclusion(std::string &file) {
142+
nanobind::ft_lock_guard lock(mutex);
143+
userTracebackIncludeFiles.insert("^" + llvm::Regex::escape(file));
144+
userTracebackIncludeRegex.assign(
145+
llvm::join(userTracebackIncludeFiles, "|"));
146+
isUserTracebackFilenameCache.clear();
147+
}
148+
149+
void registerTracebackFileExclusion(std::string &file) {
150+
nanobind::ft_lock_guard lock(mutex);
151+
userTracebackExcludeFiles.insert("^" + llvm::Regex::escape(file));
152+
userTracebackExcludeRegex.assign(
153+
llvm::join(userTracebackExcludeFiles, "|"));
154+
isUserTracebackFilenameCache.clear();
155+
}
156+
157+
bool isUserTracebackFilename(llvm::StringRef file) {
158+
nanobind::ft_lock_guard lock(mutex);
159+
if (!isUserTracebackFilenameCache.contains(file)) {
160+
std::string fileStr = file.str();
161+
bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
162+
bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
163+
isUserTracebackFilenameCache[file] = include || !exclude;
164+
}
165+
return isUserTracebackFilenameCache[file];
166+
}
167+
117168
private:
118169
static PyGlobals *instance;
119170

120171
nanobind::ft_mutex mutex;
121172

173+
bool locTracebackEnabled_ = false;
174+
size_t locTracebackFramesLimit_ = 10;
175+
std::unordered_set<std::string> userTracebackIncludeFiles;
176+
std::unordered_set<std::string> userTracebackExcludeFiles;
177+
std::regex userTracebackIncludeRegex;
178+
std::regex userTracebackExcludeRegex;
179+
llvm::StringMap<bool> isUserTracebackFilenameCache;
180+
122181
/// Module name prefixes to search under for dialect implementation modules.
123182
std::vector<std::string> dialectSearchPrefixes;
124183
/// Map of dialect namespace to external dialect class object.

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 124 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@
2020
#include "nanobind/nanobind.h"
2121
#include "llvm/ADT/ArrayRef.h"
2222
#include "llvm/ADT/SmallVector.h"
23-
#include "llvm/Support/raw_ostream.h"
2423

2524
#include <optional>
26-
#include <system_error>
27-
#include <utility>
2825

2926
namespace nb = nanobind;
3027
using namespace nb::literals;
@@ -1523,7 +1520,7 @@ nb::object PyOperation::create(std::string_view name,
15231520
llvm::ArrayRef<MlirValue> operands,
15241521
std::optional<nb::dict> attributes,
15251522
std::optional<std::vector<PyBlock *>> successors,
1526-
int regions, DefaultingPyLocation ___location,
1523+
int regions, PyLocation ___location,
15271524
const nb::object &maybeIp, bool inferType) {
15281525
llvm::SmallVector<MlirType, 4> mlirResults;
15291526
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1627,7 +1624,7 @@ nb::object PyOperation::create(std::string_view name,
16271624
if (!operation.ptr)
16281625
throw nb::value_error("Operation creation failed");
16291626
PyOperationRef created =
1630-
PyOperation::createDetached(___location->getContext(), operation);
1627+
PyOperation::createDetached(___location.getContext(), operation);
16311628
maybeInsertOperation(created, maybeIp);
16321629

16331630
return created.getObject();
@@ -1937,9 +1934,9 @@ nb::object PyOpView::buildGeneric(
19371934
std::optional<nb::list> resultTypeList, nb::list operandList,
19381935
std::optional<nb::dict> attributes,
19391936
std::optional<std::vector<PyBlock *>> successors,
1940-
std::optional<int> regions, DefaultingPyLocation ___location,
1937+
std::optional<int> regions, PyLocation ___location,
19411938
const nb::object &maybeIp) {
1942-
PyMlirContextRef context = ___location->getContext();
1939+
PyMlirContextRef context = ___location.getContext();
19431940

19441941
// Class level operation construction metadata.
19451942
// Operand and result segment specs are either none, which does no
@@ -2789,6 +2786,91 @@ class PyOpAttributeMap {
27892786
PyOperationRef operation;
27902787
};
27912788

2789+
// bpo-40421 added PyFrame_GetLasti() to Python 3.11.0b1
2790+
#if PY_VERSION_HEX < 0x030b00b1 && !defined(PYPY_VERSION)
2791+
int PyFrame_GetLasti(PyFrameObject *frame) {
2792+
#if PY_VERSION_HEX >= 0x030a00a7
2793+
// bpo-27129: Since Python 3.10.0a7, f_lasti is an instruction offset,
2794+
// not a bytes offset anymore. Python uses 16-bit "wordcode" (2 bytes)
2795+
// instructions.
2796+
if (frame->f_lasti < 0) {
2797+
return -1;
2798+
}
2799+
return frame->f_lasti * 2;
2800+
#else
2801+
return frame->f_lasti;
2802+
#endif
2803+
}
2804+
#endif
2805+
2806+
constexpr size_t kMaxFrames = 512;
2807+
2808+
std::optional<MlirLocation> tracebackToLocation(MlirContext ctx) {
2809+
size_t framesLimit = PyGlobals::get().locTracebackFramesLimit();
2810+
// We use a thread_local here mostly to avoid requiring a large amount of
2811+
// space.
2812+
thread_local std::array<MlirLocation, kMaxFrames> frames;
2813+
size_t count = 0;
2814+
2815+
assert(PyGILState_Check());
2816+
2817+
if (!PyGlobals::get().locTracebacksEnabled())
2818+
return std::nullopt;
2819+
2820+
PyThreadState *tstate = PyThreadState_GET();
2821+
2822+
PyFrameObject *next;
2823+
for (PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2824+
pyFrame != nullptr && count < framesLimit;
2825+
next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
2826+
PyCodeObject *code = PyFrame_GetCode(pyFrame);
2827+
auto fileNameStr =
2828+
nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2829+
llvm::StringRef fileName(fileNameStr);
2830+
if (!PyGlobals::get().isUserTracebackFilename(fileName))
2831+
continue;
2832+
2833+
#if PY_VERSION_HEX < 0x030b00f0
2834+
std::string name =
2835+
nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2836+
llvm::StringRef funcName(name);
2837+
int startLine = PyFrame_GetLineNumber(pyFrame);
2838+
MlirLocation loc =
2839+
mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
2840+
#else
2841+
// co_qualname added in py3.11
2842+
std::string name =
2843+
nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2844+
llvm::StringRef funcName(name);
2845+
int startLine, startCol, endLine, endCol;
2846+
int lasti = PyFrame_GetLasti(pyFrame);
2847+
if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2848+
&endCol)) {
2849+
throw nb::python_error();
2850+
}
2851+
MlirLocation loc = mlirLocationFileLineColRangeGet(
2852+
ctx, wrap(fileName), startLine, startCol, endLine, endCol);
2853+
#endif
2854+
2855+
frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
2856+
++count;
2857+
if (count > framesLimit)
2858+
break;
2859+
}
2860+
2861+
if (count == 0)
2862+
return mlirLocationUnknownGet(ctx);
2863+
if (count == 1)
2864+
return frames.front();
2865+
2866+
MlirLocation callee = frames[0];
2867+
MlirLocation caller = frames[count - 1];
2868+
for (int i = count - 2; i >= 1; i--)
2869+
caller = mlirLocationCallSiteGet(frames[i], caller);
2870+
2871+
return mlirLocationCallSiteGet(callee, caller);
2872+
}
2873+
27922874
} // namespace
27932875

27942876
//------------------------------------------------------------------------------
@@ -3241,7 +3323,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
32413323
.def_static(
32423324
"create",
32433325
[](DefaultingPyLocation loc) {
3244-
MlirModule module = mlirModuleCreateEmpty(loc);
3326+
PyMlirContextRef ctx = loc->getContext();
3327+
MlirLocation mlirLoc = loc;
3328+
if (auto tloc = tracebackToLocation(ctx->get()))
3329+
mlirLoc = *tloc;
3330+
MlirModule module = mlirModuleCreateEmpty(mlirLoc);
32453331
return PyModule::forModule(module).releaseObject();
32463332
},
32473333
nb::arg("loc").none() = nb::none(), "Creates an empty module")
@@ -3467,9 +3553,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34673553
}
34683554
}
34693555

3556+
PyMlirContextRef ctx = ___location->getContext();
3557+
if (auto loc = tracebackToLocation(ctx->get())) {
3558+
return PyOperation::create(name, results, mlirOperands,
3559+
attributes, successors, regions,
3560+
{ctx, *loc}, maybeIp, inferType);
3561+
}
34703562
return PyOperation::create(name, results, mlirOperands, attributes,
3471-
successors, regions, ___location, maybeIp,
3472-
inferType);
3563+
successors, regions, *___location.get(),
3564+
maybeIp, inferType);
34733565
},
34743566
nb::arg("name"), nb::arg("results").none() = nb::none(),
34753567
nb::arg("operands").none() = nb::none(),
@@ -3514,10 +3606,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
35143606
std::optional<std::vector<PyBlock *>> successors,
35153607
std::optional<int> regions, DefaultingPyLocation ___location,
35163608
const nb::object &maybeIp) {
3517-
new (self) PyOpView(PyOpView::buildGeneric(
3518-
name, opRegionSpec, operandSegmentSpecObj,
3519-
resultSegmentSpecObj, resultTypeList, operandList,
3520-
attributes, successors, regions, ___location, maybeIp));
3609+
PyMlirContextRef ctx = ___location->getContext();
3610+
if (auto loc = tracebackToLocation(ctx->get())) {
3611+
new (self) PyOpView(PyOpView::buildGeneric(
3612+
name, opRegionSpec, operandSegmentSpecObj,
3613+
resultSegmentSpecObj, resultTypeList, operandList,
3614+
attributes, successors, regions, {ctx, *loc}, maybeIp));
3615+
} else {
3616+
new (self) PyOpView(PyOpView::buildGeneric(
3617+
name, opRegionSpec, operandSegmentSpecObj,
3618+
resultSegmentSpecObj, resultTypeList, operandList,
3619+
attributes, successors, regions, *___location.get(),
3620+
maybeIp));
3621+
}
35213622
},
35223623
nb::arg("name"), nb::arg("opRegionSpec"),
35233624
nb::arg("operandSegmentSpecObj").none() = nb::none(),
@@ -3558,10 +3659,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
35583659
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
35593660
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
35603661
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3662+
3663+
PyMlirContextRef ctx = ___location->getContext();
3664+
if (auto loc = tracebackToLocation(ctx->get())) {
3665+
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
3666+
resultSegmentSpec, resultTypeList,
3667+
operandList, attributes, successors,
3668+
regions, {ctx, *loc}, maybeIp);
3669+
}
35613670
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
35623671
resultSegmentSpec, resultTypeList,
35633672
operandList, attributes, successors,
3564-
regions, ___location, maybeIp);
3673+
regions, *___location.get(), maybeIp);
35653674
},
35663675
nb::arg("cls"), nb::arg("results").none() = nb::none(),
35673676
nb::arg("operands").none() = nb::none(),

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -722,8 +722,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
722722
llvm::ArrayRef<MlirValue> operands,
723723
std::optional<nanobind::dict> attributes,
724724
std::optional<std::vector<PyBlock *>> successors, int regions,
725-
DefaultingPyLocation ___location, const nanobind::object &ip,
726-
bool inferType);
725+
PyLocation ___location, const nanobind::object &ip, bool inferType);
727726

728727
/// Creates an OpView suitable for this operation.
729728
nanobind::object createOpView();
@@ -781,7 +780,7 @@ class PyOpView : public PyOperationBase {
781780
nanobind::list operandList,
782781
std::optional<nanobind::dict> attributes,
783782
std::optional<std::vector<PyBlock *>> successors,
784-
std::optional<int> regions, DefaultingPyLocation ___location,
783+
std::optional<int> regions, PyLocation ___location,
785784
const nanobind::object &maybeIp);
786785

787786
/// Construct an instance of a class deriving from OpView, bypassing its

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
109
#include "Globals.h"
1110
#include "IRModule.h"
1211
#include "NanobindUtils.h"
@@ -44,7 +43,15 @@ NB_MODULE(_mlir, m) {
4443
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
4544
"operation_name"_a, "operation_class"_a, nb::kw_only(),
4645
"replace"_a = false,
47-
"Testing hook for directly registering an operation");
46+
"Testing hook for directly registering an operation")
47+
.def("loc_tracebacks_enabled", &PyGlobals::locTracebacksEnabled)
48+
.def("set_loc_tracebacks_enabled", &PyGlobals::setLocTracebacksEnabled)
49+
.def("set_loc_tracebacks_frame_limit",
50+
&PyGlobals::setLocTracebackFramesLimit)
51+
.def("register_traceback_file_inclusion",
52+
&PyGlobals::registerTracebackFileInclusion)
53+
.def("register_traceback_file_exclusion",
54+
&PyGlobals::registerTracebackFileExclusion);
4855

4956
// Aside from making the globals accessible to python, having python manage
5057
// it is necessary to make sure it is destroyed (and releases its python

mlir/test/python/ir/auto_location.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
import gc
4+
from contextlib import contextmanager
5+
6+
from mlir.ir import *
7+
from mlir.dialects._ods_common import _cext
8+
from mlir.dialects import arith, _arith_ops_gen
9+
10+
11+
def run(f):
12+
print("\nTEST:", f.__name__)
13+
f()
14+
gc.collect()
15+
assert Context._get_live_count() == 0
16+
17+
18+
@contextmanager
19+
def with_infer_location():
20+
_cext.globals.set_loc_tracebacks_enabled(True)
21+
yield
22+
_cext.globals.set_loc_tracebacks_enabled(False)
23+
24+
25+
# CHECK-LABEL: TEST: testInferLocations
26+
def testInferLocations():
27+
with Context() as ctx, Location.unknown(), with_infer_location():
28+
ctx.allow_unregistered_dialects = True
29+
op = Operation.create("custom.op1")
30+
one = arith.constant(IndexType.get(), 1)
31+
_cext.globals.register_traceback_file_exclusion(arith.__file__)
32+
_cext.globals.register_traceback_file_exclusion(_arith_ops_gen.__file__)
33+
two = arith.constant(IndexType.get(), 2)
34+
35+
# CHECK: loc(callsite("testInferLocations"("{{.*}}/test/python/ir/auto_location.py":29:13 to :43)
36+
# CHECK-SAME: at callsite("run"("{{.*}}/test/python/ir/auto_location.py":13:4 to :7)
37+
# CHECK-SAME: at "<module>"("{{.*}}/test/python/ir/auto_location.py":54:0 to :23))))
38+
print(op.___location)
39+
40+
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}/mlir/dialects/_arith_ops_gen.py":404:4 to :218)
41+
# CHECK-SAME: at callsite("ConstantOp.__init__"("{{.*}}/mlir/dialects/arith.py":65:12 to :76)
42+
# CHECK-SAME: at callsite("constant"("{{.*}}/mlir/dialects/arith.py":110:40 to :81)
43+
# CHECK-SAME: at callsite("testInferLocations"("{{.*}}/test/python/ir/auto_location.py":30:14 to :48)
44+
# CHECK-SAME: at callsite("run"("{{.*}}/test/python/ir/auto_location.py":13:4 to :7)
45+
# CHECK-SAME: at "<module>"("{{.*}}/test/python/ir/auto_location.py":54:0 to :23)))))))
46+
print(one.___location)
47+
48+
# CHECK: loc(callsite("testInferLocations"("{{.*}}/test/python/ir/auto_location.py":33:14 to :48)
49+
# CHECK-SAME: at callsite("run"("{{.*}}/test/python/ir/auto_location.py":13:4 to :7)
50+
# CHECK-SAME: at "<module>"("{{.*}}/test/python/ir/auto_location.py":54:0 to :23))))
51+
print(two.___location)
52+
53+
54+
run(testInferLocations)

mlir/test/python/ir/lit.local.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
if "Windows" in config.host_os:
2+
config.excludes.add("auto_location.py")

0 commit comments

Comments
 (0)