Skip to content

[mlir][python] automatic ___location inference #151246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@
#define MLIR_BINDINGS_PYTHON_GLOBALS_H

#include <optional>
#include <regex>
#include <string>
#include <unordered_set>
#include <vector>

#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Regex.h"

namespace mlir {
namespace python {
Expand Down Expand Up @@ -114,6 +118,37 @@ class PyGlobals {
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);

class TracebackLoc {
public:
bool locTracebacksEnabled();

void setLocTracebacksEnabled(bool value);

size_t locTracebackFramesLimit();

void setLocTracebackFramesLimit(size_t value);

void registerTracebackFileInclusion(const std::string &file);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to say the name would make me think its overriding and suggeste append. But I see this is consistent with the diagnostic handlers wording.


void registerTracebackFileExclusion(const std::string &file);

bool isUserTracebackFilename(llvm::StringRef file);

private:
nanobind::ft_mutex mutex;
bool locTracebackEnabled_ = false;
size_t locTracebackFramesLimit_ = 10;
std::unordered_set<std::string> userTracebackIncludeFiles;
std::unordered_set<std::string> userTracebackExcludeFiles;
std::regex userTracebackIncludeRegex;
bool rebuildUserTracebackIncludeRegex = false;
std::regex userTracebackExcludeRegex;
bool rebuildUserTracebackExcludeRegex = false;
llvm::StringMap<bool> isUserTracebackFilenameCache;
};

TracebackLoc &getTracebackLoc() { return tracebackLoc; }

private:
static PyGlobals *instance;

Expand All @@ -134,6 +169,8 @@ class PyGlobals {
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;

TracebackLoc tracebackLoc;
};

} // namespace python
Expand Down
116 changes: 101 additions & 15 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@
#include "nanobind/nanobind.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"

#include <optional>
#include <system_error>
#include <utility>

namespace nb = nanobind;
using namespace nb::literals;
Expand Down Expand Up @@ -1523,7 +1520,7 @@ nb::object PyOperation::create(std::string_view name,
llvm::ArrayRef<MlirValue> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
int regions, DefaultingPyLocation ___location,
int regions, PyLocation ___location,
const nb::object &maybeIp, bool inferType) {
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
Expand Down Expand Up @@ -1627,7 +1624,7 @@ nb::object PyOperation::create(std::string_view name,
if (!operation.ptr)
throw nb::value_error("Operation creation failed");
PyOperationRef created =
PyOperation::createDetached(___location->getContext(), operation);
PyOperation::createDetached(___location.getContext(), operation);
maybeInsertOperation(created, maybeIp);

return created.getObject();
Expand Down Expand Up @@ -1937,9 +1934,9 @@ nb::object PyOpView::buildGeneric(
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation ___location,
std::optional<int> regions, PyLocation ___location,
const nb::object &maybeIp) {
PyMlirContextRef context = ___location->getContext();
PyMlirContextRef context = ___location.getContext();

// Class level operation construction metadata.
// Operand and result segment specs are either none, which does no
Expand Down Expand Up @@ -2789,6 +2786,91 @@ class PyOpAttributeMap {
PyOperationRef operation;
};

constexpr size_t kMaxFrames = 512;

MlirLocation tracebackToLocation(MlirContext ctx) {
size_t framesLimit =
PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
// Use a thread_local here to avoid requiring a large amount of space.
thread_local std::array<MlirLocation, kMaxFrames> frames;
size_t count = 0;

assert(PyGILState_Check());

PyThreadState *tstate = PyThreadState_GET();

PyFrameObject *next;
for (PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
pyFrame != nullptr && count < framesLimit;
next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
PyCodeObject *code = PyFrame_GetCode(pyFrame);
auto fileNameStr =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
llvm::StringRef fileName(fileNameStr);
if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
continue;

#if PY_VERSION_HEX < 0x030b00f0
std::string name =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
llvm::StringRef funcName(name);
int startLine = PyFrame_GetLineNumber(pyFrame);
MlirLocation loc =
mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
#else
// co_qualname and PyCode_Addr2Location added in py3.11
std::string name =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
llvm::StringRef funcName(name);
int startLine, startCol, endLine, endCol;
int lasti = PyFrame_GetLasti(pyFrame);
if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
&endCol)) {
throw nb::python_error();
}
MlirLocation loc = mlirLocationFileLineColRangeGet(
ctx, wrap(fileName), startLine, startCol, endLine, endCol);
#endif

frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
++count;
if (count > framesLimit)
break;
}

if (count == 0)
return mlirLocationUnknownGet(ctx);
if (count == 1)
return frames[0];

MlirLocation callee = frames[0];
MlirLocation caller = frames[count - 1];
for (int i = count - 2; i >= 1; i--)
caller = mlirLocationCallSiteGet(frames[i], caller);

return mlirLocationCallSiteGet(callee, caller);
}

PyLocation
maybeGetTracebackLocation(const std::optional<PyLocation> &___location) {
MlirLocation mlirLoc;
MlirContext mlirCtx;
if (!___location.has_value() &&
PyGlobals::get().getTracebackLoc().locTracebacksEnabled()) {
mlirCtx = DefaultingPyMlirContext::resolve().get();
mlirLoc = tracebackToLocation(mlirCtx);
} else if (!___location.has_value()) {
mlirLoc = DefaultingPyLocation::resolve();
mlirCtx = mlirLocationGetContext(mlirLoc);
} else {
mlirLoc = *___location;
mlirCtx = mlirLocationGetContext(mlirLoc);
}
assert(!mlirLocationIsNull(mlirLoc) && "expected non-null mlirLoc");
PyMlirContextRef ctx = PyMlirContext::forContext(mlirCtx);
return {ctx, mlirLoc};
}

} // namespace

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -3240,8 +3322,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kModuleParseDocstring)
.def_static(
"create",
[](DefaultingPyLocation loc) {
MlirModule module = mlirModuleCreateEmpty(loc);
[](std::optional<PyLocation> loc) {
PyLocation pyLoc = maybeGetTracebackLocation(loc);
MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
return PyModule::forModule(module).releaseObject();
},
nb::arg("loc").none() = nb::none(), "Creates an empty module")
Expand Down Expand Up @@ -3454,7 +3537,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<std::vector<PyValue *>> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
DefaultingPyLocation ___location, const nb::object &maybeIp,
std::optional<PyLocation> ___location, const nb::object &maybeIp,
bool inferType) {
// Unpack/validate operands.
llvm::SmallVector<MlirValue, 4> mlirOperands;
Expand All @@ -3467,8 +3550,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
}
}

PyLocation pyLoc = maybeGetTracebackLocation(___location);
return PyOperation::create(name, results, mlirOperands, attributes,
successors, regions, ___location, maybeIp,
successors, regions, pyLoc, maybeIp,
inferType);
},
nb::arg("name"), nb::arg("results").none() = nb::none(),
Expand Down Expand Up @@ -3512,12 +3596,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation ___location,
std::optional<int> regions, std::optional<PyLocation> ___location,
const nb::object &maybeIp) {
PyLocation pyLoc = maybeGetTracebackLocation(___location);
new (self) PyOpView(PyOpView::buildGeneric(
name, opRegionSpec, operandSegmentSpecObj,
resultSegmentSpecObj, resultTypeList, operandList,
attributes, successors, regions, ___location, maybeIp));
attributes, successors, regions, pyLoc, maybeIp));
},
nb::arg("name"), nb::arg("opRegionSpec"),
nb::arg("operandSegmentSpecObj").none() = nb::none(),
Expand Down Expand Up @@ -3551,17 +3636,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](nb::handle cls, std::optional<nb::list> resultTypeList,
nb::list operandList, std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation ___location,
std::optional<int> regions, std::optional<PyLocation> ___location,
const nb::object &maybeIp) {
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
std::tuple<int, bool> opRegionSpec =
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
PyLocation pyLoc = maybeGetTracebackLocation(___location);
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
resultSegmentSpec, resultTypeList,
operandList, attributes, successors,
regions, ___location, maybeIp);
regions, pyLoc, maybeIp);
},
nb::arg("cls"), nb::arg("results").none() = nb::none(),
nb::arg("operands").none() = nb::none(),
Expand Down
70 changes: 69 additions & 1 deletion mlir/lib/Bindings/Python/IRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

