Skip to content

Commit 6209211

Browse files
committed
DPL: fix DataModelViews behavior
Also add tests to validate this is the case.
1 parent 056b5f4 commit 6209211

File tree

3 files changed

+435
-17
lines changed

3 files changed

+435
-17
lines changed

Framework/Core/include/Framework/DataModelViews.h

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ struct count_parts {
7070
count += 1;
7171
mi += header->splitPayloadParts + 1;
7272
} else {
73-
count += header->splitPayloadParts;
73+
count += header->splitPayloadParts ? header->splitPayloadParts : 1;
7474
mi += header->splitPayloadParts ? 2 * header->splitPayloadParts : 2;
7575
}
7676
}
@@ -104,11 +104,11 @@ struct get_pair {
104104
}
105105
mi += header->splitPayloadParts + 1;
106106
} else {
107-
count += header->splitPayloadParts ? header->splitPayloadParts : 1;
108-
if (self.pairId < count) {
109-
return {mi, mi + 2 * diff + 1};
107+
if (self.pairId == count) {
108+
return {mi, mi + 1};
110109
}
111-
mi += header->splitPayloadParts ? 2 * header->splitPayloadParts : 2;
110+
count += 1;
111+
mi += 2;
112112
}
113113
}
114114
throw std::runtime_error("Payload not found");
@@ -138,10 +138,10 @@ struct get_dataref_indices {
138138
mi += header->splitPayloadParts + 1;
139139
} else {
140140
if (self.part == count) {
141-
return {mi, mi + 2 * self.subPart + 1};
141+
return {mi, mi + self.subPart + 1};
142142
}
143143
count += 1;
144-
mi += header->splitPayloadParts ? 2 * header->splitPayloadParts : 2;
144+
mi += 2;
145145
}
146146
}
147147
throw std::runtime_error("Payload not found");
@@ -172,32 +172,41 @@ struct get_payload {
172172
};
173173

