diff --git a/tree/dataframe/src/RNTupleDS.cxx b/tree/dataframe/src/RNTupleDS.cxx index 261327e29794e..312baa948669d 100644 --- a/tree/dataframe/src/RNTupleDS.cxx +++ b/tree/dataframe/src/RNTupleDS.cxx @@ -27,11 +27,13 @@ #include #include +#include #include #include #include #include #include +#include #include // clang-format off @@ -49,32 +51,29 @@ // clang-format on namespace ROOT::Internal::RDF { -/// An artificial field that transforms an RNTuple column that contains the offset of collections into -/// collection sizes. It is used to provide the "number of" RDF columns for collections, e.g. -/// `R_rdf_sizeof_jets` for a collection named `jets`. -/// -/// This field owns the collection offset field but instead of exposing the collection offsets it exposes -/// the collection sizes (offset(N+1) - offset(N)). For the time being, we offer this functionality only in RDataFrame. -/// TODO(jblomer): consider providing a general set of useful virtual fields as part of RNTuple. -class RRDFCardinalityField final : public ROOT::RFieldBase { +class RRDFCardinalityFieldBase : public ROOT::RFieldBase { protected: - std::unique_ptr CloneImpl(std::string_view newName) const final - { - return std::make_unique(newName); - } - void ConstructValue(void *where) const final { *static_cast(where) = 0; } - // We construct these fields and know that they match the page source void ReconcileOnDiskField(const RNTupleDescriptor &) final {} -public: - RRDFCardinalityField(std::string_view name) - : ROOT::RFieldBase(name, "std::size_t", ROOT::ENTupleStructure::kPlain, false /* isSimple */) + RRDFCardinalityFieldBase(std::string_view name, std::string_view type) + : ROOT::RFieldBase(name, type, ROOT::ENTupleStructure::kPlain, false /* isSimple */) { } - RRDFCardinalityField(RRDFCardinalityField &&other) = default; - RRDFCardinalityField &operator=(RRDFCardinalityField &&other) = default; - ~RRDFCardinalityField() override = default; + + // Field is only used for reading + void GenerateColumns() final { throw RException(R__FAIL("Cardinality fields must only be used for reading")); } + void GenerateColumns(const ROOT::RNTupleDescriptor &desc) final + { + GenerateColumnsImpl(desc); + } + +public: + RRDFCardinalityFieldBase(const RRDFCardinalityFieldBase &other) = delete; + RRDFCardinalityFieldBase &operator=(const RRDFCardinalityFieldBase &other) = delete; + RRDFCardinalityFieldBase(RRDFCardinalityFieldBase &&other) = default; + RRDFCardinalityFieldBase &operator=(RRDFCardinalityFieldBase &&other) = default; + ~RRDFCardinalityFieldBase() override = default; const RColumnRepresentations &GetColumnRepresentations() const final { @@ -85,15 +84,48 @@ class RRDFCardinalityField final : public ROOT::RFieldBase { {}); return representations; } - // Field is only used for reading - void GenerateColumns() final { throw RException(R__FAIL("Cardinality fields must only be used for reading")); } - void GenerateColumns(const ROOT::RNTupleDescriptor &desc) final +}; + +/// An artificial field that transforms an RNTuple column that contains the offset of collections into +/// collection sizes. It is used to provide the "number of" RDF columns for collections, e.g. +/// `R_rdf_sizeof_jets` for a collection named `jets`. +/// +/// This is similar to the RCardinalityField but it presents itself as an integer type. +/// The template argument T must be an integral type. +template +class RRDFCardinalityField final : public RRDFCardinalityFieldBase { + static_assert(std::is_integral_v, "T must be an integral type"); + + inline void CheckSize(ROOT::NTupleSize_t size) const + { + if constexpr (std::is_same_v || std::is_same_v) + return; + if (size > std::numeric_limits::max()) { + throw RException(R__FAIL(std::string("integer overflow in field ") + GetFieldName() + + ". Please read the column with a larger-sized integral type.")); + } + } + +protected: + std::unique_ptr CloneImpl(std::string_view newName) const final + { + return std::make_unique(newName); + } + void ConstructValue(void *where) const final { *static_cast(where) = 0; } + +public: + RRDFCardinalityField(std::string_view name) + : RRDFCardinalityFieldBase(name, ROOT::Internal::GetRenormalizedTypeName(typeid(T))) { - GenerateColumnsImpl(desc); } + RRDFCardinalityField(const RRDFCardinalityField &other) = delete; + RRDFCardinalityField &operator=(const RRDFCardinalityField &other) = delete; + RRDFCardinalityField(RRDFCardinalityField &&other) = default; + RRDFCardinalityField &operator=(RRDFCardinalityField &&other) = default; + ~RRDFCardinalityField() override = default; - size_t GetValueSize() const final { return sizeof(std::size_t); } - size_t GetAlignment() const final { return alignof(std::size_t); } + std::size_t GetValueSize() const final { return sizeof(T); } + std::size_t GetAlignment() const final { return alignof(T); } /// Get the number of elements of the collection identified by globalIndex void ReadGlobalImpl(ROOT::NTupleSize_t globalIndex, void *to) final @@ -101,7 +133,8 @@ class RRDFCardinalityField final : public ROOT::RFieldBase { RNTupleLocalIndex collectionStart; ROOT::NTupleSize_t size; fPrincipalColumn->GetCollectionInfo(globalIndex, &collectionStart, &size); - *static_cast(to) = size; + CheckSize(size); + *static_cast(to) = size; } /// Get the number of elements of the collection identified by clusterIndex @@ -110,7 +143,8 @@ class RRDFCardinalityField final : public ROOT::RFieldBase { RNTupleLocalIndex collectionStart; ROOT::NTupleSize_t size; fPrincipalColumn->GetCollectionInfo(localIndex, &collectionStart, &size); - *static_cast(to) = size; + CheckSize(size); + *static_cast(to) = size; } }; @@ -144,7 +178,8 @@ class RArraySizeField final : public ROOT::RFieldBase { public: RArraySizeField(std::string_view name, std::size_t arrayLength) - : ROOT::RFieldBase(name, "std::size_t", ROOT::ENTupleStructure::kPlain, false /* isSimple */), + : ROOT::RFieldBase(name, ROOT::Internal::GetRenormalizedTypeName(typeid(std::size_t)), + ROOT::ENTupleStructure::kPlain, false /* isSimple */), fArrayLength(arrayLength) { } @@ -325,6 +360,18 @@ void ROOT::RDF::RNTupleDS::AddField(const ROOT::RNTupleDescriptor &desc, std::st if (!fieldOrException) return; auto valueField = fieldOrException.Unwrap(); + if (const auto cardinalityField = dynamic_cast(valueField.get())) { + // Cardinality fields in RDataFrame are presented as integers + if (cardinalityField->As32Bit()) { + valueField = + std::make_unique>(fieldDesc.GetFieldName()); + } else if (cardinalityField->As64Bit()) { + valueField = + std::make_unique>(fieldDesc.GetFieldName()); + } else { + R__ASSERT(false && "cardinality field stored with an unexpected integer type"); + } + } valueField->SetOnDiskId(fieldId); for (auto &f : *valueField) { f.SetOnDiskId(desc.FindFieldId(f.GetFieldName(), f.GetParent()->GetOnDiskId())); @@ -337,7 +384,7 @@ void ROOT::RDF::RNTupleDS::AddField(const ROOT::RNTupleDescriptor &desc, std::st if (info.fNRepetitions > 0) { cardinalityField = std::make_unique(name, info.fNRepetitions); } else { - cardinalityField = std::make_unique(name); + cardinalityField = std::make_unique>(name); } cardinalityField->SetOnDiskId(info.fFieldId); } @@ -475,7 +522,7 @@ ROOT::RFieldBase *ROOT::RDF::RNTupleDS::GetFieldWithTypeChecks(std::string_view // If the field corresponding to the provided name is not a cardinality column and the requested type is different // from the proto field that was created when the data source was constructed, we first have to create an // alternative proto field for the column reader. Otherwise, we can directly use the existing proto field. - if (fieldName.substr(0, 13) != "R_rdf_sizeof_" && requestedType != fColumnTypes[index]) { + if (requestedType != fColumnTypes[index]) { auto &altProtoFields = fAlternativeProtoFields[index]; // If we can find the requested type in the registered alternative protofields, return the corresponding field @@ -488,12 +535,41 @@ ROOT::RFieldBase *ROOT::RDF::RNTupleDS::GetFieldWithTypeChecks(std::string_view } // Otherwise, create a new protofield and register it in the alternatives before returning - auto newAltProtoFieldOrException = ROOT::RFieldBase::Create(std::string(fieldName), requestedType); - if (!newAltProtoFieldOrException) { - throw std::runtime_error("RNTupleDS: Could not create field with type \"" + requestedType + - "\" for column \"" + std::string(fieldName) + "\""); + std::unique_ptr newAltProtoField; + const std::string strName = std::string(fieldName); + if (dynamic_cast(fProtoFields[index].get())) { + if (requestedType == "bool") { + newAltProtoField = std::make_unique>(strName); + } else if (requestedType == "char") { + newAltProtoField = std::make_unique>(strName); + } else if (requestedType == "std::int8_t") { + newAltProtoField = std::make_unique>(strName); + } else if (requestedType == "std::uint8_t") { + newAltProtoField = std::make_unique>(strName); + } else if (requestedType == "std::int16_t") { + newAltProtoField = std::make_unique>(strName); + } else if (requestedType == "std::uint16_t") { + newAltProtoField = std::make_unique>(strName); + } else if (requestedType == "std::int32_t") { + newAltProtoField = std::make_unique>(strName); + } else if (requestedType == "std::uint32_t") { + newAltProtoField = std::make_unique>(strName); + } else if (requestedType == "std::int64_t") { + newAltProtoField = std::make_unique>(strName); + } else if (requestedType == "std::uint64_t") { + newAltProtoField = std::make_unique>(strName); + } else { + throw std::runtime_error("RNTupleDS: Could not create field with type \"" + requestedType + + "\" for column \"" + std::string(fieldName) + "\""); + } + } else { + auto newAltProtoFieldOrException = ROOT::RFieldBase::Create(strName, requestedType); + if (!newAltProtoFieldOrException) { + throw std::runtime_error("RNTupleDS: Could not create field with type \"" + requestedType + + "\" for column \"" + std::string(fieldName) + "\""); + } + newAltProtoField = newAltProtoFieldOrException.Unwrap(); } - auto newAltProtoField = newAltProtoFieldOrException.Unwrap(); newAltProtoField->SetOnDiskId(fProtoFields[index]->GetOnDiskId()); auto *newField = newAltProtoField.get(); altProtoFields.emplace_back(std::move(newAltProtoField)); diff --git a/tree/dataframe/test/dataframe_snapshot_ntuple.cxx b/tree/dataframe/test/dataframe_snapshot_ntuple.cxx index 0d5fdd6acf7ed..6784d4f559bb4 100644 --- a/tree/dataframe/test/dataframe_snapshot_ntuple.cxx +++ b/tree/dataframe/test/dataframe_snapshot_ntuple.cxx @@ -563,10 +563,7 @@ TEST(RDFSnapshotRNTuple, CardinalityColumns) opts.fMode = "UPDATE"; opts.fOutputFormat = ROOT::RDF::ESnapshotOutputFormat::kRNTuple; ROOT::RDataFrame df("ntuple", fileGuard.GetPath()); - - ROOT_EXPECT_WARNING(df.Snapshot("ntuple_snap", fileGuard.GetPath(), "", opts), "Snapshot", - "Column \"nElectrons\" is a read-only \"ROOT::RNTupleCardinality\" column. It " - "will be snapshot as its inner type \"std::uint32_t\" instead."); + df.Snapshot("ntuple_snap", fileGuard.GetPath(), "", opts); ROOT::RDataFrame sdf("ntuple_snap", fileGuard.GetPath()); EXPECT_EQ("std::uint32_t", sdf.GetColumnType("nElectrons")); diff --git a/tree/dataframe/test/datasource_ntuple.cxx b/tree/dataframe/test/datasource_ntuple.cxx index 9e25a83f1a060..a264614ed78b1 100644 --- a/tree/dataframe/test/datasource_ntuple.cxx +++ b/tree/dataframe/test/datasource_ntuple.cxx @@ -1,4 +1,5 @@ #include +#include #include #include @@ -14,6 +15,7 @@ #include "ClassWithArrays.h" #include +#include #include @@ -66,8 +68,10 @@ class RNTupleDSTest : public ::testing::Test { auto fldElectron = model->MakeField("electron"); fldElectron->pt = 137.0; auto fldVecElectron = model->MakeField>("VecElectron"); - fldVecElectron->push_back(*fldElectron); - fldVecElectron->push_back(*fldElectron); + for (int i = 0; i < 128; ++i) + fldVecElectron->push_back(*fldElectron); + auto fldNElectron = std::make_unique>>("nElectron"); + model->AddProjectedField(std::move(fldNElectron), [](const std::string &) { return "VecElectron"; }); { auto ntuple = RNTupleWriter::Recreate(std::move(model), fNtplName, fFileName); ntuple->Fill(); @@ -84,7 +88,7 @@ TEST_F(RNTupleDSTest, ColTypeNames) RNTupleDS ds(fNtplName, fFileName); auto colNames = ds.GetColumnNames(); - ASSERT_EQ(15, colNames.size()); + ASSERT_EQ(16, colNames.size()); EXPECT_TRUE(ds.HasColumn("pt")); EXPECT_TRUE(ds.HasColumn("energy")); @@ -96,12 +100,14 @@ TEST_F(RNTupleDSTest, ColTypeNames) EXPECT_TRUE(ds.HasColumn("R_rdf_sizeof_VecElectron")); EXPECT_TRUE(ds.HasColumn("VecElectron.pt")); EXPECT_TRUE(ds.HasColumn("R_rdf_sizeof_VecElectron.pt")); + EXPECT_TRUE(ds.HasColumn("nElectron")); EXPECT_FALSE(ds.HasColumn("Address")); EXPECT_STREQ("std::string", ds.GetTypeName("tag").c_str()); EXPECT_STREQ("float", ds.GetTypeName("energy").c_str()); - EXPECT_STREQ("std::size_t", ds.GetTypeName("R_rdf_sizeof_jets").c_str()); + EXPECT_EQ(ROOT::Internal::GetRenormalizedTypeName(typeid(std::size_t)), ds.GetTypeName("R_rdf_sizeof_jets")); EXPECT_STREQ("ROOT::VecOps::RVec", ds.GetTypeName("rvec").c_str()); + EXPECT_STREQ("std::uint64_t", ds.GetTypeName("nElectron").c_str()); try { ds.GetTypeName("Address"); @@ -142,6 +148,27 @@ TEST_F(RNTupleDSTest, CardinalityColumn) EXPECT_EQ(3, *max_rvec2); } +TEST_F(RNTupleDSTest, ProjectedCardinalityColumn) +{ + auto df = ROOT::RDF::FromRNTuple(fNtplName, fFileName); + + EXPECT_EQ(128u, *df.Filter("nElectron == 128").Max("nElectron")); + + EXPECT_EQ(128u, *df.Filter([](std::uint64_t x) { return x == 128; }, {"nElectron"}).Max("nElectron")); + EXPECT_EQ(128u, *df.Filter([](std::int32_t x) { return x == 128; }, {"nElectron"}).Max("nElectron")); + EXPECT_EQ(128u, *df.Filter([](std::uint32_t x) { return x == 128; }, {"nElectron"}).Max("nElectron")); + EXPECT_EQ(128u, *df.Filter([](std::int16_t x) { return x == 128; }, {"nElectron"}).Max("nElectron")); + EXPECT_EQ(128u, *df.Filter([](std::uint16_t x) { return x == 128; }, {"nElectron"}).Max("nElectron")); + EXPECT_EQ(128u, *df.Filter([](std::uint8_t x) { return x == 128; }, {"nElectron"}).Max("nElectron")); + EXPECT_EQ(128u, *df.Filter([](bool x) { return x; }, {"nElectron"}).Max("nElectron")); + try { + *df.Filter([](std::int8_t x) { return x == 0; }, {"nElectron"}).Count(); + FAIL() << "integer overflow should fail"; + } catch (const ROOT::RException &e) { + EXPECT_THAT(e.what(), ::testing::HasSubstr("integer overflow")); + } +} + static void ReadTest(const std::string &name, const std::string &fname) { auto df = ROOT::RDF::FromRNTuple(name, fname); @@ -183,7 +210,7 @@ static void ReadTest(const std::string &name, const std::string &fname) EXPECT_TRUE(All(rvec->at(0) == ROOT::RVecI{1, 2, 3})); EXPECT_TRUE(All(vectorasrvec->at(0) == ROOT::RVecF{1.f, 2.f})); EXPECT_FLOAT_EQ(137.0, sumElectronPt.GetValue()); - EXPECT_FLOAT_EQ(2. * 137.0, sumVecElectronPt.GetValue()); + EXPECT_FLOAT_EQ(128. * 137.0, sumVecElectronPt.GetValue()); } static void ChainTest(const std::string &name, const std::string &fname)