#include "Globals.h"
#include "NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.

namespace nb = nanobind;
using namespace mlir;
Expand Down Expand Up @@ -197,3 +197,71 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
// Not found and loading did not yield a registration.
return std::nullopt;
}

bool PyGlobals::TracebackLoc::locTracebacksEnabled() {
nanobind::ft_lock_guard lock(mutex);
return locTracebackEnabled_;
}

void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) {
nanobind::ft_lock_guard lock(mutex);
locTracebackEnabled_ = value;
}

size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() {
nanobind::ft_lock_guard lock(mutex);
return locTracebackFramesLimit_;
}

void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) {
nanobind::ft_lock_guard lock(mutex);
locTracebackFramesLimit_ = value;
}

void PyGlobals::TracebackLoc::registerTracebackFileInclusion(
const std::string &file) {
nanobind::ft_lock_guard lock(mutex);
auto reg = "^" + llvm::Regex::escape(file);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this always has to match prefix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the choice JAX made so I went with it - I think it's reasonable wrt being conservative.

if (userTracebackIncludeFiles.insert(reg).second)
rebuildUserTracebackIncludeRegex = true;
if (userTracebackExcludeFiles.count(reg)) {
if (userTracebackExcludeFiles.erase(reg))
rebuildUserTracebackExcludeRegex = true;
}
}

void PyGlobals::TracebackLoc::registerTracebackFileExclusion(
const std::string &file) {
nanobind::ft_lock_guard lock(mutex);
auto reg = "^" + llvm::Regex::escape(file);
if (userTracebackExcludeFiles.insert(reg).second)
rebuildUserTracebackExcludeRegex = true;
if (userTracebackIncludeFiles.count(reg)) {
if (userTracebackIncludeFiles.erase(reg))
rebuildUserTracebackIncludeRegex = true;
}
}

bool PyGlobals::TracebackLoc::isUserTracebackFilename(
const llvm::StringRef file) {
nanobind::ft_lock_guard lock(mutex);
if (rebuildUserTracebackIncludeRegex) {
userTracebackIncludeRegex.assign(
llvm::join(userTracebackIncludeFiles, "|"));
rebuildUserTracebackIncludeRegex = false;
isUserTracebackFilenameCache.clear();
}
if (rebuildUserTracebackExcludeRegex) {
userTracebackExcludeRegex.assign(
llvm::join(userTracebackExcludeFiles, "|"));
rebuildUserTracebackExcludeRegex = false;
isUserTracebackFilenameCache.clear();
}
if (!isUserTracebackFilenameCache.contains(file)) {
std::string fileStr = file.str();
bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
isUserTracebackFilenameCache[file] = include || !exclude;
}
return isUserTracebackFilenameCache[file];
}
5 changes: 2 additions & 3 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
llvm::ArrayRef<MlirValue> operands,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
DefaultingPyLocation ___location, const nanobind::object &ip,
bool inferType);
PyLocation ___location, const nanobind::object &ip, bool inferType);

/// Creates an OpView suitable for this operation.
nanobind::object createOpView();
Expand Down Expand Up @@ -781,7 +780,7 @@ class PyOpView : public PyOperationBase {
nanobind::list operandList,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation ___location,
std::optional<int> regions, PyLocation ___location,
const nanobind::object &maybeIp);

/// Construct an instance of a class deriving from OpView, bypassing its
Expand Down
Loading