Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 104 additions & 36 deletions tree/dataframe/src/RNTupleDS.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <TSystem.h>

#include <cassert>
#include <limits>
#include <memory>
#include <mutex>
#include <string>
Expand All @@ -49,32 +50,27 @@
// clang-format on

namespace ROOT::Internal::RDF {
/// An artificial field that transforms an RNTuple column that contains the offset of collections into
/// collection sizes. It is used to provide the "number of" RDF columns for collections, e.g.
/// `R_rdf_sizeof_jets` for a collection named `jets`.
///
/// This field owns the collection offset field but instead of exposing the collection offsets it exposes
/// the collection sizes (offset(N+1) - offset(N)). For the time being, we offer this functionality only in RDataFrame.
/// TODO(jblomer): consider providing a general set of useful virtual fields as part of RNTuple.
class RRDFCardinalityField final : public ROOT::RFieldBase {
class RRDFCardinalityFieldBase : public ROOT::RFieldBase {
protected:
std::unique_ptr<ROOT::RFieldBase> CloneImpl(std::string_view newName) const final
{
return std::make_unique<RRDFCardinalityField>(newName);
}
void ConstructValue(void *where) const final { *static_cast<std::size_t *>(where) = 0; }

// We construct these fields and know that they match the page source
void ReconcileOnDiskField(const RNTupleDescriptor &) final {}

public:
RRDFCardinalityField(std::string_view name)
: ROOT::RFieldBase(name, "std::size_t", ROOT::ENTupleStructure::kPlain, false /* isSimple */)
RRDFCardinalityFieldBase(std::string_view name, std::string_view type)
: ROOT::RFieldBase(name, type, ROOT::ENTupleStructure::kPlain, false /* isSimple */)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, by using the type name, the field is exposed and readable by the user as the integer type and not just as RNTupleCardinality, correct?

{
}
RRDFCardinalityField(RRDFCardinalityField &&other) = default;
RRDFCardinalityField &operator=(RRDFCardinalityField &&other) = default;
~RRDFCardinalityField() override = default;

// Field is only used for reading
void GenerateColumns() final { throw RException(R__FAIL("Cardinality fields must only be used for reading")); }
void GenerateColumns(const ROOT::RNTupleDescriptor &desc) final
{
GenerateColumnsImpl<ROOT::Internal::RColumnIndex>(desc);
}

public:
RRDFCardinalityFieldBase(RRDFCardinalityFieldBase &&other) = default;
RRDFCardinalityFieldBase &operator=(RRDFCardinalityFieldBase &&other) = default;
~RRDFCardinalityFieldBase() override = default;
Comment on lines +71 to +73
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For completeness

Suggested change
RRDFCardinalityFieldBase(RRDFCardinalityFieldBase &&other) = default;
RRDFCardinalityFieldBase &operator=(RRDFCardinalityFieldBase &&other) = default;
~RRDFCardinalityFieldBase() override = default;
RRDFCardinalityFieldBase(const RRDFCardinalityFieldBase &other) = delete;
RRDFCardinalityFieldBase &operator=(const RRDFCardinalityFieldBase &other) = delete;
RRDFCardinalityFieldBase(RRDFCardinalityFieldBase &&other) = default;
RRDFCardinalityFieldBase &operator=(RRDFCardinalityFieldBase &&other) = default;
~RRDFCardinalityFieldBase() override = default;


const RColumnRepresentations &GetColumnRepresentations() const final
{
Expand All @@ -85,23 +81,52 @@ class RRDFCardinalityField final : public ROOT::RFieldBase {
{});
return representations;
}
// Field is only used for reading
void GenerateColumns() final { throw RException(R__FAIL("Cardinality fields must only be used for reading")); }
void GenerateColumns(const ROOT::RNTupleDescriptor &desc) final
};

/// An artificial field that transforms an RNTuple column that contains the offset of collections into
/// collection sizes. It is used to provide the "number of" RDF columns for collections, e.g.
/// `R_rdf_sizeof_jets` for a collection named `jets`.
///
/// This is similar to the RCardinalityField but it presents itself as an integer type.
/// The template argument T must be an integral type.
template <typename T>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we could already SFINAE-out the non-integral types

class RRDFCardinalityField final : public RRDFCardinalityFieldBase {
inline void CheckSize(ROOT::NTupleSize_t size) const
{
GenerateColumnsImpl<ROOT::Internal::RColumnIndex>(desc);
if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, std::uint64_t>)
return;
if (size > std::numeric_limits<T>::max()) {
throw RException(R__FAIL(std::string("integer overflow in field ") + GetFieldName()));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
throw RException(R__FAIL(std::string("integer overflow in field ") + GetFieldName()));
throw RException(R__FAIL(std::string("integer overflow in field ") + GetFieldName() + ". Please read the column with a larger-sized integral type."));

}
}

size_t GetValueSize() const final { return sizeof(std::size_t); }
size_t GetAlignment() const final { return alignof(std::size_t); }
protected:
std::unique_ptr<ROOT::RFieldBase> CloneImpl(std::string_view newName) const final
{
return std::make_unique<RRDFCardinalityField>(newName);
}
void ConstructValue(void *where) const final { *static_cast<T *>(where) = 0; }

public:
RRDFCardinalityField(std::string_view name)
: RRDFCardinalityFieldBase(name, ROOT::Internal::GetRenormalizedTypeName(typeid(T)))
{
}
RRDFCardinalityField(RRDFCardinalityField &&other) = default;
RRDFCardinalityField &operator=(RRDFCardinalityField &&other) = default;
~RRDFCardinalityField() override = default;
Comment on lines +115 to +117
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here

Suggested change
RRDFCardinalityField(RRDFCardinalityField &&other) = default;
RRDFCardinalityField &operator=(RRDFCardinalityField &&other) = default;
~RRDFCardinalityField() override = default;
RRDFCardinalityField(const RRDFCardinalityField &other) = delete;
RRDFCardinalityField &operator=(const RRDFCardinalityField &other) = delete;
RRDFCardinalityField(RRDFCardinalityField &&other) = default;
RRDFCardinalityField &operator=(RRDFCardinalityField &&other) = default;
~RRDFCardinalityField() override = default;


size_t GetValueSize() const final { return sizeof(T); }
size_t GetAlignment() const final { return alignof(T); }
Comment on lines +119 to +120
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
size_t GetValueSize() const final { return sizeof(T); }
size_t GetAlignment() const final { return alignof(T); }
std::size_t GetValueSize() const final { return sizeof(T); }
std::size_t GetAlignment() const final { return alignof(T); }


/// Get the number of elements of the collection identified by globalIndex
void ReadGlobalImpl(ROOT::NTupleSize_t globalIndex, void *to) final
{
RNTupleLocalIndex collectionStart;
ROOT::NTupleSize_t size;
fPrincipalColumn->GetCollectionInfo(globalIndex, &collectionStart, &size);
*static_cast<std::size_t *>(to) = size;
CheckSize(size);
*static_cast<T *>(to) = size;
}

/// Get the number of elements of the collection identified by clusterIndex
Expand All @@ -110,7 +135,8 @@ class RRDFCardinalityField final : public ROOT::RFieldBase {
RNTupleLocalIndex collectionStart;
ROOT::NTupleSize_t size;
fPrincipalColumn->GetCollectionInfo(localIndex, &collectionStart, &size);
*static_cast<std::size_t *>(to) = size;
CheckSize(size);
*static_cast<T *>(to) = size;
}
};

Expand Down Expand Up @@ -144,7 +170,8 @@ class RArraySizeField final : public ROOT::RFieldBase {

public:
RArraySizeField(std::string_view name, std::size_t arrayLength)
: ROOT::RFieldBase(name, "std::size_t", ROOT::ENTupleStructure::kPlain, false /* isSimple */),
: ROOT::RFieldBase(name, ROOT::Internal::GetRenormalizedTypeName(typeid(std::size_t)),
ROOT::ENTupleStructure::kPlain, false /* isSimple */),
fArrayLength(arrayLength)
{
}
Expand Down Expand Up @@ -325,6 +352,18 @@ void ROOT::RDF::RNTupleDS::AddField(const ROOT::RNTupleDescriptor &desc, std::st
if (!fieldOrException)
return;
auto valueField = fieldOrException.Unwrap();
if (const auto cardinalityField = dynamic_cast<const ROOT::RCardinalityField *>(valueField.get())) {
// Cardinality fields in RDataFrame are presented as integers
if (cardinalityField->As32Bit()) {
valueField =
std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::uint32_t>>(fieldDesc.GetFieldName());
} else if (cardinalityField->As64Bit()) {
valueField =
std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::uint64_t>>(fieldDesc.GetFieldName());
} else {
R__ASSERT(false);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would a message like the following make sense?

Suggested change
R__ASSERT(false);
R__ASSERT(false && "cardinality field stored with an incompatible integer type");

}
}
valueField->SetOnDiskId(fieldId);
for (auto &f : *valueField) {
f.SetOnDiskId(desc.FindFieldId(f.GetFieldName(), f.GetParent()->GetOnDiskId()));
Expand All @@ -337,7 +376,7 @@ void ROOT::RDF::RNTupleDS::AddField(const ROOT::RNTupleDescriptor &desc, std::st
if (info.fNRepetitions > 0) {
cardinalityField = std::make_unique<ROOT::Internal::RDF::RArraySizeField>(name, info.fNRepetitions);
} else {
cardinalityField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField>(name);
cardinalityField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::size_t>>(name);
}
cardinalityField->SetOnDiskId(info.fFieldId);
}
Expand Down Expand Up @@ -475,7 +514,7 @@ ROOT::RFieldBase *ROOT::RDF::RNTupleDS::GetFieldWithTypeChecks(std::string_view
// If the field corresponding to the provided name is not a cardinality column and the requested type is different
// from the proto field that was created when the data source was constructed, we first have to create an
// alternative proto field for the column reader. Otherwise, we can directly use the existing proto field.
if (fieldName.substr(0, 13) != "R_rdf_sizeof_" && requestedType != fColumnTypes[index]) {
if (requestedType != fColumnTypes[index]) {
auto &altProtoFields = fAlternativeProtoFields[index];

// If we can find the requested type in the registered alternative protofields, return the corresponding field
Expand All @@ -488,12 +527,41 @@ ROOT::RFieldBase *ROOT::RDF::RNTupleDS::GetFieldWithTypeChecks(std::string_view
}

// Otherwise, create a new protofield and register it in the alternatives before returning
auto newAltProtoFieldOrException = ROOT::RFieldBase::Create(std::string(fieldName), requestedType);
if (!newAltProtoFieldOrException) {
throw std::runtime_error("RNTupleDS: Could not create field with type \"" + requestedType +
"\" for column \"" + std::string(fieldName) + "\"");
std::unique_ptr<RFieldBase> newAltProtoField;
const std::string strName = std::string(fieldName);
if (dynamic_cast<ROOT::Internal::RDF::RRDFCardinalityFieldBase *>(fProtoFields[index].get())) {
if (requestedType == "bool") {
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<bool>>(strName);
} else if (requestedType == "char") {
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<char>>(strName);
} else if (requestedType == "std::int8_t") {
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::int8_t>>(strName);
} else if (requestedType == "std::uint8_t") {
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::uint8_t>>(strName);
} else if (requestedType == "std::int16_t") {
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::int16_t>>(strName);
} else if (requestedType == "std::uint16_t") {
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::uint16_t>>(strName);
} else if (requestedType == "std::int32_t") {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure, when we get to this point we have already normalized the requested type, right? e.g. int will already be std::int32_t here.

newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::int32_t>>(strName);
} else if (requestedType == "std::uint32_t") {
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::uint32_t>>(strName);
} else if (requestedType == "std::int64_t") {
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::int64_t>>(strName);
} else if (requestedType == "std::uint64_t") {
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::uint64_t>>(strName);
} else {
throw std::runtime_error("RNTupleDS: Could not create field with type \"" + requestedType +
"\" for column \"" + std::string(fieldName) + "\"");
}
} else {
auto newAltProtoFieldOrException = ROOT::RFieldBase::Create(strName, requestedType);
if (!newAltProtoFieldOrException) {
throw std::runtime_error("RNTupleDS: Could not create field with type \"" + requestedType +
"\" for column \"" + std::string(fieldName) + "\"");
}
newAltProtoField = newAltProtoFieldOrException.Unwrap();
}
auto newAltProtoField = newAltProtoFieldOrException.Unwrap();
newAltProtoField->SetOnDiskId(fProtoFields[index]->GetOnDiskId());
auto *newField = newAltProtoField.get();
altProtoFields.emplace_back(std::move(newAltProtoField));
Expand Down
5 changes: 1 addition & 4 deletions tree/dataframe/test/dataframe_snapshot_ntuple.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,7 @@ TEST(RDFSnapshotRNTuple, CardinalityColumns)
opts.fMode = "UPDATE";
opts.fOutputFormat = ROOT::RDF::ESnapshotOutputFormat::kRNTuple;
ROOT::RDataFrame df("ntuple", fileGuard.GetPath());

ROOT_EXPECT_WARNING(df.Snapshot("ntuple_snap", fileGuard.GetPath(), "", opts), "Snapshot",
"Column \"nElectrons\" is a read-only \"ROOT::RNTupleCardinality<std::uint32_t>\" column. It "
"will be snapshot as its inner type \"std::uint32_t\" instead.");
df.Snapshot("ntuple_snap", fileGuard.GetPath(), "", opts);

ROOT::RDataFrame sdf("ntuple_snap", fileGuard.GetPath());
EXPECT_EQ("std::uint32_t", sdf.GetColumnType("nElectrons"));
Expand Down
37 changes: 32 additions & 5 deletions tree/dataframe/test/datasource_ntuple.cxx
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <ROOT/RDataFrame.hxx>
#include <ROOT/RFieldUtils.hxx>
#include <ROOT/RNTupleDS.hxx>
#include <ROOT/RVec.hxx>

Expand All @@ -14,6 +15,7 @@
#include "ClassWithArrays.h"

#include <limits>
#include <typeinfo>

#include <TFile.h>

Expand Down Expand Up @@ -66,8 +68,10 @@ class RNTupleDSTest : public ::testing::Test {
auto fldElectron = model->MakeField<Electron>("electron");
fldElectron->pt = 137.0;
auto fldVecElectron = model->MakeField<std::vector<Electron>>("VecElectron");
fldVecElectron->push_back(*fldElectron);
fldVecElectron->push_back(*fldElectron);
for (int i = 0; i < 128; ++i)
fldVecElectron->push_back(*fldElectron);
auto fldNElectron = std::make_unique<ROOT::RField<ROOT::RNTupleCardinality<std::uint64_t>>>("nElectron");
model->AddProjectedField(std::move(fldNElectron), [](const std::string &) { return "VecElectron"; });
{
auto ntuple = RNTupleWriter::Recreate(std::move(model), fNtplName, fFileName);
ntuple->Fill();
Expand All @@ -84,7 +88,7 @@ TEST_F(RNTupleDSTest, ColTypeNames)
RNTupleDS ds(fNtplName, fFileName);

auto colNames = ds.GetColumnNames();
ASSERT_EQ(15, colNames.size());
ASSERT_EQ(16, colNames.size());

EXPECT_TRUE(ds.HasColumn("pt"));
EXPECT_TRUE(ds.HasColumn("energy"));
Expand All @@ -96,12 +100,14 @@ TEST_F(RNTupleDSTest, ColTypeNames)
EXPECT_TRUE(ds.HasColumn("R_rdf_sizeof_VecElectron"));
EXPECT_TRUE(ds.HasColumn("VecElectron.pt"));
EXPECT_TRUE(ds.HasColumn("R_rdf_sizeof_VecElectron.pt"));
EXPECT_TRUE(ds.HasColumn("nElectron"));
EXPECT_FALSE(ds.HasColumn("Address"));

EXPECT_STREQ("std::string", ds.GetTypeName("tag").c_str());
EXPECT_STREQ("float", ds.GetTypeName("energy").c_str());
EXPECT_STREQ("std::size_t", ds.GetTypeName("R_rdf_sizeof_jets").c_str());
EXPECT_EQ(ROOT::Internal::GetRenormalizedTypeName(typeid(std::size_t)), ds.GetTypeName("R_rdf_sizeof_jets"));
EXPECT_STREQ("ROOT::VecOps::RVec<std::int32_t>", ds.GetTypeName("rvec").c_str());
EXPECT_STREQ("std::uint64_t", ds.GetTypeName("nElectron").c_str());

try {
ds.GetTypeName("Address");
Expand Down Expand Up @@ -142,6 +148,27 @@ TEST_F(RNTupleDSTest, CardinalityColumn)
EXPECT_EQ(3, *max_rvec2);
}

TEST_F(RNTupleDSTest, ProjectedCardinalityColumn)
{
auto df = ROOT::RDF::FromRNTuple(fNtplName, fFileName);

EXPECT_EQ(128u, *df.Filter("nElectron == 128").Max("nElectron"));

EXPECT_EQ(128u, *df.Filter([](std::uint64_t x) { return x == 128; }, {"nElectron"}).Max("nElectron"));
EXPECT_EQ(128u, *df.Filter([](std::int32_t x) { return x == 128; }, {"nElectron"}).Max("nElectron"));
EXPECT_EQ(128u, *df.Filter([](std::uint32_t x) { return x == 128; }, {"nElectron"}).Max("nElectron"));
EXPECT_EQ(128u, *df.Filter([](std::int16_t x) { return x == 128; }, {"nElectron"}).Max("nElectron"));
EXPECT_EQ(128u, *df.Filter([](std::uint16_t x) { return x == 128; }, {"nElectron"}).Max("nElectron"));
EXPECT_EQ(128u, *df.Filter([](std::uint8_t x) { return x == 128; }, {"nElectron"}).Max("nElectron"));
EXPECT_EQ(128u, *df.Filter([](bool x) { return x; }, {"nElectron"}).Max("nElectron"));
try {
*df.Filter([](std::int8_t x) { return x == 0; }, {"nElectron"}).Count();
FAIL() << "integer overflow should fail";
} catch (const ROOT::RException &e) {
EXPECT_THAT(e.what(), ::testing::HasSubstr("integer overflow"));
}
}

static void ReadTest(const std::string &name, const std::string &fname)
{
auto df = ROOT::RDF::FromRNTuple(name, fname);
Expand Down Expand Up @@ -183,7 +210,7 @@ static void ReadTest(const std::string &name, const std::string &fname)
EXPECT_TRUE(All(rvec->at(0) == ROOT::RVecI{1, 2, 3}));
EXPECT_TRUE(All(vectorasrvec->at(0) == ROOT::RVecF{1.f, 2.f}));
EXPECT_FLOAT_EQ(137.0, sumElectronPt.GetValue());
EXPECT_FLOAT_EQ(2. * 137.0, sumVecElectronPt.GetValue());
EXPECT_FLOAT_EQ(128. * 137.0, sumVecElectronPt.GetValue());
}

static void ChainTest(const std::string &name, const std::string &fname)
Expand Down
Loading