Skip to content

Commit 61b0669

Browse files
committed
[mlir][python] auto-locs
1 parent 3212704 commit 61b0669

File tree

5 files changed

+139
-15
lines changed

5 files changed

+139
-15
lines changed

mlir/lib/Bindings/Python/Globals.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,23 @@ class PyGlobals {
114114
std::optional<nanobind::object>
115115
lookupOperationClass(llvm::StringRef operationName);
116116

117+
bool tracebacksEnabled() {
118+
nanobind::ft_lock_guard lock(mutex);
119+
return tracebackEnabled_;
120+
}
121+
122+
void setTracebacksEnabled(bool value) {
123+
nanobind::ft_lock_guard lock(mutex);
124+
tracebackEnabled_ = value;
125+
}
126+
117127
private:
118128
static PyGlobals *instance;
119129

120130
nanobind::ft_mutex mutex;
121131

132+
bool tracebackEnabled_ = false;
133+
122134
/// Module name prefixes to search under for dialect implementation modules.
123135
std::vector<std::string> dialectSearchPrefixes;
124136
/// Map of dialect namespace to external dialect class object.

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 102 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,7 +1523,7 @@ nb::object PyOperation::create(std::string_view name,
15231523
llvm::ArrayRef<MlirValue> operands,
15241524
std::optional<nb::dict> attributes,
15251525
std::optional<std::vector<PyBlock *>> successors,
1526-
int regions, DefaultingPyLocation ___location,
1526+
int regions, PyLocation ___location,
15271527
const nb::object &maybeIp, bool inferType) {
15281528
llvm::SmallVector<MlirType, 4> mlirResults;
15291529
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1627,7 +1627,7 @@ nb::object PyOperation::create(std::string_view name,
16271627
if (!operation.ptr)
16281628
throw nb::value_error("Operation creation failed");
16291629
PyOperationRef created =
1630-
PyOperation::createDetached(___location->getContext(), operation);
1630+
PyOperation::createDetached(___location.getContext(), operation);
16311631
maybeInsertOperation(created, maybeIp);
16321632

16331633
return created.getObject();
@@ -1937,9 +1937,9 @@ nb::object PyOpView::buildGeneric(
19371937
std::optional<nb::list> resultTypeList, nb::list operandList,
19381938
std::optional<nb::dict> attributes,
19391939
std::optional<std::vector<PyBlock *>> successors,
1940-
std::optional<int> regions, DefaultingPyLocation ___location,
1940+
std::optional<int> regions, PyLocation ___location,
19411941
const nb::object &maybeIp) {
1942-
PyMlirContextRef context = ___location->getContext();
1942+
PyMlirContextRef context = ___location.getContext();
19431943

19441944
// Class level operation construction metadata.
19451945
// Operand and result segment specs are either none, which does no
@@ -2789,6 +2789,70 @@ class PyOpAttributeMap {
27892789
PyOperationRef operation;
27902790
};
27912791

2792+
std::optional<MlirLocation> tracebackToLocation(MlirContext ctx) {
2793+
// We use a thread_local here mostly to avoid requiring a large amount of
2794+
// space.
2795+
size_t frames_limit = 100;
2796+
thread_local std::vector<MlirLocation> frames;
2797+
frames.reserve(frames_limit);
2798+
int count = 0;
2799+
2800+
assert(PyGILState_Check());
2801+
2802+
if (!PyGlobals::get().tracebacksEnabled())
2803+
return std::nullopt;
2804+
2805+
PyThreadState *thread_state = PyThreadState_GET();
2806+
2807+
PyFrameObject *next;
2808+
for (PyFrameObject *py_frame = PyThreadState_GetFrame(thread_state);
2809+
py_frame != nullptr && count < frames_limit; py_frame = next) {
2810+
PyCodeObject *code = PyFrame_GetCode(py_frame);
2811+
int lasti = PyFrame_GetLasti(py_frame);
2812+
MlirStringRef fileName = mlirStringRefCreateFromCString(
2813+
nb::borrow<nb::str>(code->co_filename).c_str());
2814+
2815+
#if PY_VERSION_HEX < 0x030b00f0
2816+
MlirStringRef funcName = mlirStringRefCreateFromCString(
2817+
nb::borrow<nb::str>(frame.code->co_name).c_str());
2818+
auto line = PyCode_Addr2Line(frame.code, frame.lasti);
2819+
auto loc = mlirLocationFileLineColGet(ctx, fileName, line, 0);
2820+
#else
2821+
MlirStringRef funcName = mlirStringRefCreateFromCString(
2822+
nb::borrow<nb::str>(code->co_qualname).c_str());
2823+
int start_line, start_column, end_line, end_column;
2824+
if (!PyCode_Addr2Location(code, lasti, &start_line, &start_column,
2825+
&end_line, &end_column)) {
2826+
throw nb::python_error();
2827+
}
2828+
auto loc = mlirLocationFileLineColRangeGet(
2829+
ctx, fileName, start_column, start_column, end_line, end_column);
2830+
#endif
2831+
2832+
frames.push_back(mlirLocationNameGet(ctx, funcName, loc));
2833+
++count;
2834+
next = PyFrame_GetBack(py_frame);
2835+
Py_XDECREF(py_frame);
2836+
2837+
if (frames.size() > frames_limit)
2838+
break;
2839+
}
2840+
2841+
if (frames.empty())
2842+
return mlirLocationUnknownGet(ctx);
2843+
if (frames.size() == 1)
2844+
return frames.front();
2845+
2846+
MlirLocation callee = frames.front();
2847+
frames.erase(frames.begin());
2848+
MlirLocation caller = frames.back();
2849+
for (const MlirLocation &frame :
2850+
llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2851+
caller = mlirLocationCallSiteGet(frame, caller);
2852+
2853+
return mlirLocationCallSiteGet(callee, caller);
2854+
}
2855+
27922856
} // namespace
27932857

27942858
//------------------------------------------------------------------------------
@@ -3241,6 +3305,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
32413305
.def_static(
32423306
"create",
32433307
[](DefaultingPyLocation loc) {
3308+
PyMlirContextRef ctx = loc->getContext();
3309+
MlirLocation mlirLoc = loc;
3310+
if (auto tloc = tracebackToLocation(ctx->get()))
3311+
mlirLoc = *tloc;
32443312
MlirModule module = mlirModuleCreateEmpty(loc);
32453313
return PyModule::forModule(module).releaseObject();
32463314
},
@@ -3467,9 +3535,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34673535
}
34683536
}
34693537

3538+
PyMlirContextRef ctx = ___location->getContext();
3539+
if (auto loc = tracebackToLocation(ctx->get())) {
3540+
return PyOperation::create(name, results, mlirOperands,
3541+
attributes, successors, regions,
3542+
{ctx, *loc}, maybeIp, inferType);
3543+
}
34703544
return PyOperation::create(name, results, mlirOperands, attributes,
3471-
successors, regions, ___location, maybeIp,
3472-
inferType);
3545+
successors, regions, *___location.get(),
3546+
maybeIp, inferType);
34733547
},
34743548
nb::arg("name"), nb::arg("results").none() = nb::none(),
34753549
nb::arg("operands").none() = nb::none(),
@@ -3514,10 +3588,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
35143588
std::optional<std::vector<PyBlock *>> successors,
35153589
std::optional<int> regions, DefaultingPyLocation ___location,
35163590
const nb::object &maybeIp) {
3517-
new (self) PyOpView(PyOpView::buildGeneric(
3518-
name, opRegionSpec, operandSegmentSpecObj,
3519-
resultSegmentSpecObj, resultTypeList, operandList,
3520-
attributes, successors, regions, ___location, maybeIp));
3591+
PyMlirContextRef ctx = ___location->getContext();
3592+
if (auto loc = tracebackToLocation(ctx->get())) {
3593+
new (self) PyOpView(PyOpView::buildGeneric(
3594+
name, opRegionSpec, operandSegmentSpecObj,
3595+
resultSegmentSpecObj, resultTypeList, operandList,
3596+
attributes, successors, regions, {ctx, *loc}, maybeIp));
3597+
} else {
3598+
new (self) PyOpView(PyOpView::buildGeneric(
3599+
name, opRegionSpec, operandSegmentSpecObj,
3600+
resultSegmentSpecObj, resultTypeList, operandList,
3601+
attributes, successors, regions, *___location.get(),
3602+
maybeIp));
3603+
}
35213604
},
35223605
nb::arg("name"), nb::arg("opRegionSpec"),
35233606
nb::arg("operandSegmentSpecObj").none() = nb::none(),
@@ -3558,10 +3641,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
35583641
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
35593642
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
35603643
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3644+
3645+
PyMlirContextRef ctx = ___location->getContext();
3646+
if (auto loc = tracebackToLocation(ctx->get())) {
3647+
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
3648+
resultSegmentSpec, resultTypeList,
3649+
operandList, attributes, successors,
3650+
regions, {ctx, *loc}, maybeIp);
3651+
}
35613652
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
35623653
resultSegmentSpec, resultTypeList,
35633654
operandList, attributes, successors,
3564-
regions, ___location, maybeIp);
3655+
regions, *___location.get(), maybeIp);
35653656
},
35663657
nb::arg("cls"), nb::arg("results").none() = nb::none(),
35673658
nb::arg("operands").none() = nb::none(),

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -722,8 +722,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
722722
llvm::ArrayRef<MlirValue> operands,
723723
std::optional<nanobind::dict> attributes,
724724
std::optional<std::vector<PyBlock *>> successors, int regions,
725-
DefaultingPyLocation ___location, const nanobind::object &ip,
726-
bool inferType);
725+
PyLocation ___location, const nanobind::object &ip, bool inferType);
727726

