@@ -646,6 +646,15 @@ template <typename T> class ndarray {
646646 }
647647 }
648648
649+ void globalIndices (const std::function<void (const id &)> &callback) const {
650+ size_t size = gSize ();
651+ id idx (_nDims);
652+ while (size--) {
653+ callback (idx);
654+ idx.next (_gShape);
655+ }
656+ }
657+
649658 int64_t getLocalDataOffset (const id &idx) const {
650659 auto localIdx = idx - _gOffsets;
651660 int64_t offset = 0 ;
@@ -711,14 +720,16 @@ size_t getOutputRank(const std::vector<Parts> &parts, int64_t dim0) {
711720template <typename T> class WaitPermute {
712721public:
713722 WaitPermute (SHARPY::Transceiver *tc, SHARPY::Transceiver::WaitHandle hdl,
714- SHARPY::rank_type nRanks, std::vector<Parts> &&parts,
715- std::vector<int64_t > &&axes, std::vector<int64_t > oGShape,
723+ SHARPY::rank_type cRank, SHARPY::rank_type nRanks,
724+ std::vector<Parts> &&parts, std::vector<int64_t > &&axes,
725+ std::vector<int64_t > oGShape, ndarray<T> &&input,
716726 ndarray<T> &&output, std::vector<T> &&receiveBuffer,
717727 std::vector<int > &&receiveOffsets,
718728 std::vector<int > &&receiveSizes)
719- : tc(tc), hdl(hdl), nRanks(nRanks), parts(std::move(parts)),
729+ : tc(tc), hdl(hdl), cRank(cRank), nRanks(nRanks), parts(std::move(parts)),
720730 axes (std::move(axes)), oGShape(std::move(oGShape)),
721- output(std::move(output)), receiveBuffer(std::move(receiveBuffer)),
731+ input(std::move(input)), output(std::move(output)),
732+ receiveBuffer(std::move(receiveBuffer)),
722733 receiveOffsets(std::move(receiveOffsets)),
723734 receiveSizes(std::move(receiveSizes)) {}
724735
@@ -733,9 +744,12 @@ template <typename T> class WaitPermute {
733744 }
734745
735746 std::vector<size_t > receiveRankBufferCount (nRanks, 0 );
736- output.localIndices ([&](const id &outputIndex) {
737- id inputIndex = outputIndex.permute (axes);
738- auto rank = getInputRank (parts, inputIndex[0 ]);
747+ input.globalIndices ([&](const id &inputIndex) {
748+ id outputIndex = inputIndex.permute (axes);
749+ auto rank = getOutputRank (parts, outputIndex[0 ]);
750+ if (rank != cRank)
751+ return ;
752+ rank = getInputRank (parts, inputIndex[0 ]);
739753 auto &count = receiveRankBufferCount[rank];
740754 output[outputIndex] = receiveRankBuffer[rank][count++];
741755 });
@@ -744,10 +758,12 @@ template <typename T> class WaitPermute {
744758private:
745759 SHARPY::Transceiver *tc;
746760 SHARPY::Transceiver::WaitHandle hdl;
761+ SHARPY::rank_type cRank;
747762 SHARPY::rank_type nRanks;
748763 std::vector<Parts> parts;
749764 std::vector<int64_t > axes;
750765 std::vector<int64_t > oGShape;
766+ ndarray<T> input;
751767 ndarray<T> output;
752768 std::vector<T> receiveBuffer;
753769 std::vector<int > receiveOffsets;
@@ -791,9 +807,9 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
791807 assert (std::accumulate (&oOffsPtr[1 ], &oOffsPtr[oNDims], 0 ,
792808 std::plus<int64_t >()) == 0 );
793809
794- auto nRanks = tc->nranks ();
795- auto rank = tc->rank ();
796- if (nRanks <= rank ) {
810+ const auto nRanks = tc->nranks ();
811+ const auto cRank = tc->rank ();
812+ if (nRanks <= cRank ) {
797813 throw std::out_of_range (" Fatal: rank must be < number of ranks" );
798814 }
799815
@@ -833,10 +849,10 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
833849
834850 // First we allgather the current and target partitioning
835851 std::vector<Parts> parts (nRanks);
836- parts[rank ].iStart = iOffsPtr[0 ];
837- parts[rank ].iEnd = iOffsPtr[0 ] + iDataShapePtr[0 ];
838- parts[rank ].oStart = oOffsPtr[0 ];
839- parts[rank ].oEnd = oOffsPtr[0 ] + oDataShapePtr[0 ];
852+ parts[cRank ].iStart = iOffsPtr[0 ];
853+ parts[cRank ].iEnd = iOffsPtr[0 ] + iDataShapePtr[0 ];
854+ parts[cRank ].oStart = oOffsPtr[0 ];
855+ parts[cRank ].oEnd = oOffsPtr[0 ] + oDataShapePtr[0 ];
840856 std::vector<int > counts (nRanks, 4 );
841857 std::vector<int > dspl (nRanks);
842858 for (auto i = 0ul ; i < nRanks; ++i) {
@@ -891,10 +907,10 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
891907 sendOffsets.data (), sharpytype, receiveBuffer.data (),
892908 receiveSizes.data (), receiveOffsets.data ());
893909
894- auto wait = WaitPermute (tc, hdl, nRanks, std::move (parts) , std::move (axes ),
895- std::move (oGShape), std::move (output ),
896- std::move (receiveBuffer ), std::move (receiveOffsets ),
897- std::move (receiveSizes));
910+ auto wait = WaitPermute (tc, hdl, cRank, nRanks , std::move (parts ),
911+ std::move (axes), std::move ( oGShape), std::move (input ),
912+ std::move (output ), std::move (receiveBuffer ),
913+ std::move (receiveOffsets), std::move ( receiveSizes));
898914
899915 assert (parts.empty () && axes.empty () && receiveBuffer.empty () &&
900916 receiveOffsets.empty () && receiveSizes.empty ());
0 commit comments