174174
struct get_num_payloads {
175-
size_t id;
176-
// ends the pipeline, returns the number of parts
175+
size_t n;
176+
// ends the pipeline, returns the number of payloads which are associated
177+
// to the multipart n-th sequence of messages found in the range
177178
template <typename R>
178179
requires std::ranges::random_access_range<R> && std::ranges::sized_range<R>
179180
friend size_t operator|(R&& r, get_num_payloads self)
180181
{
181182
size_t count = 0;
182183
size_t mi = 0;
184+
// Un
183185
while (mi < r.size()) {
184186
auto* header = o2::header::get<o2::header::DataHeader*>(r[mi]->GetData());
185187
if (!header) {
186188
throw std::runtime_error("Not a DataHeader");
187189
}
188-
if (self.id == count) {
189-
if (header->splitPayloadParts > 1 && (header->splitPayloadIndex == header->splitPayloadParts)) {
190+
if (header->splitPayloadParts > 1 && (header->splitPayloadIndex == header->splitPayloadParts)) {
191+
// This is the case for the new multi payload messages where the number of parts
192+
// is as many as the splitPayloadParts number.
193+
if (self.n == count) {
190194
return header->splitPayloadParts;
191-
} else {
192-
return 1;
193195
}
194-
}
195-
if (header->splitPayloadParts > 1 && (header->splitPayloadIndex == header->splitPayloadParts)) {
196+
// For multipayload we skip all the parts and their associated header
196197
count += 1;
197198
mi += header->splitPayloadParts + 1;
198199
} else {
199-
count += 1;
200-
mi += header->splitPayloadParts ? 2 * header->splitPayloadParts : 2;
200+
// This is the case of a multipart (header, payload), (header, payload), ...
201+
// sequence where we know how many pairs are there.
202+
// When splitPayloadParts == 0, it means it is a non-multipart (header, payload)
203+
// pair. Each pair has exactly 1 payload.
204+
auto pairs = header->splitPayloadParts ? header->splitPayloadParts : 1;
205+
if (self.n < count + pairs) {
206+
return 1;
207+
}
208+
count += pairs;
209+
mi += 2 * pairs;
201210
}
202211
}
203212
return 0;

Framework/Core/test/test_DataRelayer.cxx

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
#include "Framework/WorkflowSpec.h"
2727
#include <Monitoring/Monitoring.h>
2828
#include <fairmq/TransportFactory.h>
29+
#include <fairmq/Channel.h>
30+
#include "Framework/FairMQDeviceProxy.h"
31+
#include "Framework/ExpirationHandler.h"
32+
#include "Framework/LifetimeHelpers.h"
2933
#include <array>
3034
#include <vector>
3135
#include <uv.h>
@@ -808,4 +812,159 @@ TEST_CASE("DataRelayer")
808812
}
809813
}
810814
}
815+
816+
SECTION("ProcessDanglingInputs")
817+
{
818+
InputSpec spec{"condition", "TST", "COND"};
819+
std::vector<InputRoute> inputs = {
820+
InputRoute{spec, 0, "from_source_to_self", 0}};
821+
822+
std::vector<InputChannelInfo> infos{1};
823+
TimesliceIndex index{1, infos};
824+
ref.registerService(ServiceRegistryHelpers::handleForService<TimesliceIndex>(&index));
825+
826+
// Bind a fake input channel so FairMQDeviceProxy::getInputChannelIndex works
827+
FairMQDeviceProxy proxy;
828+
std::vector<fair::mq::Channel> channels{fair::mq::Channel("from_source_to_self")};
829+
auto findChannel = [&channels](std::string const& name) -> fair::mq::Channel& {
830+
for (auto& ch : channels) {
831+
if (ch.GetName() == name) {
832+
return ch;
833+
}
834+
}
835+
throw std::runtime_error("Channel not found: " + name);
836+
};
837+
proxy.bind({}, inputs, {}, findChannel, [] { return false; });
838+
ref.registerService(ServiceRegistryHelpers::handleForService<FairMQDeviceProxy>(&proxy));
839+
840+
auto policy = CompletionPolicyHelpers::consumeWhenAny();
841+
DataRelayer relayer(policy, inputs, index, {registry}, -1);
842+
relayer.setPipelineLength(4);
843+
844+
auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq");
845+
auto channelAlloc = o2::pmr::getTransportAllocator(transport.get());
846+
847+
DataHeader dh{"COND", "TST", 0};
848+
dh.splitPayloadParts = 1;
849+
dh.splitPayloadIndex = 0;
850+
DataProcessingHeader dph{0, 1};
851+
852+
ExpirationHandler handler;
853+
handler.name = "test-condition";
854+
handler.routeIndex = RouteIndex{0};
855+
handler.lifetime = Lifetime::Condition;
856+
857+
// Creator: claim an empty slot and assign timeslice 0 to it
858+
handler.creator = [](ServiceRegistryRef services, ChannelIndex channelIndex) -> TimesliceSlot {
859+
auto& index = services.get<TimesliceIndex>();
860+
for (size_t si = 0; si < index.size(); si++) {
861+
TimesliceSlot slot{si};
862+
if (!index.isValid(slot)) {
863+
index.associate(TimesliceId{0}, slot);
864+
index.setOldestPossibleInput({1}, channelIndex);
865+
return slot;
866+
}
867+
}
868+
return TimesliceSlot{TimesliceSlot::INVALID};
869+
};
870+
871+
// Checker: always trigger expiration
872+
handler.checker = LifetimeHelpers::expireAlways();
873+
874+
// Handler: materialise a dummy header+payload into the PartRef
875+
handler.handler = [&transport, &channelAlloc, &dh, &dph](ServiceRegistryRef, PartRef& ref, data_matcher::VariableContext&) {
876+
ref.header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph});
877+
ref.payload = transport->CreateMessage(4);
878+
};
879+
880+
std::vector<ExpirationHandler> handlers{handler};
881+
auto activity = relayer.processDanglingInputs(handlers, {registry}, true);
882+
883+
REQUIRE(activity.newSlots == 1);
884+
REQUIRE(activity.expiredSlots == 1);
885+
886+
// The materialised data should now be ready to consume
887+
std::vector<RecordAction> ready;
888+
relayer.getReadyToProcess(ready);
889+
REQUIRE(ready.size() == 1);
890+
REQUIRE(ready[0].op == CompletionPolicy::CompletionOp::Consume);
891+
892+
auto result = relayer.consumeAllInputsForTimeslice(ready[0].slot);
893+
REQUIRE(result.size() == 1);
894+
REQUIRE((result.at(0).messages | count_parts{}) == 1);
895+
}
896+
897+
SECTION("ProcessDanglingInputsSkipsWhenDataPresent")
898+
{
899+
// processDanglingInputs must not overwrite a slot that already has data.
900+
// This is guarded by the (part.messages | get_header{0}) != nullptr check.
901+
InputSpec spec{"condition", "TST", "COND"};
902+
std::vector<InputRoute> inputs = {
903+
InputRoute{spec, 0, "from_source_to_self", 0}};
904+
905+
std::vector<InputChannelInfo> infos{1};
906+
TimesliceIndex index{1, infos};
907+
ref.registerService(ServiceRegistryHelpers::handleForService<TimesliceIndex>(&index));
908+
909+
FairMQDeviceProxy proxy;
910+
std::vector<fair::mq::Channel> channels{fair::mq::Channel("from_source_to_self")};
911+
auto findChannel = [&channels](std::string const& name) -> fair::mq::Channel& {
912+
for (auto& ch : channels) {
913+
if (ch.GetName() == name) {
914+
return ch;
915+
}
916+
}
917+
throw std::runtime_error("Channel not found: " + name);
918+
};
919+
proxy.bind({}, inputs, {}, findChannel, [] { return false; });
920+
ref.registerService(ServiceRegistryHelpers::handleForService<FairMQDeviceProxy>(&proxy));
921+
922+
auto policy = CompletionPolicyHelpers::consumeWhenAny();
923+
DataRelayer relayer(policy, inputs, index, {registry}, -1);
924+
relayer.setPipelineLength(4);
925+
926+
auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq");
927+
auto channelAlloc = o2::pmr::getTransportAllocator(transport.get());
928+
929+
DataHeader dh{"COND", "TST", 0};
930+
dh.splitPayloadParts = 1;
931+
dh.splitPayloadIndex = 0;
932+
DataProcessingHeader dph{0, 1};
933+
934+
// Build an expiration handler that always tries to expire
935+
ExpirationHandler handler;
936+
handler.name = "test-condition";
937+
handler.routeIndex = RouteIndex{0};
938+
handler.lifetime = Lifetime::Condition;
939+
handler.creator = [](ServiceRegistryRef services, ChannelIndex channelIndex) -> TimesliceSlot {
940+
auto& index = services.get<TimesliceIndex>();
941+
for (size_t si = 0; si < index.size(); si++) {
942+
TimesliceSlot slot{si};
943+
if (!index.isValid(slot)) {
944+
index.associate(TimesliceId{0}, slot);
945+
index.setOldestPossibleInput({1}, channelIndex);
946+
return slot;
947+
}
948+
}
949+
return TimesliceSlot{TimesliceSlot::INVALID};
950+
};
951+
handler.checker = LifetimeHelpers::expireAlways();
952+
int handlerCallCount = 0;
953+
handler.handler = [&transport, &channelAlloc, &dh, &dph, &handlerCallCount](ServiceRegistryRef, PartRef& ref, data_matcher::VariableContext&) {
954+
ref.header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph});
955+
ref.payload = transport->CreateMessage(4);
956+
handlerCallCount++;
957+
};
958+
std::vector<ExpirationHandler> handlers{handler};
959+
960+
// First call: slot is empty, so the handler fires and materialises data
961+
auto activity1 = relayer.processDanglingInputs(handlers, {registry}, true);
962+
REQUIRE(activity1.expiredSlots == 1);
963+
REQUIRE(handlerCallCount == 1);
964+
965+
// Second call: slot already has data — the handler must NOT fire again
966+
auto activity2 = relayer.processDanglingInputs(handlers, {registry}, false);
967+
REQUIRE(activity2.expiredSlots == 0);
968+
REQUIRE(handlerCallCount == 1); // handler was not called a second time
969+
}
811970
}

0 commit comments

Comments
 (0)