Skip to content

Commit edfc18a

Browse files
committed
[flang][runtime] Optimize Descriptor::FixedStride()
Put the common cases on fast paths, and don't depend on IsContiguous() in the general case path. Add a unit test, too.
1 parent 35cabd6 commit edfc18a

File tree

5 files changed

+231
-20
lines changed

5 files changed

+231
-20
lines changed

flang-rt/include/flang-rt/runtime/descriptor.h

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ class Descriptor {
181181
const SubscriptValue *extent = nullptr,
182182
ISO::CFI_attribute_t attribute = CFI_attribute_other);
183183

184+
RT_API_ATTRS void UncheckedScalarEstablish(
185+
const typeInfo::DerivedType &, void *);
186+
184187
// To create a descriptor for a derived type the caller
185188
// must provide non-null dt argument.
186189
// The addendum argument is only used for testing purposes,
@@ -433,7 +436,9 @@ class Descriptor {
433436
bool stridesAreContiguous{true};
434437
for (int j{0}; j < leadingDimensions; ++j) {
435438
const Dimension &dim{GetDimension(j)};
436-
stridesAreContiguous &= bytes == dim.ByteStride() || dim.Extent() == 1;
439+
if (bytes != dim.ByteStride() && dim.Extent() != 1) {
440+
stridesAreContiguous = false;
441+
}
437442
bytes *= dim.Extent();
438443
}
439444
// One and zero element arrays are contiguous even if the descriptor
@@ -448,28 +453,44 @@ class Descriptor {
448453

449454
// The result, if any, is a fixed stride value that can be used to
450455
// address all elements. It generalizes contiguity by also allowing
451-
// the case of an array with extent 1 on all but one dimension.
456+
// the case of an array with extent 1 on all dimensions but one.
457+
// Returns 0 for an empty array, a byte stride if one is well-defined
458+
// for the array, or nullopt otherwise.
452459
RT_API_ATTRS common::optional<SubscriptValue> FixedStride() const {
453-
auto rank{static_cast<std::size_t>(raw_.rank)};
454-
common::optional<SubscriptValue> stride;
455-
for (std::size_t j{0}; j < rank; ++j) {
456-
const Dimension &dim{GetDimension(j)};
457-
auto extent{dim.Extent()};
458-
if (extent == 0) {
459-
break; // empty array
460-
} else if (extent == 1) { // ok
461-
} else if (stride) {
462-
// Extent > 1 on multiple dimensions
463-
if (IsContiguous()) {
464-
return ElementBytes();
460+
int rank{raw_.rank};
461+
auto elementBytes{static_cast<SubscriptValue>(ElementBytes())};
462+
if (rank == 0) {
463+
return elementBytes;
464+
} else if (rank == 1) {
465+
const Dimension &dim{GetDimension(0)};
466+
return dim.Extent() == 0 ? 0 : dim.ByteStride();
467+
} else {
468+
common::optional<SubscriptValue> stride;
469+
auto bytes{elementBytes};
470+
for (int j{0}; j < rank; ++j) {
471+
const Dimension &dim{GetDimension(j)};
472+
auto extent{dim.Extent()};
473+
if (extent == 0) {
474+
return 0; // empty array
475+
} else if (extent == 1) { // ok
465476
} else {
466-
return common::nullopt;
477+
if (stride) { // Extent > 1 on multiple dimensions
478+
if (bytes != dim.ByteStride()) { // discontiguity
479+
while (++j < rank) {
480+
if (GetDimension(j).Extent() == 0) {
481+
return 0; // empty array
482+
}
483+
}
484+
return common::nullopt; // nonempty, discontiguous
485+
}
486+
} else {
487+
stride = dim.ByteStride();
488+
}
489+
bytes *= extent;
467490
}
468-
} else {
469-
stride = dim.ByteStride();
470491
}
492+
return stride.value_or(elementBytes /*for singleton*/);
471493
}
472-
return stride.value_or(0); // 0 for scalars and empty arrays
473494
}
474495

475496
// Establishes a pointer to a section or element.

flang-rt/lib/runtime/derived.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,8 @@ RT_API_ATTRS int FinalizeTicket::Continue(WorkQueue &workQueue) {
360360
} else if (component_->genre() == typeInfo::Component::Genre::Data &&
361361
component_->derivedType() &&
362362
!component_->derivedType()->noFinalizationNeeded()) {
363+
// todo: calculate and use fixedStride_ here as in DestroyTicket to
364+
// avoid subscripts and repeated descriptor establishment.
363365
SubscriptValue extents[maxRank];
364366
GetComponentExtents(extents, *component_, instance_);
365367
Descriptor &compDesc{componentDescriptor_.descriptor()};
@@ -452,6 +454,24 @@ RT_API_ATTRS int DestroyTicket::Continue(WorkQueue &workQueue) {
452454
} else if (component_->genre() == typeInfo::Component::Genre::Data) {
453455
if (!componentDerived || componentDerived->noDestructionNeeded()) {
454456
SkipToNextComponent();
457+
} else if (fixedStride_) {
458+
// faster path, no need for subscripts, can reuse descriptor
459+
char *p{instance_.OffsetElement<char>(
460+
elementAt_ * *fixedStride_ + component_->offset())};
461+
Descriptor &compDesc{componentDescriptor_.descriptor()};
462+
const typeInfo::DerivedType &compType{*componentDerived};
463+
compDesc.UncheckedScalarEstablish(compType, p);
464+
for (std::size_t j{elementAt_}; j < elements_;
465+
++j, p += *fixedStride_) {
466+
compDesc.set_base_addr(p);
467+
++elementAt_;
468+
if (int status{workQueue.BeginDestroy(
469+
compDesc, compType, /*finalize=*/false)};
470+
status != StatOk) {
471+
return status;
472+
}
473+
}
474+
SkipToNextComponent();
455475
} else {
456476
SubscriptValue extents[maxRank];
457477
GetComponentExtents(extents, *component_, instance_);
@@ -461,8 +481,8 @@ RT_API_ATTRS int DestroyTicket::Continue(WorkQueue &workQueue) {
461481
instance_.ElementComponent<char>(subscripts_, component_->offset()),
462482
component_->rank(), extents);
463483
Advance();
464-
if (int status{workQueue.BeginDestroy(
465-
compDesc, *componentDerived, /*finalize=*/false)};
484+
if (int status{
485+
workQueue.BeginDestroy(compDesc, compType, /*finalize=*/false)};
466486
status != StatOk) {
467487
return status;
468488
}

flang-rt/lib/runtime/descriptor.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ RT_API_ATTRS void Descriptor::Establish(const typeInfo::DerivedType &dt,
100100
new (Addendum()) DescriptorAddendum{&dt};
101101
}
102102

103+
RT_API_ATTRS void Descriptor::UncheckedScalarEstablish(
104+
const typeInfo::DerivedType &dt, void *p) {
105+
auto elementBytes{static_cast<std::size_t>(dt.sizeInBytes())};
106+
ISO::EstablishDescriptor(
107+
&raw_, p, CFI_attribute_other, CFI_type_struct, elementBytes, 0, nullptr);
108+
SetHasAddendum();
109+
new (Addendum()) DescriptorAddendum{&dt};
110+
}
111+
103112
RT_API_ATTRS OwningPtr<Descriptor> Descriptor::Create(TypeCode t,
104113
std::size_t elementBytes, void *p, int rank, const SubscriptValue *extent,
105114
ISO::CFI_attribute_t attribute, bool addendum,

flang-rt/unittests/Runtime/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_flangrt_unittest(RuntimeTests
1717
Complex.cpp
1818
CrashHandlerFixture.cpp
1919
Derived.cpp
20+
Descriptor.cpp
2021
ExternalIOTest.cpp
2122
Format.cpp
2223
InputExtensions.cpp
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
//===-- unittests/Runtime/Pointer.cpp ---------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang-rt/runtime/descriptor.h"
10+
#include "tools.h"
11+
#include "gtest/gtest.h"
12+
13+
using namespace Fortran::runtime;
14+
15+
TEST(Descriptor, FixedStride) {
16+
StaticDescriptor<4> staticDesc[2];
17+
Descriptor &descriptor{staticDesc[0].descriptor()};
18+
using Type = std::int32_t;
19+
Type data[8][8][8];
20+
constexpr int four{static_cast<int>(sizeof data[0][0][0])};
21+
TypeCode integer{TypeCategory::Integer, four};
22+
// Scalar
23+
descriptor.Establish(integer, four, data, 0);
24+
EXPECT_TRUE(descriptor.IsContiguous());
25+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
26+
// Empty vector
27+
SubscriptValue extent[3]{0, 0, 0};
28+
descriptor.Establish(integer, four, data, 1, extent);
29+
EXPECT_TRUE(descriptor.IsContiguous());
30+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
31+
// Contiguous vector (0:7:1)
32+
extent[0] = 8;
33+
descriptor.Establish(integer, four, data, 1, extent);
34+
ASSERT_EQ(descriptor.rank(), 1);
35+
ASSERT_EQ(descriptor.Elements(), 8);
36+
ASSERT_EQ(descriptor.ElementBytes(), four);
37+
ASSERT_EQ(descriptor.GetDimension(0).LowerBound(), 0);
38+
ASSERT_EQ(descriptor.GetDimension(0).ByteStride(), four);
39+
ASSERT_EQ(descriptor.GetDimension(0).Extent(), 8);
40+
EXPECT_TRUE(descriptor.IsContiguous());
41+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
42+
// Contiguous reverse vector (7:0:-1)
43+
descriptor.GetDimension(0).SetByteStride(-four);
44+
EXPECT_FALSE(descriptor.IsContiguous());
45+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four);
46+
// Discontiguous vector (0:6:2)
47+
descriptor.GetDimension(0).SetExtent(4);
48+
descriptor.GetDimension(0).SetByteStride(2 * four);
49+
EXPECT_FALSE(descriptor.IsContiguous());
50+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four);
51+
// Empty matrix
52+
extent[0] = 0;
53+
descriptor.Establish(integer, four, data, 2, extent);
54+
EXPECT_TRUE(descriptor.IsContiguous());
55+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
56+
// Contiguous matrix (0:7, 0:7)
57+
extent[0] = extent[1] = 8;
58+
descriptor.Establish(integer, four, data, 2, extent);
59+
EXPECT_TRUE(descriptor.IsContiguous());
60+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
61+
// Contiguous row (0:7, 0)
62+
descriptor.GetDimension(1).SetExtent(1);
63+
EXPECT_TRUE(descriptor.IsContiguous());
64+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
65+
// Contiguous column (0, 0:7)
66+
descriptor.GetDimension(0).SetExtent(1);
67+
descriptor.GetDimension(1).SetExtent(7);
68+
descriptor.GetDimension(1).SetByteStride(8 * four);
69+
EXPECT_FALSE(descriptor.IsContiguous());
70+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * four);
71+
// Contiguous reverse row (7:0:-1, 0)
72+
descriptor.GetDimension(0).SetExtent(8);
73+
descriptor.GetDimension(0).SetByteStride(-four);
74+
descriptor.GetDimension(1).SetExtent(1);
75+
EXPECT_FALSE(descriptor.IsContiguous());
76+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four);
77+
// Contiguous reverse column (0, 7:0:-1)
78+
descriptor.GetDimension(0).SetExtent(1);
79+
descriptor.GetDimension(0).SetByteStride(four);
80+
descriptor.GetDimension(1).SetExtent(7);
81+
descriptor.GetDimension(1).SetByteStride(8 * -four);
82+
EXPECT_FALSE(descriptor.IsContiguous());
83+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -four);
84+
// Discontiguous row (0:6:2, 0)
85+
descriptor.GetDimension(0).SetExtent(4);
86+
descriptor.GetDimension(0).SetByteStride(2 * four);
87+
descriptor.GetDimension(1).SetExtent(1);
88+
descriptor.GetDimension(1).SetByteStride(four);
89+
EXPECT_FALSE(descriptor.IsContiguous());
90+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four);
91+
// Discontiguous column (0, 0:6:2)
92+
descriptor.GetDimension(0).SetExtent(1);
93+
descriptor.GetDimension(0).SetByteStride(four);
94+
descriptor.GetDimension(1).SetExtent(4);
95+
descriptor.GetDimension(1).SetByteStride(8 * 2 * four);
96+
EXPECT_FALSE(descriptor.IsContiguous());
97+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * 2 * four);
98+
// Discontiguous reverse row (7:1:-2, 0)
99+
descriptor.GetDimension(0).SetExtent(4);
100+
descriptor.GetDimension(0).SetByteStride(-2 * four);
101+
descriptor.GetDimension(1).SetExtent(1);
102+
descriptor.GetDimension(1).SetByteStride(four);
103+
EXPECT_FALSE(descriptor.IsContiguous());
104+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), -2 * four);
105+
// Discontiguous reverse column (0, 7:1:-2)
106+
descriptor.GetDimension(0).SetExtent(1);
107+
descriptor.GetDimension(0).SetByteStride(four);
108+
descriptor.GetDimension(1).SetExtent(4);
109+
descriptor.GetDimension(1).SetByteStride(8 * -2 * four);
110+
EXPECT_FALSE(descriptor.IsContiguous());
111+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -2 * four);
112+
// Discontiguous rows (0:6:2, 0:1)
113+
descriptor.GetDimension(0).SetExtent(4);
114+
descriptor.GetDimension(0).SetByteStride(2 * four);
115+
descriptor.GetDimension(1).SetExtent(2);
116+
descriptor.GetDimension(1).SetByteStride(8 * four);
117+
EXPECT_FALSE(descriptor.IsContiguous());
118+
EXPECT_FALSE(descriptor.FixedStride().has_value());
119+
// Discontiguous columns (0:1, 0:6:2)
120+
descriptor.GetDimension(0).SetExtent(2);
121+
descriptor.GetDimension(0).SetByteStride(four);
122+
descriptor.GetDimension(1).SetExtent(4);
123+
descriptor.GetDimension(1).SetByteStride(8 * four);
124+
EXPECT_FALSE(descriptor.IsContiguous());
125+
EXPECT_FALSE(descriptor.FixedStride().has_value());
126+
// Empty 3-D array
127+
extent[0] = extent[1] = extent[2] = 0;
128+
;
129+
descriptor.Establish(integer, four, data, 3, extent);
130+
EXPECT_TRUE(descriptor.IsContiguous());
131+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
132+
// Contiguous 3-D array (0:7, 0:7, 0:7)
133+
extent[0] = extent[1] = extent[2] = 8;
134+
descriptor.Establish(integer, four, data, 3, extent);
135+
EXPECT_TRUE(descriptor.IsContiguous());
136+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
137+
// Discontiguous 3-D array (0:7, 0:6:2, 0:6:2)
138+
descriptor.GetDimension(1).SetExtent(4);
139+
descriptor.GetDimension(1).SetByteStride(8 * 2 * four);
140+
descriptor.GetDimension(2).SetExtent(4);
141+
descriptor.GetDimension(2).SetByteStride(8 * 8 * 2 * four);
142+
EXPECT_FALSE(descriptor.IsContiguous());
143+
EXPECT_FALSE(descriptor.FixedStride().has_value());
144+
// Discontiguous-looking empty 3-D array (0:-1, 0:6:2, 0:6:2)
145+
descriptor.GetDimension(0).SetExtent(0);
146+
EXPECT_TRUE(descriptor.IsContiguous());
147+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
148+
// Discontiguous-looking empty 3-D array (0:6:2, 0:-1, 0:6:2)
149+
descriptor.GetDimension(0).SetExtent(4);
150+
descriptor.GetDimension(0).SetByteStride(2 * four);
151+
descriptor.GetDimension(1).SetExtent(0);
152+
EXPECT_TRUE(descriptor.IsContiguous());
153+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
154+
// Discontiguous-looking empty 3-D array (0:6:2, 0:6:2, 0:-1)
155+
descriptor.GetDimension(1).SetExtent(4);
156+
descriptor.GetDimension(1).SetExtent(8 * 2 * four);
157+
descriptor.GetDimension(2).SetExtent(0);
158+
EXPECT_TRUE(descriptor.IsContiguous());
159+
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
160+
}

0 commit comments

Comments
 (0)