-
Notifications
You must be signed in to change notification settings - Fork 67
Bitonic_Sort #943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Bitonic_Sort #943
Conversation
|
[CI]: Can one of the admins verify this patch? |
| template<typename KeyType, typename ValueType, typename Comparator> | ||
| void compareExchangeWithPartner( | ||
| bool takeLarger, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) loPair, | ||
| NBL_CONST_REF_ARG(pair<KeyType, ValueType>) partnerLoPair, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) hiPair, | ||
| NBL_CONST_REF_ARG(pair<KeyType, ValueType>) partnerHiPair, | ||
| NBL_CONST_REF_ARG(Comparator) comp) | ||
| { | ||
| const bool loSelfSmaller = comp(loPair.first, partnerLoPair.first); | ||
| const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller; | ||
| loPair.first = takePartnerLo ? partnerLoPair.first : loPair.first; | ||
| loPair.second = takePartnerLo ? partnerLoPair.second : loPair.second; | ||
|
|
||
| const bool hiSelfSmaller = comp(hiPair.first, partnerHiPair.first); | ||
| const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller; | ||
| hiPair.first = takePartnerHi ? partnerHiPair.first : hiPair.first; | ||
| hiPair.second = takePartnerHi ? partnerHiPair.second : hiPair.second; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Define an operator= for pair, then this becomes
if(takePartnerLo)
loPair = pLoPair;
if(takePartnerHi)
hiPair = pHiPair;don't worry about branching, since everything's an assignment it'll just be OpSelects under the hood
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can't define an operator= in HLSL none of the operators which should return references can be defined in HLSL so assignment and array indexing (as well as compound assignment)
However all structs in HLSL are trivial so can be assigned
| template<typename KeyType, typename ValueType, typename Comparator> | ||
| void compareExchangeWithPartner( | ||
| bool takeLarger, | ||
| NBL_REF_ARG(KeyType) loKey, | ||
| NBL_CONST_REF_ARG(KeyType) partnerLoKey, | ||
| NBL_REF_ARG(KeyType) hiKey, | ||
| NBL_CONST_REF_ARG(KeyType) partnerHiKey, | ||
| NBL_REF_ARG(ValueType) loVal, | ||
| NBL_CONST_REF_ARG(ValueType) partnerLoVal, | ||
| NBL_REF_ARG(ValueType) hiVal, | ||
| NBL_CONST_REF_ARG(ValueType) partnerHiVal, | ||
| NBL_CONST_REF_ARG(Comparator) comp) | ||
| { | ||
| const bool loSelfSmaller = comp(loKey, partnerLoKey); | ||
| const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller; | ||
| loKey = takePartnerLo ? partnerLoKey : loKey; | ||
| loVal = takePartnerLo ? partnerLoVal : loVal; | ||
|
|
||
| const bool hiSelfSmaller = comp(hiKey, partnerHiKey); | ||
| const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller; | ||
| hiKey = takePartnerHi ? partnerHiKey : hiKey; | ||
| hiVal = takePartnerHi ? partnerHiVal : hiVal; | ||
| } | ||
|
|
||
|
|
||
| template<typename KeyType, typename ValueType, typename Comparator> | ||
| void compareSwap( | ||
| bool ascending, | ||
| NBL_REF_ARG(KeyType) loKey, | ||
| NBL_REF_ARG(KeyType) hiKey, | ||
| NBL_REF_ARG(ValueType) loVal, | ||
| NBL_REF_ARG(ValueType) hiVal, | ||
| NBL_CONST_REF_ARG(Comparator) comp) | ||
| { | ||
| const bool shouldSwap = comp(hiKey, loKey); | ||
|
|
||
| const bool doSwap = (shouldSwap == ascending); | ||
|
|
||
| KeyType tempKey = loKey; | ||
| loKey = doSwap ? hiKey : loKey; | ||
| hiKey = doSwap ? tempKey : hiKey; | ||
|
|
||
| ValueType tempVal = loVal; | ||
| loVal = doSwap ? hiVal : loVal; | ||
| hiVal = doSwap ? tempVal : hiVal; | ||
| } | ||
|
|
||
| template<typename KeyType, typename ValueType> | ||
| void swap( | ||
| NBL_REF_ARG(KeyType) loKey, | ||
| NBL_REF_ARG(KeyType) hiKey, | ||
| NBL_REF_ARG(ValueType) loVal, | ||
| NBL_REF_ARG(ValueType) hiVal) | ||
| { | ||
| KeyType tempKey = loKey; | ||
| loKey = hiKey; | ||
| hiKey = tempKey; | ||
|
|
||
| ValueType tempVal = loVal; | ||
| loVal = hiVal; | ||
| hiVal = tempVal; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete all of these and only use the versions that use pairs. It's more consistent
| template<typename KeyType, typename ValueType> | ||
| void swap( | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) loPair, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) hiPair) | ||
| { | ||
| pair<KeyType, ValueType> temp = loPair; | ||
| loPair = hiPair; | ||
| hiPair = temp; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work without a definition for operator= for pair. Your code compiles because you're not using it rn. We want to keep this version and rewrite all the swaps to use this version using pairs.
The definition for pair, the overload for operator= and this swap method all belong in https://github.com/Devsh-Graphics-Programming/Nabla/blob/master/include/nbl/builtin/hlsl/utility.hlsl, mimicking how std::pair is in the <utility> header in cpp. Move them over there.
| template<typename KeyType, typename ValueType, typename Comparator> | ||
| void compareSwap( | ||
| bool ascending, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) loPair, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) hiPair, | ||
| NBL_CONST_REF_ARG(Comparator) comp) | ||
| { | ||
| const bool shouldSwap = comp(hiPair.first, loPair.first); | ||
| const bool doSwap = (shouldSwap == ascending); | ||
|
|
||
| KeyType tempKey = loPair.first; | ||
| ValueType tempVal = loPair.second; | ||
| loPair.first = doSwap ? hiPair.first : loPair.first; | ||
| loPair.second = doSwap ? hiPair.second : loPair.second; | ||
| hiPair.first = doSwap ? tempKey : hiPair.first; | ||
| hiPair.second = doSwap ? tempVal : hiPair.second; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the definition for the swap below, this just becomes
if (doSwap)
swap(loPair, hiPair);| //template<typename T1, typename T2> | ||
| //struct pair | ||
| //{ | ||
| // using first_type = T1; | ||
| // using second_type = T2; | ||
| // | ||
| // first_type first; | ||
| // second_type second; | ||
| //}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this
include/nbl/builtin/hlsl/pair.hlsl
Outdated
| template<typename T1, typename T2> | ||
| struct pair | ||
| { | ||
| using first_type = T1; | ||
| using second_type = T2; | ||
|
|
||
| first_type first; | ||
| second_type second; | ||
| }; | ||
|
|
||
|
|
||
| // Helper to make a pair (similar to std::make_pair) | ||
| template<typename T1, typename T2> | ||
| pair<T1, T2> make_pair(T1 f, T2 s) | ||
| { | ||
| pair<T1, T2> p; | ||
| p.first = f; | ||
| p.second = s; | ||
| return p; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of this belongs in https://github.com/Devsh-Graphics-Programming/Nabla/blob/master/include/nbl/builtin/hlsl/utility.hlsl. Move the code over there and remove this file.
| static void mergeStage(uint32_t stage, bool bitonicAscending, uint32_t invocationID, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey, | ||
| NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mergeStage(..., loPair, hiPair)| const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loKey, threadStride); | ||
| const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiKey, threadStride); | ||
| const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loVal, threadStride); | ||
| const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiVal, threadStride); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you make this code work with pairs, you might want to do it like
pair pLoPair, pHiPair;
pLoPair.key = glsl::subgroupShuffleXor<key_t>(loPair.key, threadStride);
pHiPair.key = glsl::subgroupShuffleXor<key_t>(hiPair.key, threadStride);
pLoPair.value = glsl::subgroupShuffleXor<value_t>(loPair.value, threadStride);
pHiPair.value = glsl::subgroupShuffleXor<value_t>(hiPair.value, threadStride);| } | ||
| } | ||
|
|
||
| static void __call(bool ascending, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__call(ascending, loPair, hiPair)| static void mergeStage(NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor, uint32_t stage, bool bitonicAscending, uint32_t invocationID, | ||
| NBL_REF_ARG(nbl::hlsl::pair<key_t, value_t>) loPair, NBL_REF_ARG(nbl::hlsl::pair<key_t, value_t>) hiPair) | ||
| { | ||
| const uint32_t WorkgroupSize = config_t::WorkgroupSize; | ||
| const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2(); | ||
| comparator_t comp; | ||
|
|
||
| [unroll] | ||
| for (uint32_t pass = 0; pass <= stage; pass++) | ||
| { | ||
| if (pass) | ||
| sharedmemAccessor.workgroupExecutionAndMemoryBarrier(); | ||
|
|
||
| const uint32_t stridePower = (stage - pass + 1) + subgroupSizeLog2; | ||
| const uint32_t stride = 1u << stridePower; | ||
| const uint32_t threadStride = stride >> 1; | ||
|
|
||
| nbl::hlsl::pair<key_t, value_t> pLoPair = loPair; | ||
| shuffleXor(pLoPair, threadStride, sharedmemAccessor); | ||
| sharedmemAccessor.workgroupExecutionAndMemoryBarrier(); | ||
|
|
||
| nbl::hlsl::pair<key_t, value_t> pHiPair = hiPair; | ||
| shuffleXor(pHiPair, threadStride, sharedmemAccessor); | ||
|
|
||
| const bool isUpper = (invocationID & threadStride) != 0; | ||
| const bool takeLarger = isUpper == bitonicAscending; | ||
|
|
||
| nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, loPair, pLoPair, hiPair, pHiPair, comp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're using the shared memory accessor here, when you should be using an adaptor. The user only has to pass a generic sharedmem accessor, then the adaptor ensures its accesses are optimal. If changing accessor for adaptors is what breaks your code, we can look into making a specialized sharedmem adaptor for pairs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise, we're passing the burden of writing an optimal accessor to the user, which requires the user knowing the underlying impl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hold this one until Matt replies because the adaptor needs to change here
Removed commented-out template struct for pair.
| [unroll] | ||
| for (uint32_t strideLog = simpleLog - 1u; strideLog + 1u > 0u; strideLog--) | ||
| { | ||
| const uint32_t stride = 1u << strideLog; | ||
| [unroll] | ||
| for (uint32_t virtualThreadID = threadID; virtualThreadID < ElementsPerSimpleSort / 2; virtualThreadID += WorkgroupSize) | ||
| { | ||
| const uint32_t loIx = (((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u))) + offsetAccessor.offset; | ||
| const uint32_t hiIx = loIx | stride; | ||
|
|
||
| nbl::hlsl::pair<key_t, value_t> lopair, hipair; | ||
| accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, lopair); | ||
| accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hipair); | ||
|
|
||
| swap(lopair, hipair); | ||
|
|
||
| accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, lopair); | ||
| accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hipair); | ||
| } | ||
| sharedmemAccessor.workgroupExecutionAndMemoryBarrier(); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels overkill to do this. Instead of having to go over half the array to swap it around just make the E=2 struct specialization take an ascending parameter like you did for the subgroup sort. Set it to true by default, and in the calls above you need in this E > 2 struct pass ascending = !(WorkgroupID & 1) to ensure every even workgroup is sorted ascendingly and every odd workgroup is sorted descendingly. It's the same you did for the subgroup case.
Description
Testing
TODO list: