Skip to content

[flang][CUDA] Apply intrinsic operator overrides #151018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion flang/include/flang/Semantics/semantics.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ class SemanticsContext {
warningsAreErrors_ = x;
return *this;
}

SemanticsContext &set_debugModuleWriter(bool x) {
debugModuleWriter_ = x;
return *this;
Expand Down
9 changes: 4 additions & 5 deletions flang/lib/Semantics/check-cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -761,14 +761,13 @@ void CUDAChecker::Enter(const parser::AssignmentStmt &x) {
// legal.
if (nbLhs == 0 && nbRhs > 1) {
context_.Say(lhsLoc,
"More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US);
"More than one reference to a CUDA object on the right hand side of the assignment"_err_en_US);
}

if (Fortran::evaluate::HasCUDADeviceAttrs(assign->lhs) &&
Fortran::evaluate::HasCUDAImplicitTransfer(assign->rhs)) {
if (evaluate::HasCUDADeviceAttrs(assign->lhs) &&
evaluate::HasCUDAImplicitTransfer(assign->rhs)) {
if (GetNbOfCUDAManagedOrUnifiedSymbols(assign->lhs) == 1 &&
GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 &&
GetNbOfCUDADeviceSymbols(assign->rhs) == 1) {
GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 && nbRhs == 1) {
return; // This is a special case handled on the host.
}
context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US);
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Semantics/check-declarations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2081,7 +2081,7 @@ static bool ConflictsWithIntrinsicAssignment(const Procedure &proc) {
}

static bool ConflictsWithIntrinsicOperator(
const GenericKind &kind, const Procedure &proc) {
const GenericKind &kind, const Procedure &proc, SemanticsContext &context) {
if (!kind.IsIntrinsicOperator()) {
return false;
}
Expand Down Expand Up @@ -2167,7 +2167,7 @@ bool CheckHelper::CheckDefinedOperator(SourceName opName, GenericKind kind,
}
} else if (!checkDefinedOperatorArgs(opName, specific, proc)) {
return false; // error was reported
} else if (ConflictsWithIntrinsicOperator(kind, proc)) {
} else if (ConflictsWithIntrinsicOperator(kind, proc, context_)) {
msg = "%s function '%s' conflicts with intrinsic operator"_err_en_US;
}
if (msg) {
Expand Down
77 changes: 70 additions & 7 deletions flang/lib/Semantics/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,17 @@ class ArgumentAnalyzer {
bool CheckForNullPointer(const char *where = "as an operand here");
bool CheckForAssumedRank(const char *where = "as an operand here");

bool AnyCUDADeviceData() const;
// Returns true if an interface has been defined for an intrinsic operator
// with one or more device operands.
bool HasDeviceDefinedIntrinsicOpOverride(const char *) const;
template <typename E> bool HasDeviceDefinedIntrinsicOpOverride(E opr) const {
return HasDeviceDefinedIntrinsicOpOverride(
context_.context().languageFeatures().GetNames(opr));
}

// Find and return a user-defined operator or report an error.
// The provided message is used if there is no such operator.
// If a definedOpSymbolPtr is provided, the caller must check
// for its accessibility.
MaybeExpr TryDefinedOp(
const char *, parser::MessageFixedText, bool isUserOp = false);
template <typename E>
Expand All @@ -183,6 +190,8 @@ class ArgumentAnalyzer {
void Dump(llvm::raw_ostream &);

private:
bool HasDeviceDefinedIntrinsicOpOverride(
const std::vector<const char *> &) const;
MaybeExpr TryDefinedOp(
const std::vector<const char *> &, parser::MessageFixedText);
MaybeExpr TryBoundOp(const Symbol &, int passIndex);
Expand All @@ -202,7 +211,7 @@ class ArgumentAnalyzer {
void SayNoMatch(
const std::string &, bool isAssignment = false, bool isAmbiguous = false);
std::string TypeAsFortran(std::size_t);
bool AnyUntypedOrMissingOperand();
bool AnyUntypedOrMissingOperand() const;

ExpressionAnalyzer &context_;
ActualArguments actuals_;
Expand Down Expand Up @@ -4497,13 +4506,20 @@ void ArgumentAnalyzer::Analyze(
bool ArgumentAnalyzer::IsIntrinsicRelational(RelationalOperator opr,
const DynamicType &leftType, const DynamicType &rightType) const {
CHECK(actuals_.size() == 2);
return semantics::IsIntrinsicRelational(
opr, leftType, GetRank(0), rightType, GetRank(1));
return !(context_.context().languageFeatures().IsEnabled(
common::LanguageFeature::CUDA) &&
HasDeviceDefinedIntrinsicOpOverride(opr)) &&
semantics::IsIntrinsicRelational(
opr, leftType, GetRank(0), rightType, GetRank(1));
}

bool ArgumentAnalyzer::IsIntrinsicNumeric(NumericOperator opr) const {
std::optional<DynamicType> leftType{GetType(0)};
if (actuals_.size() == 1) {
if (context_.context().languageFeatures().IsEnabled(
common::LanguageFeature::CUDA) &&
HasDeviceDefinedIntrinsicOpOverride(AsFortran(opr))) {
return false;
} else if (actuals_.size() == 1) {
if (IsBOZLiteral(0)) {
return opr == NumericOperator::Add; // unary '+'
} else {
Expand Down Expand Up @@ -4617,6 +4633,53 @@ bool ArgumentAnalyzer::CheckForAssumedRank(const char *where) {
return true;
}

bool ArgumentAnalyzer::AnyCUDADeviceData() const {
for (const std::optional<ActualArgument> &arg : actuals_) {
if (arg) {
if (const Expr<SomeType> *expr{arg->UnwrapExpr()}) {
if (HasCUDADeviceAttrs(*expr)) {
return true;
}
}
}
}
return false;
}

// Some operations can be defined with explicit non-type-bound interfaces
// that would erroneously conflict with intrinsic operations in their
// types and ranks but have one or more dummy arguments with the DEVICE
// attribute.
bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride(
const char *opr) const {
if (AnyCUDADeviceData() && !AnyUntypedOrMissingOperand()) {
std::string oprNameString{"operator("s + opr + ')'};
parser::CharBlock oprName{oprNameString};
parser::Messages buffer;
auto restorer{context_.GetContextualMessages().SetMessages(buffer)};
const auto &scope{context_.context().FindScope(source_)};
if (Symbol * generic{scope.FindSymbol(oprName)}) {
parser::Name name{generic->name(), generic};
const Symbol *resultSymbol{nullptr};
if (context_.AnalyzeDefinedOp(
name, ActualArguments{actuals_}, resultSymbol)) {
return true;
}
}
}
return false;
}

bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride(
const std::vector<const char *> &oprNames) const {
for (const char *opr : oprNames) {
if (HasDeviceDefinedIntrinsicOpOverride(opr)) {
return true;
}
}
return false;
}

MaybeExpr ArgumentAnalyzer::TryDefinedOp(
const char *opr, parser::MessageFixedText error, bool isUserOp) {
if (AnyUntypedOrMissingOperand()) {
Expand Down Expand Up @@ -5135,7 +5198,7 @@ std::string ArgumentAnalyzer::TypeAsFortran(std::size_t i) {
}
}

bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() {
bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() const {
for (const auto &actual : actuals_) {
if (!actual ||
(!actual->GetType() && !IsBareNullPointer(actual->UnwrapExpr()))) {
Expand Down
49 changes: 49 additions & 0 deletions flang/test/Semantics/bug1214.cuf
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
module overrides
type realResult
real a
end type
interface operator(*)
procedure :: multHostDevice, multDeviceHost
end interface
interface assignment(=)
procedure :: assignHostResult, assignDeviceResult
end interface
contains
elemental function multHostDevice(x, y) result(result)
real, intent(in) :: x
real, intent(in), device :: y
type(realResult) result
result%a = x * y
end
elemental function multDeviceHost(x, y) result(result)
real, intent(in), device :: x
real, intent(in) :: y
type(realResult) result
result%a = x * y
end
elemental subroutine assignHostResult(lhs, rhs)
real, intent(out) :: lhs
type(realResult), intent(in) :: rhs
lhs = rhs%a
end
elemental subroutine assignDeviceResult(lhs, rhs)
real, intent(out), device :: lhs
type(realResult), intent(in) :: rhs
lhs = rhs%a
end
end

program p
use overrides
real, device :: da, db
real :: ha, hb
!CHECK: CALL assigndeviceresult(db,multhostdevice(2._4,da))
db = 2. * da
!CHECK: CALL assigndeviceresult(db,multdevicehost(da,2._4))
db = da * 2.
!CHECK: CALL assignhostresult(ha,multhostdevice(2._4,da))
ha = 2. * da
!CHECK: CALL assignhostresult(ha,multdevicehost(da,2._4))
ha = da * 2.
end
2 changes: 1 addition & 1 deletion flang/test/Semantics/cuf11.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ subroutine sub1()
real, device :: adev(10), bdev(10)
real :: ahost(10)

!ERROR: More than one reference to a CUDA object on the right hand side of the assigment
!ERROR: More than one reference to a CUDA object on the right hand side of the assignment
ahost = adev + bdev

ahost = adev + adev
Expand Down