Skip to content

Commit ea70a83

Browse files
committed
[flang][CUDA] Apply intrinsic operator overrides
Fortran's intrinsic numeric and relational operators can be overridden with explicit interfaces so long as one or more of the dummy arguments have the DEVICE attribute. Semantics already allows this without complaint, but fails to replace the operations with the defined specific procedure calls when analyzing expressions.
1 parent f6e70c7 commit ea70a83

File tree

6 files changed

+134
-16
lines changed

6 files changed

+134
-16
lines changed

flang/include/flang/Semantics/semantics.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ class SemanticsContext {
123123
std::map<const Symbol *, SourceName> &moduleFileOutputRenamings() {
124124
return moduleFileOutputRenamings_;
125125
}
126+
bool anyCUDAHostDeviceIntrinsicOperatorOverride() const {
127+
return anyCUDAHostDeviceIntrinsicOperatorOverride_;
128+
}
126129

127130
SemanticsContext &set_location(
128131
const std::optional<parser::CharBlock> &___location) {
@@ -162,11 +165,15 @@ class SemanticsContext {
162165
warningsAreErrors_ = x;
163166
return *this;
164167
}
165-
166168
SemanticsContext &set_debugModuleWriter(bool x) {
167169
debugModuleWriter_ = x;
168170
return *this;
169171
}
172+
SemanticsContext &set_anyCUDAHostDeviceIntrinsicOperatorOverride(
173+
bool yes = true) {
174+
anyCUDAHostDeviceIntrinsicOperatorOverride_ = yes;
175+
return *this;
176+
}
170177

171178
const DeclTypeSpec &MakeNumericType(TypeCategory, int kind = 0);
172179
const DeclTypeSpec &MakeLogicalType(int kind = 0);
@@ -352,6 +359,7 @@ class SemanticsContext {
352359
std::map<const Symbol *, SourceName> moduleFileOutputRenamings_;
353360
UnorderedSymbolSet isDefined_;
354361
std::list<ProgramTree> programTrees_;
362+
bool anyCUDAHostDeviceIntrinsicOperatorOverride_{false};
355363
};
356364

357365
class Semantics {

flang/lib/Semantics/check-cuda.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -761,14 +761,13 @@ void CUDAChecker::Enter(const parser::AssignmentStmt &x) {
761761
// legal.
762762
if (nbLhs == 0 && nbRhs > 1) {
763763
context_.Say(lhsLoc,
764-
"More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US);
764+
"More than one reference to a CUDA object on the right hand side of the assignment"_err_en_US);
765765
}
766766

767-
if (Fortran::evaluate::HasCUDADeviceAttrs(assign->lhs) &&
768-
Fortran::evaluate::HasCUDAImplicitTransfer(assign->rhs)) {
767+
if (evaluate::HasCUDADeviceAttrs(assign->lhs) &&
768+
evaluate::HasCUDAImplicitTransfer(assign->rhs)) {
769769
if (GetNbOfCUDAManagedOrUnifiedSymbols(assign->lhs) == 1 &&
770-
GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 &&
771-
GetNbOfCUDADeviceSymbols(assign->rhs) == 1) {
770+
GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 && nbRhs == 1) {
772771
return; // This is a special case handled on the host.
773772
}
774773
context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US);

flang/lib/Semantics/check-declarations.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,12 +2081,13 @@ static bool ConflictsWithIntrinsicAssignment(const Procedure &proc) {
20812081
}
20822082

20832083
static bool ConflictsWithIntrinsicOperator(
2084-
const GenericKind &kind, const Procedure &proc) {
2084+
const GenericKind &kind, const Procedure &proc, SemanticsContext &context) {
20852085
if (!kind.IsIntrinsicOperator()) {
20862086
return false;
20872087
}
20882088
const auto &arg0Data{std::get<DummyDataObject>(proc.dummyArguments[0].u)};
20892089
if (CUDAHostDeviceDiffer(proc, arg0Data)) {
2090+
context.set_anyCUDAHostDeviceIntrinsicOperatorOverride();
20902091
return false;
20912092
}
20922093
const auto &arg0TnS{arg0Data.type};
@@ -2167,7 +2168,7 @@ bool CheckHelper::CheckDefinedOperator(SourceName opName, GenericKind kind,
21672168
}
21682169
} else if (!checkDefinedOperatorArgs(opName, specific, proc)) {
21692170
return false; // error was reported
2170-
} else if (ConflictsWithIntrinsicOperator(kind, proc)) {
2171+
} else if (ConflictsWithIntrinsicOperator(kind, proc, context_)) {
21712172
msg = "%s function '%s' conflicts with intrinsic operator"_err_en_US;
21722173
}
21732174
if (msg) {

flang/lib/Semantics/expression.cpp

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,17 @@ class ArgumentAnalyzer {
165165
bool CheckForNullPointer(const char *where = "as an operand here");
166166
bool CheckForAssumedRank(const char *where = "as an operand here");
167167

168+
bool AnyCUDADeviceData() const;
169+
// Returns true if an interface has been defined for an intrinsic operator
170+
// with one or more device operands.
171+
bool HasDeviceDefinedIntrinsicOpOverride(const char *) const;
172+
template <typename E> bool HasDeviceDefinedIntrinsicOpOverride(E opr) const {
173+
return HasDeviceDefinedIntrinsicOpOverride(
174+
context_.context().languageFeatures().GetNames(opr));
175+
}
176+
168177
// Find and return a user-defined operator or report an error.
169178
// The provided message is used if there is no such operator.
170-
// If a definedOpSymbolPtr is provided, the caller must check
171-
// for its accessibility.
172179
MaybeExpr TryDefinedOp(
173180
const char *, parser::MessageFixedText, bool isUserOp = false);
174181
template <typename E>
@@ -183,6 +190,8 @@ class ArgumentAnalyzer {
183190
void Dump(llvm::raw_ostream &);
184191

185192
private:
193+
bool HasDeviceDefinedIntrinsicOpOverride(
194+
const std::vector<const char *> &) const;
186195
MaybeExpr TryDefinedOp(
187196
const std::vector<const char *> &, parser::MessageFixedText);
188197
MaybeExpr TryBoundOp(const Symbol &, int passIndex);
@@ -202,7 +211,7 @@ class ArgumentAnalyzer {
202211
void SayNoMatch(
203212
const std::string &, bool isAssignment = false, bool isAmbiguous = false);
204213
std::string TypeAsFortran(std::size_t);
205-
bool AnyUntypedOrMissingOperand();
214+
bool AnyUntypedOrMissingOperand() const;
206215

207216
ExpressionAnalyzer &context_;
208217
ActualArguments actuals_;
@@ -4497,13 +4506,18 @@ void ArgumentAnalyzer::Analyze(
44974506
bool ArgumentAnalyzer::IsIntrinsicRelational(RelationalOperator opr,
44984507
const DynamicType &leftType, const DynamicType &rightType) const {
44994508
CHECK(actuals_.size() == 2);
4500-
return semantics::IsIntrinsicRelational(
4501-
opr, leftType, GetRank(0), rightType, GetRank(1));
4509+
return !(context_.context().anyCUDAHostDeviceIntrinsicOperatorOverride() &&
4510+
HasDeviceDefinedIntrinsicOpOverride(opr)) &&
4511+
semantics::IsIntrinsicRelational(
4512+
opr, leftType, GetRank(0), rightType, GetRank(1));
45024513
}
45034514

45044515
bool ArgumentAnalyzer::IsIntrinsicNumeric(NumericOperator opr) const {
45054516
std::optional<DynamicType> leftType{GetType(0)};
4506-
if (actuals_.size() == 1) {
4517+
if (context_.context().anyCUDAHostDeviceIntrinsicOperatorOverride() &&
4518+
HasDeviceDefinedIntrinsicOpOverride(AsFortran(opr))) {
4519+
return false;
4520+
} else if (actuals_.size() == 1) {
45074521
if (IsBOZLiteral(0)) {
45084522
return opr == NumericOperator::Add; // unary '+'
45094523
} else {
@@ -4617,6 +4631,53 @@ bool ArgumentAnalyzer::CheckForAssumedRank(const char *where) {
46174631
return true;
46184632
}
46194633

4634+
bool ArgumentAnalyzer::AnyCUDADeviceData() const {
4635+
for (const std::optional<ActualArgument> &arg : actuals_) {
4636+
if (arg) {
4637+
if (const Expr<SomeType> *expr{arg->UnwrapExpr()}) {
4638+
if (HasCUDADeviceAttrs(*expr)) {
4639+
return true;
4640+
}
4641+
}
4642+
}
4643+
}
4644+
return false;
4645+
}
4646+
4647+
// Some operations can be defined with explicit non-type-bound interfaces
4648+
// that would erroneously conflict with intrinsic operations in their
4649+
// types and ranks but have one or more dummy arguments with the DEVICE
4650+
// attribute.
4651+
bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride(
4652+
const char *opr) const {
4653+
if (AnyCUDADeviceData() && !AnyUntypedOrMissingOperand()) {
4654+
std::string oprNameString{"operator("s + opr + ')'};
4655+
parser::CharBlock oprName{oprNameString};
4656+
parser::Messages buffer;
4657+
auto restorer{context_.GetContextualMessages().SetMessages(buffer)};
4658+
const auto &scope{context_.context().FindScope(source_)};
4659+
if (Symbol * generic{scope.FindSymbol(oprName)}) {
4660+
parser::Name name{generic->name(), generic};
4661+
const Symbol *resultSymbol{nullptr};
4662+
if (context_.AnalyzeDefinedOp(
4663+
name, ActualArguments{actuals_}, resultSymbol)) {
4664+
return true;
4665+
}
4666+
}
4667+
}
4668+
return false;
4669+
}
4670+
4671+
bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride(
4672+
const std::vector<const char *> &oprNames) const {
4673+
for (const char *opr : oprNames) {
4674+
if (HasDeviceDefinedIntrinsicOpOverride(opr)) {
4675+
return true;
4676+
}
4677+
}
4678+
return false;
4679+
}
4680+
46204681
MaybeExpr ArgumentAnalyzer::TryDefinedOp(
46214682
const char *opr, parser::MessageFixedText error, bool isUserOp) {
46224683
if (AnyUntypedOrMissingOperand()) {
@@ -5135,7 +5196,7 @@ std::string ArgumentAnalyzer::TypeAsFortran(std::size_t i) {
51355196
}
51365197
}
51375198

5138-
bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() {
5199+
bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() const {
51395200
for (const auto &actual : actuals_) {
51405201
if (!actual ||
51415202
(!actual->GetType() && !IsBareNullPointer(actual->UnwrapExpr()))) {

flang/test/Semantics/bug1214.cuf

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
2+
module overrides
3+
type realResult
4+
real a
5+
end type
6+
interface operator(*)
7+
procedure :: multHostDevice, multDeviceHost
8+
end interface
9+
interface assignment(=)
10+
procedure :: assignHostResult, assignDeviceResult
11+
end interface
12+
contains
13+
elemental function multHostDevice(x, y) result(result)
14+
real, intent(in) :: x
15+
real, intent(in), device :: y
16+
type(realResult) result
17+
result%a = x * y
18+
end
19+
elemental function multDeviceHost(x, y) result(result)
20+
real, intent(in), device :: x
21+
real, intent(in) :: y
22+
type(realResult) result
23+
result%a = x * y
24+
end
25+
elemental subroutine assignHostResult(lhs, rhs)
26+
real, intent(out) :: lhs
27+
type(realResult), intent(in) :: rhs
28+
lhs = rhs%a
29+
end
30+
elemental subroutine assignDeviceResult(lhs, rhs)
31+
real, intent(out), device :: lhs
32+
type(realResult), intent(in) :: rhs
33+
lhs = rhs%a
34+
end
35+
end
36+
37+
program p
38+
use overrides
39+
real, device :: da, db
40+
real :: ha, hb
41+
!CHECK: CALL assigndeviceresult(db,multhostdevice(2._4,da))
42+
db = 2. * da
43+
!CHECK: CALL assigndeviceresult(db,multdevicehost(da,2._4))
44+
db = da * 2.
45+
!CHECK: CALL assignhostresult(ha,multhostdevice(2._4,da))
46+
ha = 2. * da
47+
!CHECK: CALL assignhostresult(ha,multdevicehost(da,2._4))
48+
ha = da * 2.
49+
end

flang/test/Semantics/cuf11.cuf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ subroutine sub1()
1616
real, device :: adev(10), bdev(10)
1717
real :: ahost(10)
1818

19-
!ERROR: More than one reference to a CUDA object on the right hand side of the assigment
19+
!ERROR: More than one reference to a CUDA object on the right hand side of the assignment
2020
ahost = adev + bdev
2121

2222
ahost = adev + adev

0 commit comments

Comments
 (0)