diff --git a/flang-rt/include/flang-rt/runtime/descriptor.h b/flang-rt/include/flang-rt/runtime/descriptor.h index bc5a5b5f14697..abd668e27f18f 100644 --- a/flang-rt/include/flang-rt/runtime/descriptor.h +++ b/flang-rt/include/flang-rt/runtime/descriptor.h @@ -181,6 +181,9 @@ class Descriptor { const SubscriptValue *extent = nullptr, ISO::CFI_attribute_t attribute = CFI_attribute_other); + RT_API_ATTRS void UncheckedScalarEstablish( + const typeInfo::DerivedType &, void *); + // To create a descriptor for a derived type the caller // must provide non-null dt argument. // The addendum argument is only used for testing purposes, @@ -433,7 +436,9 @@ class Descriptor { bool stridesAreContiguous{true}; for (int j{0}; j < leadingDimensions; ++j) { const Dimension &dim{GetDimension(j)}; - stridesAreContiguous &= bytes == dim.ByteStride() || dim.Extent() == 1; + if (bytes != dim.ByteStride() && dim.Extent() != 1) { + stridesAreContiguous = false; + } bytes *= dim.Extent(); } // One and zero element arrays are contiguous even if the descriptor @@ -448,28 +453,44 @@ class Descriptor { // The result, if any, is a fixed stride value that can be used to // address all elements. It generalizes contiguity by also allowing - // the case of an array with extent 1 on all but one dimension. + // the case of an array with extent 1 on all dimensions but one. + // Returns 0 for an empty array, a byte stride if one is well-defined + // for the array, or nullopt otherwise. RT_API_ATTRS common::optional FixedStride() const { - auto rank{static_cast(raw_.rank)}; - common::optional stride; - for (std::size_t j{0}; j < rank; ++j) { - const Dimension &dim{GetDimension(j)}; - auto extent{dim.Extent()}; - if (extent == 0) { - break; // empty array - } else if (extent == 1) { // ok - } else if (stride) { - // Extent > 1 on multiple dimensions - if (IsContiguous()) { - return ElementBytes(); + int rank{raw_.rank}; + auto elementBytes{static_cast(ElementBytes())}; + if (rank == 0) { + return elementBytes; + } else if (rank == 1) { + const Dimension &dim{GetDimension(0)}; + return dim.Extent() == 0 ? 0 : dim.ByteStride(); + } else { + common::optional stride; + auto bytes{elementBytes}; + for (int j{0}; j < rank; ++j) { + const Dimension &dim{GetDimension(j)}; + auto extent{dim.Extent()}; + if (extent == 0) { + return 0; // empty array + } else if (extent == 1) { // ok } else { - return common::nullopt; + if (stride) { // Extent > 1 on multiple dimensions + if (bytes != dim.ByteStride()) { // discontiguity + while (++j < rank) { + if (GetDimension(j).Extent() == 0) { + return 0; // empty array + } + } + return common::nullopt; // nonempty, discontiguous + } + } else { + stride = dim.ByteStride(); + } + bytes *= extent; } - } else { - stride = dim.ByteStride(); } + return stride.value_or(elementBytes /*for singleton*/); } - return stride.value_or(0); // 0 for scalars and empty arrays } // Establishes a pointer to a section or element. diff --git a/flang-rt/lib/runtime/derived.cpp b/flang-rt/lib/runtime/derived.cpp index 4ed0baaa3d108..2dddf079f91db 100644 --- a/flang-rt/lib/runtime/derived.cpp +++ b/flang-rt/lib/runtime/derived.cpp @@ -360,6 +360,8 @@ RT_API_ATTRS int FinalizeTicket::Continue(WorkQueue &workQueue) { } else if (component_->genre() == typeInfo::Component::Genre::Data && component_->derivedType() && !component_->derivedType()->noFinalizationNeeded()) { + // todo: calculate and use fixedStride_ here as in DestroyTicket to + // avoid subscripts and repeated descriptor establishment. SubscriptValue extents[maxRank]; GetComponentExtents(extents, *component_, instance_); Descriptor &compDesc{componentDescriptor_.descriptor()}; @@ -452,6 +454,24 @@ RT_API_ATTRS int DestroyTicket::Continue(WorkQueue &workQueue) { } else if (component_->genre() == typeInfo::Component::Genre::Data) { if (!componentDerived || componentDerived->noDestructionNeeded()) { SkipToNextComponent(); + } else if (fixedStride_) { + // faster path, no need for subscripts, can reuse descriptor + char *p{instance_.OffsetElement( + elementAt_ * *fixedStride_ + component_->offset())}; + Descriptor &compDesc{componentDescriptor_.descriptor()}; + const typeInfo::DerivedType &compType{*componentDerived}; + compDesc.UncheckedScalarEstablish(compType, p); + for (std::size_t j{elementAt_}; j < elements_; + ++j, p += *fixedStride_) { + compDesc.set_base_addr(p); + ++elementAt_; + if (int status{workQueue.BeginDestroy( + compDesc, compType, /*finalize=*/false)}; + status != StatOk) { + return status; + } + } + SkipToNextComponent(); } else { SubscriptValue extents[maxRank]; GetComponentExtents(extents, *component_, instance_); @@ -461,8 +481,8 @@ RT_API_ATTRS int DestroyTicket::Continue(WorkQueue &workQueue) { instance_.ElementComponent(subscripts_, component_->offset()), component_->rank(), extents); Advance(); - if (int status{workQueue.BeginDestroy( - compDesc, *componentDerived, /*finalize=*/false)}; + if (int status{ + workQueue.BeginDestroy(compDesc, compType, /*finalize=*/false)}; status != StatOk) { return status; } diff --git a/flang-rt/lib/runtime/descriptor.cpp b/flang-rt/lib/runtime/descriptor.cpp index 021440cbdd0f6..fde4baa6a317c 100644 --- a/flang-rt/lib/runtime/descriptor.cpp +++ b/flang-rt/lib/runtime/descriptor.cpp @@ -100,6 +100,15 @@ RT_API_ATTRS void Descriptor::Establish(const typeInfo::DerivedType &dt, new (Addendum()) DescriptorAddendum{&dt}; } +RT_API_ATTRS void Descriptor::UncheckedScalarEstablish( + const typeInfo::DerivedType &dt, void *p) { + auto elementBytes{static_cast(dt.sizeInBytes())}; + ISO::EstablishDescriptor( + &raw_, p, CFI_attribute_other, CFI_type_struct, elementBytes, 0, nullptr); + SetHasAddendum(); + new (Addendum()) DescriptorAddendum{&dt}; +} + RT_API_ATTRS OwningPtr Descriptor::Create(TypeCode t, std::size_t elementBytes, void *p, int rank, const SubscriptValue *extent, ISO::CFI_attribute_t attribute, bool addendum, diff --git a/flang-rt/unittests/Runtime/CMakeLists.txt b/flang-rt/unittests/Runtime/CMakeLists.txt index cf1e15ddfa3e7..e51bc24415773 100644 --- a/flang-rt/unittests/Runtime/CMakeLists.txt +++ b/flang-rt/unittests/Runtime/CMakeLists.txt @@ -17,6 +17,7 @@ add_flangrt_unittest(RuntimeTests Complex.cpp CrashHandlerFixture.cpp Derived.cpp + Descriptor.cpp ExternalIOTest.cpp Format.cpp InputExtensions.cpp diff --git a/flang-rt/unittests/Runtime/Descriptor.cpp b/flang-rt/unittests/Runtime/Descriptor.cpp new file mode 100644 index 0000000000000..3a4a7670fc62e --- /dev/null +++ b/flang-rt/unittests/Runtime/Descriptor.cpp @@ -0,0 +1,160 @@ +//===-- unittests/Runtime/Pointer.cpp ---------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang-rt/runtime/descriptor.h" +#include "tools.h" +#include "gtest/gtest.h" + +using namespace Fortran::runtime; + +TEST(Descriptor, FixedStride) { + StaticDescriptor<4> staticDesc[2]; + Descriptor &descriptor{staticDesc[0].descriptor()}; + using Type = std::int32_t; + Type data[8][8][8]; + constexpr int four{static_cast(sizeof data[0][0][0])}; + TypeCode integer{TypeCategory::Integer, four}; + // Scalar + descriptor.Establish(integer, four, data, 0); + EXPECT_TRUE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), four); + // Empty vector + SubscriptValue extent[3]{0, 0, 0}; + descriptor.Establish(integer, four, data, 1, extent); + EXPECT_TRUE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0); + // Contiguous vector (0:7:1) + extent[0] = 8; + descriptor.Establish(integer, four, data, 1, extent); + ASSERT_EQ(descriptor.rank(), 1); + ASSERT_EQ(descriptor.Elements(), 8); + ASSERT_EQ(descriptor.ElementBytes(), four); + ASSERT_EQ(descriptor.GetDimension(0).LowerBound(), 0); + ASSERT_EQ(descriptor.GetDimension(0).ByteStride(), four); + ASSERT_EQ(descriptor.GetDimension(0).Extent(), 8); + EXPECT_TRUE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), four); + // Contiguous reverse vector (7:0:-1) + descriptor.GetDimension(0).SetByteStride(-four); + EXPECT_FALSE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four); + // Discontiguous vector (0:6:2) + descriptor.GetDimension(0).SetExtent(4); + descriptor.GetDimension(0).SetByteStride(2 * four); + EXPECT_FALSE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four); + // Empty matrix + extent[0] = 0; + descriptor.Establish(integer, four, data, 2, extent); + EXPECT_TRUE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0); + // Contiguous matrix (0:7, 0:7) + extent[0] = extent[1] = 8; + descriptor.Establish(integer, four, data, 2, extent); + EXPECT_TRUE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), four); + // Contiguous row (0:7, 0) + descriptor.GetDimension(1).SetExtent(1); + EXPECT_TRUE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), four); + // Contiguous column (0, 0:7) + descriptor.GetDimension(0).SetExtent(1); + descriptor.GetDimension(1).SetExtent(7); + descriptor.GetDimension(1).SetByteStride(8 * four); + EXPECT_FALSE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * four); + // Contiguous reverse row (7:0:-1, 0) + descriptor.GetDimension(0).SetExtent(8); + descriptor.GetDimension(0).SetByteStride(-four); + descriptor.GetDimension(1).SetExtent(1); + EXPECT_FALSE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four); + // Contiguous reverse column (0, 7:0:-1) + descriptor.GetDimension(0).SetExtent(1); + descriptor.GetDimension(0).SetByteStride(four); + descriptor.GetDimension(1).SetExtent(7); + descriptor.GetDimension(1).SetByteStride(8 * -four); + EXPECT_FALSE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -four); + // Discontiguous row (0:6:2, 0) + descriptor.GetDimension(0).SetExtent(4); + descriptor.GetDimension(0).SetByteStride(2 * four); + descriptor.GetDimension(1).SetExtent(1); + descriptor.GetDimension(1).SetByteStride(four); + EXPECT_FALSE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four); + // Discontiguous column (0, 0:6:2) + descriptor.GetDimension(0).SetExtent(1); + descriptor.GetDimension(0).SetByteStride(four); + descriptor.GetDimension(1).SetExtent(4); + descriptor.GetDimension(1).SetByteStride(8 * 2 * four); + EXPECT_FALSE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * 2 * four); + // Discontiguous reverse row (7:1:-2, 0) + descriptor.GetDimension(0).SetExtent(4); + descriptor.GetDimension(0).SetByteStride(-2 * four); + descriptor.GetDimension(1).SetExtent(1); + descriptor.GetDimension(1).SetByteStride(four); + EXPECT_FALSE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), -2 * four); + // Discontiguous reverse column (0, 7:1:-2) + descriptor.GetDimension(0).SetExtent(1); + descriptor.GetDimension(0).SetByteStride(four); + descriptor.GetDimension(1).SetExtent(4); + descriptor.GetDimension(1).SetByteStride(8 * -2 * four); + EXPECT_FALSE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -2 * four); + // Discontiguous rows (0:6:2, 0:1) + descriptor.GetDimension(0).SetExtent(4); + descriptor.GetDimension(0).SetByteStride(2 * four); + descriptor.GetDimension(1).SetExtent(2); + descriptor.GetDimension(1).SetByteStride(8 * four); + EXPECT_FALSE(descriptor.IsContiguous()); + EXPECT_FALSE(descriptor.FixedStride().has_value()); + // Discontiguous columns (0:1, 0:6:2) + descriptor.GetDimension(0).SetExtent(2); + descriptor.GetDimension(0).SetByteStride(four); + descriptor.GetDimension(1).SetExtent(4); + descriptor.GetDimension(1).SetByteStride(8 * four); + EXPECT_FALSE(descriptor.IsContiguous()); + EXPECT_FALSE(descriptor.FixedStride().has_value()); + // Empty 3-D array + extent[0] = extent[1] = extent[2] = 0; + ; + descriptor.Establish(integer, four, data, 3, extent); + EXPECT_TRUE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0); + // Contiguous 3-D array (0:7, 0:7, 0:7) + extent[0] = extent[1] = extent[2] = 8; + descriptor.Establish(integer, four, data, 3, extent); + EXPECT_TRUE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), four); + // Discontiguous 3-D array (0:7, 0:6:2, 0:6:2) + descriptor.GetDimension(1).SetExtent(4); + descriptor.GetDimension(1).SetByteStride(8 * 2 * four); + descriptor.GetDimension(2).SetExtent(4); + descriptor.GetDimension(2).SetByteStride(8 * 8 * 2 * four); + EXPECT_FALSE(descriptor.IsContiguous()); + EXPECT_FALSE(descriptor.FixedStride().has_value()); + // Discontiguous-looking empty 3-D array (0:-1, 0:6:2, 0:6:2) + descriptor.GetDimension(0).SetExtent(0); + EXPECT_TRUE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0); + // Discontiguous-looking empty 3-D array (0:6:2, 0:-1, 0:6:2) + descriptor.GetDimension(0).SetExtent(4); + descriptor.GetDimension(0).SetByteStride(2 * four); + descriptor.GetDimension(1).SetExtent(0); + EXPECT_TRUE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0); + // Discontiguous-looking empty 3-D array (0:6:2, 0:6:2, 0:-1) + descriptor.GetDimension(1).SetExtent(4); + descriptor.GetDimension(1).SetExtent(8 * 2 * four); + descriptor.GetDimension(2).SetExtent(0); + EXPECT_TRUE(descriptor.IsContiguous()); + EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0); +}