Skip to content

Commit a9627b7

Browse files
committed
[CUDA] Add partial support for recent CUDA versions.
Generate PTX using newer versions of PTX and allow using sm_80 with CUDA-11. None of the new features of CUDA-10.2+ have been implemented yet, so using these versions will still produce a warning. Differential Revision: https://reviews.llvm.org/D77670
1 parent 33386b2 commit a9627b7

File tree

6 files changed

+59
-22
lines changed

6 files changed

+59
-22
lines changed

clang/include/clang/Basic/Cuda.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ enum class CudaVersion {
2727
CUDA_92,
2828
CUDA_100,
2929
CUDA_101,
30-
LATEST = CUDA_101,
30+
CUDA_102,
31+
CUDA_110,
32+
LATEST = CUDA_110,
33+
LATEST_SUPPORTED = CUDA_101,
3134
};
3235
const char *CudaVersionToString(CudaVersion V);
3336
// Input is "Major.Minor"
@@ -50,6 +53,7 @@ enum class CudaArch {
5053
SM_70,
5154
SM_72,
5255
SM_75,
56+
SM_80,
5357
GFX600,
5458
GFX601,
5559
GFX700,

clang/lib/Basic/Cuda.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ const char *CudaVersionToString(CudaVersion V) {
2828
return "10.0";
2929
case CudaVersion::CUDA_101:
3030
return "10.1";
31+
case CudaVersion::CUDA_102:
32+
return "10.2";
33+
case CudaVersion::CUDA_110:
34+
return "11.0";
3135
}
3236
llvm_unreachable("invalid enum");
3337
}
@@ -42,6 +46,8 @@ CudaVersion CudaStringToVersion(const llvm::Twine &S) {
4246
.Case("9.2", CudaVersion::CUDA_92)
4347
.Case("10.0", CudaVersion::CUDA_100)
4448
.Case("10.1", CudaVersion::CUDA_101)
49+
.Case("10.2", CudaVersion::CUDA_102)
50+
.Case("11.0", CudaVersion::CUDA_110)
4551
.Default(CudaVersion::UNKNOWN);
4652
}
4753