728727
/// Creates an OpView suitable for this operation.
729728
nanobind::object createOpView();
@@ -781,7 +780,7 @@ class PyOpView : public PyOperationBase {
781780
nanobind::list operandList,
782781
std::optional<nanobind::dict> attributes,
783782
std::optional<std::vector<PyBlock *>> successors,
784-
std::optional<int> regions, DefaultingPyLocation ___location,
783+
std::optional<int> regions, PyLocation ___location,
785784
const nanobind::object &maybeIp);
786785

787786
/// Construct an instance of a class deriving from OpView, bypassing its

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ NB_MODULE(_mlir, m) {
4444
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
4545
"operation_name"_a, "operation_class"_a, nb::kw_only(),
4646
"replace"_a = false,
47-
"Testing hook for directly registering an operation");
47+
"Testing hook for directly registering an operation")
48+
.def("tracebacks_enabled", &PyGlobals::tracebacksEnabled)
49+
.def("set_tracebacks_enabled", &PyGlobals::setTracebacksEnabled);
4850

4951
// Aside from making the globals accessible to python, having python manage
5052
// it is necessary to make sure it is destroyed (and releases its python

mlir/test/python/ir/___location.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
import gc
4+
from contextlib import contextmanager
5+
46
from mlir.ir import *
7+
from mlir.dialects._ods_common import _cext
58

69

710
def run(f):
@@ -27,6 +30,23 @@ def testUnknown():
2730
run(testUnknown)
2831

2932

33+
@contextmanager
34+
def with_infer_location():
35+
_cext.globals.set_tracebacks_enabled(True)
36+
yield
37+
_cext.globals.set_tracebacks_enabled(False)
38+
39+
40+
# CHECK-LABEL: TEST: testInferLocations
41+
def testInferLocations():
42+
with Context() as ctx, Location.unknown(), with_infer_location():
43+
ctx.allow_unregistered_dialects = True
44+
op = Operation.create("custom.op1")
45+
46+
47+
run(testInferLocations)
48+
49+
3050
# CHECK-LABEL: TEST: testLocationAttr
3151
def testLocationAttr():
3252
with Context() as ctxt:

0 commit comments

Comments
 (0)