Skip to content

[flang][runtime] Optimize Descriptor::FixedStride() #151755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions flang-rt/include/flang-rt/runtime/descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ class Descriptor {
const SubscriptValue *extent = nullptr,
ISO::CFI_attribute_t attribute = CFI_attribute_other);

RT_API_ATTRS void UncheckedScalarEstablish(
const typeInfo::DerivedType &, void *);

// To create a descriptor for a derived type the caller
// must provide non-null dt argument.
// The addendum argument is only used for testing purposes,
Expand Down Expand Up @@ -433,7 +436,9 @@ class Descriptor {
bool stridesAreContiguous{true};
for (int j{0}; j < leadingDimensions; ++j) {
const Dimension &dim{GetDimension(j)};
stridesAreContiguous &= bytes == dim.ByteStride() || dim.Extent() == 1;
if (bytes != dim.ByteStride() && dim.Extent() != 1) {
stridesAreContiguous = false;
}
bytes *= dim.Extent();
}
// One and zero element arrays are contiguous even if the descriptor
Expand All @@ -448,28 +453,44 @@ class Descriptor {

// The result, if any, is a fixed stride value that can be used to
// address all elements. It generalizes contiguity by also allowing
// the case of an array with extent 1 on all but one dimension.
// the case of an array with extent 1 on all dimensions but one.
// Returns 0 for an empty array, a byte stride if one is well-defined
// for the array, or nullopt otherwise.
RT_API_ATTRS common::optional<SubscriptValue> FixedStride() const {
auto rank{static_cast<std::size_t>(raw_.rank)};
common::optional<SubscriptValue> stride;
for (std::size_t j{0}; j < rank; ++j) {
const Dimension &dim{GetDimension(j)};
auto extent{dim.Extent()};
if (extent == 0) {
break; // empty array
} else if (extent == 1) { // ok
} else if (stride) {
// Extent > 1 on multiple dimensions
if (IsContiguous()) {
return ElementBytes();
int rank{raw_.rank};
auto elementBytes{static_cast<SubscriptValue>(ElementBytes())};
if (rank == 0) {
return elementBytes;
} else if (rank == 1) {
const Dimension &dim{GetDimension(0)};
return dim.Extent() == 0 ? 0 : dim.ByteStride();
} else {
common::optional<SubscriptValue> stride;
auto bytes{elementBytes};
for (int j{0}; j < rank; ++j) {
const Dimension &dim{GetDimension(j)};
auto extent{dim.Extent()};
if (extent == 0) {
return 0; // empty array
} else if (extent == 1) { // ok
} else {
return common::nullopt;
if (stride) { // Extent > 1 on multiple dimensions
if (bytes != dim.ByteStride()) { // discontiguity
while (++j < rank) {
if (GetDimension(j).Extent() == 0) {
return 0; // empty array
}
}
return common::nullopt; // nonempty, discontiguous
}
} else {
stride = dim.ByteStride();
}
bytes *= extent;
}
} else {
stride = dim.ByteStride();
}
return stride.value_or(elementBytes /*for singleton*/);
}
return stride.value_or(0); // 0 for scalars and empty arrays
}

// Establishes a pointer to a section or element.
Expand Down
24 changes: 22 additions & 2 deletions flang-rt/lib/runtime/derived.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ RT_API_ATTRS int FinalizeTicket::Continue(WorkQueue &workQueue) {
} else if (component_->genre() == typeInfo::Component::Genre::Data &&
component_->derivedType() &&
!component_->derivedType()->noFinalizationNeeded()) {
// todo: calculate and use fixedStride_ here as in DestroyTicket to
// avoid subscripts and repeated descriptor establishment.
SubscriptValue extents[maxRank];
GetComponentExtents(extents, *component_, instance_);
Descriptor &compDesc{componentDescriptor_.descriptor()};
Expand Down Expand Up @@ -452,6 +454,24 @@ RT_API_ATTRS int DestroyTicket::Continue(WorkQueue &workQueue) {
} else if (component_->genre() == typeInfo::Component::Genre::Data) {
if (!componentDerived || componentDerived->noDestructionNeeded()) {
SkipToNextComponent();
} else if (fixedStride_) {
// faster path, no need for subscripts, can reuse descriptor
char *p{instance_.OffsetElement<char>(
elementAt_ * *fixedStride_ + component_->offset())};
Descriptor &compDesc{componentDescriptor_.descriptor()};
const typeInfo::DerivedType &compType{*componentDerived};
compDesc.UncheckedScalarEstablish(compType, p);
for (std::size_t j{elementAt_}; j < elements_;
++j, p += *fixedStride_) {
compDesc.set_base_addr(p);
++elementAt_;
if (int status{workQueue.BeginDestroy(
compDesc, compType, /*finalize=*/false)};
status != StatOk) {
return status;
}
}
SkipToNextComponent();
} else {
SubscriptValue extents[maxRank];
GetComponentExtents(extents, *component_, instance_);
Expand All @@ -461,8 +481,8 @@ RT_API_ATTRS int DestroyTicket::Continue(WorkQueue &workQueue) {
instance_.ElementComponent<char>(subscripts_, component_->offset()),
component_->rank(), extents);
Advance();
if (int status{workQueue.BeginDestroy(
compDesc, *componentDerived, /*finalize=*/false)};
if (int status{
workQueue.BeginDestroy(compDesc, compType, /*finalize=*/false)};
status != StatOk) {
return status;
}
Expand Down
9 changes: 9 additions & 0 deletions flang-rt/lib/runtime/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ RT_API_ATTRS void Descriptor::Establish(const typeInfo::DerivedType &dt,
new (Addendum()) DescriptorAddendum{&dt};
}

RT_API_ATTRS void Descriptor::UncheckedScalarEstablish(
const typeInfo::DerivedType &dt, void *p) {
auto elementBytes{static_cast<std::size_t>(dt.sizeInBytes())};
ISO::EstablishDescriptor(
&raw_, p, CFI_attribute_other, CFI_type_struct, elementBytes, 0, nullptr);
SetHasAddendum();
new (Addendum()) DescriptorAddendum{&dt};
}

RT_API_ATTRS OwningPtr<Descriptor> Descriptor::Create(TypeCode t,
std::size_t elementBytes, void *p, int rank, const SubscriptValue *extent,
ISO::CFI_attribute_t attribute, bool addendum,
Expand Down
1 change: 1 addition & 0 deletions flang-rt/unittests/Runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_flangrt_unittest(RuntimeTests
Complex.cpp
CrashHandlerFixture.cpp
Derived.cpp
Descriptor.cpp
ExternalIOTest.cpp
Format.cpp
InputExtensions.cpp
Expand Down
160 changes: 160 additions & 0 deletions flang-rt/unittests/Runtime/Descriptor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
//===-- unittests/Runtime/Pointer.cpp ---------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang-rt/runtime/descriptor.h"
#include "tools.h"
#include "gtest/gtest.h"

using namespace Fortran::runtime;

TEST(Descriptor, FixedStride) {
StaticDescriptor<4> staticDesc[2];
Descriptor &descriptor{staticDesc[0].descriptor()};
using Type = std::int32_t;
Type data[8][8][8];
constexpr int four{static_cast<int>(sizeof data[0][0][0])};
TypeCode integer{TypeCategory::Integer, four};
// Scalar
descriptor.Establish(integer, four, data, 0);
EXPECT_TRUE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
// Empty vector
SubscriptValue extent[3]{0, 0, 0};
descriptor.Establish(integer, four, data, 1, extent);
EXPECT_TRUE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
// Contiguous vector (0:7:1)
extent[0] = 8;
descriptor.Establish(integer, four, data, 1, extent);
ASSERT_EQ(descriptor.rank(), 1);
ASSERT_EQ(descriptor.Elements(), 8);
ASSERT_EQ(descriptor.ElementBytes(), four);
ASSERT_EQ(descriptor.GetDimension(0).LowerBound(), 0);
ASSERT_EQ(descriptor.GetDimension(0).ByteStride(), four);
ASSERT_EQ(descriptor.GetDimension(0).Extent(), 8);
EXPECT_TRUE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
// Contiguous reverse vector (7:0:-1)
descriptor.GetDimension(0).SetByteStride(-four);
EXPECT_FALSE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four);
// Discontiguous vector (0:6:2)
descriptor.GetDimension(0).SetExtent(4);
descriptor.GetDimension(0).SetByteStride(2 * four);
EXPECT_FALSE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four);
// Empty matrix
extent[0] = 0;
descriptor.Establish(integer, four, data, 2, extent);
EXPECT_TRUE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
// Contiguous matrix (0:7, 0:7)
extent[0] = extent[1] = 8;
descriptor.Establish(integer, four, data, 2, extent);
EXPECT_TRUE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
// Contiguous row (0:7, 0)
descriptor.GetDimension(1).SetExtent(1);
EXPECT_TRUE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
// Contiguous column (0, 0:7)
descriptor.GetDimension(0).SetExtent(1);
descriptor.GetDimension(1).SetExtent(7);
descriptor.GetDimension(1).SetByteStride(8 * four);
EXPECT_FALSE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * four);
// Contiguous reverse row (7:0:-1, 0)
descriptor.GetDimension(0).SetExtent(8);
descriptor.GetDimension(0).SetByteStride(-four);
descriptor.GetDimension(1).SetExtent(1);
EXPECT_FALSE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four);
// Contiguous reverse column (0, 7:0:-1)
descriptor.GetDimension(0).SetExtent(1);
descriptor.GetDimension(0).SetByteStride(four);
descriptor.GetDimension(1).SetExtent(7);
descriptor.GetDimension(1).SetByteStride(8 * -four);
EXPECT_FALSE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -four);
// Discontiguous row (0:6:2, 0)
descriptor.GetDimension(0).SetExtent(4);
descriptor.GetDimension(0).SetByteStride(2 * four);
descriptor.GetDimension(1).SetExtent(1);
descriptor.GetDimension(1).SetByteStride(four);
EXPECT_FALSE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four);
// Discontiguous column (0, 0:6:2)
descriptor.GetDimension(0).SetExtent(1);
descriptor.GetDimension(0).SetByteStride(four);
descriptor.GetDimension(1).SetExtent(4);
descriptor.GetDimension(1).SetByteStride(8 * 2 * four);
EXPECT_FALSE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * 2 * four);
// Discontiguous reverse row (7:1:-2, 0)
descriptor.GetDimension(0).SetExtent(4);
descriptor.GetDimension(0).SetByteStride(-2 * four);
descriptor.GetDimension(1).SetExtent(1);
descriptor.GetDimension(1).SetByteStride(four);
EXPECT_FALSE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), -2 * four);
// Discontiguous reverse column (0, 7:1:-2)
descriptor.GetDimension(0).SetExtent(1);
descriptor.GetDimension(0).SetByteStride(four);
descriptor.GetDimension(1).SetExtent(4);
descriptor.GetDimension(1).SetByteStride(8 * -2 * four);
EXPECT_FALSE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -2 * four);
// Discontiguous rows (0:6:2, 0:1)
descriptor.GetDimension(0).SetExtent(4);
descriptor.GetDimension(0).SetByteStride(2 * four);
descriptor.GetDimension(1).SetExtent(2);
descriptor.GetDimension(1).SetByteStride(8 * four);
EXPECT_FALSE(descriptor.IsContiguous());
EXPECT_FALSE(descriptor.FixedStride().has_value());
// Discontiguous columns (0:1, 0:6:2)
descriptor.GetDimension(0).SetExtent(2);
descriptor.GetDimension(0).SetByteStride(four);
descriptor.GetDimension(1).SetExtent(4);
descriptor.GetDimension(1).SetByteStride(8 * four);
EXPECT_FALSE(descriptor.IsContiguous());
EXPECT_FALSE(descriptor.FixedStride().has_value());
// Empty 3-D array
extent[0] = extent[1] = extent[2] = 0;
;
descriptor.Establish(integer, four, data, 3, extent);
EXPECT_TRUE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
// Contiguous 3-D array (0:7, 0:7, 0:7)
extent[0] = extent[1] = extent[2] = 8;
descriptor.Establish(integer, four, data, 3, extent);
EXPECT_TRUE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
// Discontiguous 3-D array (0:7, 0:6:2, 0:6:2)
descriptor.GetDimension(1).SetExtent(4);
descriptor.GetDimension(1).SetByteStride(8 * 2 * four);
descriptor.GetDimension(2).SetExtent(4);
descriptor.GetDimension(2).SetByteStride(8 * 8 * 2 * four);
EXPECT_FALSE(descriptor.IsContiguous());
EXPECT_FALSE(descriptor.FixedStride().has_value());
// Discontiguous-looking empty 3-D array (0:-1, 0:6:2, 0:6:2)
descriptor.GetDimension(0).SetExtent(0);
EXPECT_TRUE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
// Discontiguous-looking empty 3-D array (0:6:2, 0:-1, 0:6:2)
descriptor.GetDimension(0).SetExtent(4);
descriptor.GetDimension(0).SetByteStride(2 * four);
descriptor.GetDimension(1).SetExtent(0);
EXPECT_TRUE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
// Discontiguous-looking empty 3-D array (0:6:2, 0:6:2, 0:-1)
descriptor.GetDimension(1).SetExtent(4);
descriptor.GetDimension(1).SetExtent(8 * 2 * four);
descriptor.GetDimension(2).SetExtent(0);
EXPECT_TRUE(descriptor.IsContiguous());
EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
}
Loading