@@ -64,6 +70,7 @@ CudaArchToStringMap arch_names[] = {
6470
SM(60), SM(61), SM(62), // Pascal
6571
SM(70), SM(72), // Volta
6672
SM(75), // Turing
73+
SM(80), // Ampere
6774
GFX(600), // tahiti
6875
GFX(601), // pitcairn, verde, oland,hainan
6976
GFX(700), // kaveri
@@ -140,6 +147,8 @@ CudaVersion MinVersionForCudaArch(CudaArch A) {
140147
return CudaVersion::CUDA_91;
141148
case CudaArch::SM_75:
142149
return CudaVersion::CUDA_100;
150+
case CudaArch::SM_80:
151+
return CudaVersion::CUDA_110;
143152
default:
144153
llvm_unreachable("invalid enum");
145154
}

clang/lib/Basic/Targets/NVPTX.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ NVPTXTargetInfo::NVPTXTargetInfo(const llvm::Triple &Triple,
4444
if (!Feature.startswith("+ptx"))
4545
continue;
4646
PTXVersion = llvm::StringSwitch<unsigned>(Feature)
47+
.Case("+ptx70", 70)
48+
.Case("+ptx65", 65)
4749
.Case("+ptx64", 64)
4850
.Case("+ptx63", 63)
4951
.Case("+ptx61", 61)
@@ -231,6 +233,8 @@ void NVPTXTargetInfo::getTargetDefines(const LangOptions &Opts,
231233
return "720";
232234
case CudaArch::SM_75:
233235
return "750";
236+
case CudaArch::SM_80:
237+
return "800";
234238
}
235239
llvm_unreachable("unhandled CudaArch");
236240
}();

clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4992,6 +4992,7 @@ void CGOpenMPRuntimeNVPTX::processRequiresDirective(
49924992
case CudaArch::SM_70:
49934993
case CudaArch::SM_72:
49944994
case CudaArch::SM_75:
4995+
case CudaArch::SM_80:
49954996
case CudaArch::GFX600:
49964997
case CudaArch::GFX601:
49974998
case CudaArch::GFX700:
@@ -5049,6 +5050,7 @@ static std::pair<unsigned, unsigned> getSMsBlocksPerSM(CodeGenModule &CGM) {
50495050
case CudaArch::SM_70:
50505051
case CudaArch::SM_72:
50515052
case CudaArch::SM_75:
5053+
case CudaArch::SM_80:
50525054
return {84, 32};
50535055
case CudaArch::GFX600:
50545056
case CudaArch::GFX601:

clang/lib/Driver/ToolChains/Cuda.cpp

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,22 @@ void CudaInstallationDetector::ParseCudaVersionFile(llvm::StringRef V) {
4545
return;
4646
DetectedVersion = join_items(".", VersionParts[0], VersionParts[1]);
4747
Version = CudaStringToVersion(DetectedVersion);
48-
if (Version != CudaVersion::UNKNOWN)
48+
if (Version != CudaVersion::UNKNOWN) {
49+
// TODO(tra): remove the warning once we have all features of 10.2 and 11.0
50+
// implemented.
51+
DetectedVersionIsNotSupported = Version > CudaVersion::LATEST_SUPPORTED;
4952
return;
53+
}
5054

51-
Version = CudaVersion::LATEST;
55+
Version = CudaVersion::LATEST_SUPPORTED;
5256
DetectedVersionIsNotSupported = true;
5357
}
5458

5559
void CudaInstallationDetector::WarnIfUnsupportedVersion() {
5660
if (DetectedVersionIsNotSupported)
5761
D.Diag(diag::warn_drv_unknown_cuda_version)
58-
<< DetectedVersion << CudaVersionToString(Version);
62+
<< DetectedVersion
63+
<< CudaVersionToString(CudaVersion::LATEST_SUPPORTED);
5964
}
6065

6166
CudaInstallationDetector::CudaInstallationDetector(
@@ -639,24 +644,30 @@ void CudaToolChain::addClangTargetOptions(
639644
// by new PTX version, so we need to raise PTX level to enable them in NVPTX
640645
// back-end.
641646
const char *PtxFeature = nullptr;
642-
switch(CudaInstallation.version()) {
643-
case CudaVersion::CUDA_101:
644-
PtxFeature = "+ptx64";
645-
break;
646-
case CudaVersion::CUDA_100:
647-
PtxFeature = "+ptx63";
648-
break;
649-
case CudaVersion::CUDA_92:
650-
PtxFeature = "+ptx61";
651-
break;
652-
case CudaVersion::CUDA_91:
653-
PtxFeature = "+ptx61";
654-
break;
655-
case CudaVersion::CUDA_90:
656-
PtxFeature = "+ptx60";
657-
break;
658-
default:
659-
PtxFeature = "+ptx42";
647+
switch (CudaInstallation.version()) {
648+
case CudaVersion::CUDA_110:
649+
PtxFeature = "+ptx70";
650+
break;
651+
case CudaVersion::CUDA_102:
652+
PtxFeature = "+ptx65";
653+
break;
654+
case CudaVersion::CUDA_101:
655+
PtxFeature = "+ptx64";
656+
break;
657+
case CudaVersion::CUDA_100:
658+
PtxFeature = "+ptx63";
659+
break;
660+
case CudaVersion::CUDA_92:
661+
PtxFeature = "+ptx61";
662+
break;
663+
case CudaVersion::CUDA_91:
664+
PtxFeature = "+ptx61";
665+
break;
666+
case CudaVersion::CUDA_90:
667+
PtxFeature = "+ptx60";
668+
break;
669+
default:
670+
PtxFeature = "+ptx42";
660671
}
661672
CC1Args.append({"-target-feature", PtxFeature});
662673
if (DriverArgs.hasFlag(options::OPT_fcuda_short_ptr,

llvm/lib/Target/NVPTX/NVPTX.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def SM72 : SubtargetFeature<"sm_72", "SmVersion", "72",
5555
"Target SM 7.2">;
5656
def SM75 : SubtargetFeature<"sm_75", "SmVersion", "75",
5757
"Target SM 7.5">;
58+
def SM80 : SubtargetFeature<"sm_80", "SmVersion", "80",
59+
"Target SM 8.0">;
5860

5961
// PTX Versions
6062
def PTX32 : SubtargetFeature<"ptx32", "PTXVersion", "32",
@@ -77,6 +79,10 @@ def PTX63 : SubtargetFeature<"ptx63", "PTXVersion", "63",
7779
"Use PTX version 6.3">;
7880
def PTX64 : SubtargetFeature<"ptx64", "PTXVersion", "64",
7981
"Use PTX version 6.4">;
82+
def PTX65 : SubtargetFeature<"ptx65", "PTXVersion", "65",
83+
"Use PTX version 6.5">;
84+
def PTX70 : SubtargetFeature<"ptx70", "PTXVersion", "70",
85+
"Use PTX version 7.0">;
8086

8187
//===----------------------------------------------------------------------===//
8288
// NVPTX supported processors.
@@ -100,6 +106,7 @@ def : Proc<"sm_62", [SM62, PTX50]>;
100106
def : Proc<"sm_70", [SM70, PTX60]>;
101107
def : Proc<"sm_72", [SM72, PTX61]>;
102108
def : Proc<"sm_75", [SM75, PTX63]>;
109+
def : Proc<"sm_80", [SM80, PTX70]>;
103110

104111
def NVPTXInstrInfo : InstrInfo {
105112
}

0 commit comments

Comments
 (0)