Skip to content

Commit 00976c9

Browse files
committed
[flang][CUDA] Apply intrinsic operator overrides
Fortran's intrinsic numeric and relational operators can be overridden with explicit interfaces so long as one or more of the dummy arguments have the DEVICE attribute. Semantics already allows this without complaint, but fails to replace the operations with the defined specific procedure calls when analyzing expressions.
1 parent f6e70c7 commit 00976c9

File tree

6 files changed

+126
-16
lines changed

6 files changed

+126
-16
lines changed

flang/include/flang/Semantics/semantics.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ class SemanticsContext {
162162
warningsAreErrors_ = x;
163163
return *this;
164164
}
165-
166165
SemanticsContext &set_debugModuleWriter(bool x) {
167166
debugModuleWriter_ = x;
168167
return *this;

flang/lib/Semantics/check-cuda.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -761,14 +761,13 @@ void CUDAChecker::Enter(const parser::AssignmentStmt &x) {
761761
// legal.
762762
if (nbLhs == 0 && nbRhs > 1) {
763763
context_.Say(lhsLoc,
764-
"More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US);
764+
"More than one reference to a CUDA object on the right hand side of the assignment"_err_en_US);
765765
}
766766

767-
if (Fortran::evaluate::HasCUDADeviceAttrs(assign->lhs) &&
768-
Fortran::evaluate::HasCUDAImplicitTransfer(assign->rhs)) {
767+
if (evaluate::HasCUDADeviceAttrs(assign->lhs) &&
768+
evaluate::HasCUDAImplicitTransfer(assign->rhs)) {
769769
if (GetNbOfCUDAManagedOrUnifiedSymbols(assign->lhs) == 1 &&
770-
GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 &&
771-
GetNbOfCUDADeviceSymbols(assign->rhs) == 1) {
770+
GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 && nbRhs == 1) {
772771
return; // This is a special case handled on the host.
773772
}
774773
context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US);

flang/lib/Semantics/check-declarations.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,7 +2081,7 @@ static bool ConflictsWithIntrinsicAssignment(const Procedure &proc) {
20812081
}
20822082

20832083
static bool ConflictsWithIntrinsicOperator(
2084-
const GenericKind &kind, const Procedure &proc) {
2084+
const GenericKind &kind, const Procedure &proc, SemanticsContext &context) {
20852085
if (!kind.IsIntrinsicOperator()) {
20862086
return false;
20872087
}
@@ -2167,7 +2167,7 @@ bool CheckHelper::CheckDefinedOperator(SourceName opName, GenericKind kind,
21672167
}
21682168
} else if (!checkDefinedOperatorArgs(opName, specific, proc)) {
21692169
return false; // error was reported
2170-
} else if (ConflictsWithIntrinsicOperator(kind, proc)) {
2170+
} else if (ConflictsWithIntrinsicOperator(kind, proc, context_)) {
21712171
msg = "%s function '%s' conflicts with intrinsic operator"_err_en_US;
21722172
}
21732173
if (msg) {

flang/lib/Semantics/expression.cpp

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,17 @@ class ArgumentAnalyzer {
165165
bool CheckForNullPointer(const char *where = "as an operand here");
166166
bool CheckForAssumedRank(const char *where = "as an operand here");
167167

168+
bool AnyCUDADeviceData() const;
169+
// Returns true if an interface has been defined for an intrinsic operator
170+
// with one or more device operands.
171+
bool HasDeviceDefinedIntrinsicOpOverride(const char *) const;
172+
template <typename E> bool HasDeviceDefinedIntrinsicOpOverride(E opr) const {
173+
return HasDeviceDefinedIntrinsicOpOverride(
174+
context_.context().languageFeatures().GetNames(opr));
175+
}
176+
168177
// Find and return a user-defined operator or report an error.
169178
// The provided message is used if there is no such operator.
170-
// If a definedOpSymbolPtr is provided, the caller must check
171-
// for its accessibility.
172179
MaybeExpr TryDefinedOp(
173180
const char *, parser::MessageFixedText, bool isUserOp = false);
174181
template <typename E>
@@ -183,6 +190,8 @@ class ArgumentAnalyzer {
183190
void Dump(llvm::raw_ostream &);
184191

185192
private:
193+
bool HasDeviceDefinedIntrinsicOpOverride(
194+
const std::vector<const char *> &) const;
186195
MaybeExpr TryDefinedOp(
187196
const std::vector<const char *> &, parser::MessageFixedText);
188197
MaybeExpr TryBoundOp(const Symbol &, int passIndex);
@@ -202,7 +211,7 @@ class ArgumentAnalyzer {
202211
void SayNoMatch(
203212
const std::string &, bool isAssignment = false, bool isAmbiguous = false);
204213
std::string TypeAsFortran(std::size_t);
205-
bool AnyUntypedOrMissingOperand();
214+
bool AnyUntypedOrMissingOperand() const;
206215

207216
ExpressionAnalyzer &context_;
208217
ActualArguments actuals_;
@@ -4497,13 +4506,20 @@ void ArgumentAnalyzer::Analyze(
44974506
bool ArgumentAnalyzer::IsIntrinsicRelational(RelationalOperator opr,
44984507
const DynamicType &leftType, const DynamicType &rightType) const {
44994508
CHECK(actuals_.size() == 2);
4500-
return semantics::IsIntrinsicRelational(
4501-
opr, leftType, GetRank(0), rightType, GetRank(1));
4509+
return !(context_.context().languageFeatures().IsEnabled(
4510+
common::LanguageFeature::CUDA) &&
4511+
HasDeviceDefinedIntrinsicOpOverride(opr)) &&
4512+
semantics::IsIntrinsicRelational(
4513+
opr, leftType, GetRank(0), rightType, GetRank(1));
45024514
}
45034515

45044516
bool ArgumentAnalyzer::IsIntrinsicNumeric(NumericOperator opr) const {
45054517
std::optional<DynamicType> leftType{GetType(0)};
4506-
if (actuals_.size() == 1) {
4518+
if (context_.context().languageFeatures().IsEnabled(
4519+
common::LanguageFeature::CUDA) &&
4520+
HasDeviceDefinedIntrinsicOpOverride(AsFortran(opr))) {
4521+
return false;
4522+
} else if (actuals_.size() == 1) {
45074523
if (IsBOZLiteral(0)) {
45084524
return opr == NumericOperator::Add; // unary '+'
45094525
} else {
@@ -4617,6 +4633,53 @@ bool ArgumentAnalyzer::CheckForAssumedRank(const char *where) {
46174633
return true;
46184634
}
46194635

4636+
bool ArgumentAnalyzer::AnyCUDADeviceData() const {
4637+
for (const std::optional<ActualArgument> &arg : actuals_) {
4638+
if (arg) {
4639+
if (const Expr<SomeType> *expr{arg->UnwrapExpr()}) {
4640+
if (HasCUDADeviceAttrs(*expr)) {
4641+
return true;
4642+
}
4643+
}
4644+
}
4645+
}
4646+
return false;
4647+
}
4648+
4649+
// Some operations can be defined with explicit non-type-bound interfaces
4650+
// that would erroneously conflict with intrinsic operations in their
4651+
// types and ranks but have one or more dummy arguments with the DEVICE
4652+
// attribute.
4653+
bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride(
4654+
const char *opr) const {
4655+
if (AnyCUDADeviceData() && !AnyUntypedOrMissingOperand()) {
4656+
std::string oprNameString{"operator("s + opr + ')'};
4657+
parser::CharBlock oprName{oprNameString};
4658+
parser::Messages buffer;
4659+
auto restorer{context_.GetContextualMessages().SetMessages(buffer)};
4660+
const auto &scope{context_.context().FindScope(source_)};
4661+
if (Symbol * generic{scope.FindSymbol(oprName)}) {
4662+
parser::Name name{generic->name(), generic};
4663+
const Symbol *resultSymbol{nullptr};
4664+
if (context_.AnalyzeDefinedOp(
4665+
name, ActualArguments{actuals_}, resultSymbol)) {
4666+
return true;
4667+
}
4668+
}
4669+
}
4670+
return false;
4671+
}
4672+
4673+
bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride(
4674+
const std::vector<const char *> &oprNames) const {
4675+
for (const char *opr : oprNames) {
4676+
if (HasDeviceDefinedIntrinsicOpOverride(opr)) {
4677+
return true;
4678+
}
4679+
}
4680+
return false;
4681+
}
4682+
46204683
MaybeExpr ArgumentAnalyzer::TryDefinedOp(
46214684
const char *opr, parser::MessageFixedText error, bool isUserOp) {
46224685
if (AnyUntypedOrMissingOperand()) {
@@ -5135,7 +5198,7 @@ std::string ArgumentAnalyzer::TypeAsFortran(std::size_t i) {
51355198
}
51365199
}
51375200

5138-
bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() {
5201+
bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() const {
51395202
for (const auto &actual : actuals_) {
51405203
if (!actual ||
51415204
(!actual->GetType() && !IsBareNullPointer(actual->UnwrapExpr()))) {

flang/test/Semantics/bug1214.cuf

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
2+
module overrides
3+
type realResult
4+
real a
5+
end type
6+
interface operator(*)
7+
procedure :: multHostDevice, multDeviceHost
8+
end interface
9+
interface assignment(=)
10+
procedure :: assignHostResult, assignDeviceResult
11+
end interface
12+
contains
13+
elemental function multHostDevice(x, y) result(result)
14+
real, intent(in) :: x
15+
real, intent(in), device :: y
16+
type(realResult) result
17+
result%a = x * y
18+
end
19+
elemental function multDeviceHost(x, y) result(result)
20+
real, intent(in), device :: x
21+
real, intent(in) :: y
22+
type(realResult) result
23+
result%a = x * y
24+
end
25+
elemental subroutine assignHostResult(lhs, rhs)
26+
real, intent(out) :: lhs
27+
type(realResult), intent(in) :: rhs
28+
lhs = rhs%a
29+
end
30+
elemental subroutine assignDeviceResult(lhs, rhs)
31+
real, intent(out), device :: lhs
32+
type(realResult), intent(in) :: rhs
33+
lhs = rhs%a
34+
end
35+
end
36+
37+
program p
38+
use overrides
39+
real, device :: da, db
40+
real :: ha, hb
41+
!CHECK: CALL assigndeviceresult(db,multhostdevice(2._4,da))
42+
db = 2. * da
43+
!CHECK: CALL assigndeviceresult(db,multdevicehost(da,2._4))
44+
db = da * 2.
45+
!CHECK: CALL assignhostresult(ha,multhostdevice(2._4,da))
46+
ha = 2. * da
47+
!CHECK: CALL assignhostresult(ha,multdevicehost(da,2._4))
48+
ha = da * 2.
49+
end

flang/test/Semantics/cuf11.cuf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ subroutine sub1()
1616
real, device :: adev(10), bdev(10)
1717
real :: ahost(10)
1818

19-
!ERROR: More than one reference to a CUDA object on the right hand side of the assigment
19+
!ERROR: More than one reference to a CUDA object on the right hand side of the assignment
2020
ahost = adev + bdev
2121

2222
ahost = adev + adev

0 commit comments

Comments
 (0)