@@ -641,6 +641,7 @@ template <typename T> class ndarray {
641641 size_t size = lSize ();
642642 id idx = firstLocalIndex ();
643643 while (size--) {
644+ std::cout << " idx: " << idx[0 ] << " , " << idx[1 ] << std::endl;
644645 callback (idx);
645646 idx.next (_gShape);
646647 }
@@ -708,6 +709,52 @@ size_t getOutputRank(const std::vector<Parts> &parts, int64_t dim0) {
708709 return 0 ;
709710}
710711
712+ template <typename T> class WaitPermute {
713+ public:
714+ WaitPermute (SHARPY::Transceiver *tc, SHARPY::Transceiver::WaitHandle hdl,
715+ SHARPY::rank_type nRanks, std::vector<Parts> &&parts,
716+ std::vector<int64_t > &&axes, ndarray<T> &&output,
717+ std::vector<T> &&receiveBuffer, std::vector<int > &&receiveOffsets,
718+ std::vector<int > &&receiveSizes)
719+ : tc(tc), hdl(hdl), nRanks(nRanks), parts(std::move(parts)),
720+ axes (std::move(axes)), output(std::move(output)),
721+ receiveBuffer(std::move(receiveBuffer)),
722+ receiveOffsets(std::move(receiveOffsets)),
723+ receiveSizes(std::move(receiveSizes)) {}
724+
725+ void operator ()() {
726+ tc->wait (hdl);
727+ std::vector<std::vector<T>> receiveRankBuffer (nRanks);
728+ for (size_t rank = 0 ; rank < nRanks; ++rank) {
729+ auto &rankBuffer = receiveRankBuffer[rank];
730+ rankBuffer.insert (
731+ rankBuffer.end (), receiveBuffer.begin () + receiveOffsets[rank],
732+ receiveBuffer.begin () + receiveOffsets[rank] + receiveSizes[rank]);
733+ }
734+
735+ std::vector<size_t > receiveRankBufferCount (nRanks, 0 );
736+ output.localIndices ([&](const id &outputIndex) {
737+ id inputIndex = outputIndex.permute (axes);
738+ std::cout << " inputIndex: " << inputIndex[0 ] << " , " << inputIndex[1 ]
739+ << std::endl;
740+ auto rank = getInputRank (parts, inputIndex[0 ]);
741+ auto &count = receiveRankBufferCount[rank];
742+ output[outputIndex] = receiveRankBuffer[rank][count++];
743+ });
744+ }
745+
746+ private:
747+ SHARPY::Transceiver *tc;
748+ SHARPY::Transceiver::WaitHandle hdl;
749+ SHARPY::rank_type nRanks;
750+ std::vector<Parts> parts;
751+ std::vector<int64_t > axes;
752+ ndarray<T> output;
753+ std::vector<T> receiveBuffer;
754+ std::vector<int > receiveOffsets;
755+ std::vector<int > receiveSizes;
756+ };
757+
711758} // namespace
712759
713760// / @brief permute array
@@ -844,27 +891,20 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
844891 auto hdl = tc->alltoall (sendBuffer.data (), sendSizes.data (),
845892 sendOffsets.data (), sharpytype, receiveBuffer.data (),
846893 receiveSizes.data (), receiveOffsets.data ());
847- tc->wait (hdl);
848894
849- {
850- std::vector<std::vector<T>> receiveRankBuffer (nRanks);
851- for (size_t rank = 0 ; rank < nRanks; ++rank) {
852- auto &rankBuffer = receiveRankBuffer[rank];
853- rankBuffer.insert (
854- rankBuffer.end (), receiveBuffer.begin () + receiveOffsets[rank],
855- receiveBuffer.begin () + receiveOffsets[rank] + receiveSizes[rank]);
856- }
895+ auto wait = WaitPermute (tc, hdl, nRanks, std::move (parts), std::move (axes),
896+ std::move (output), std::move (receiveBuffer),
897+ std::move (receiveOffsets), std::move (receiveSizes));
857898
858- std::vector<size_t > receiveRankBufferCount (nRanks);
859- output.localIndices ([&](const id &outputIndex) {
860- id inputIndex = outputIndex.permute (axes);
861- auto rank = getInputRank (parts, inputIndex[0 ]);
862- auto &count = receiveRankBufferCount[rank];
863- output[outputIndex] = receiveRankBuffer[rank][count++];
864- });
899+ assert (parts.empty () && axes.empty () && receiveBuffer.empty () &&
900+ receiveOffsets.empty () && receiveSizes.empty ());
901+
902+ if (no_async) {
903+ wait ();
904+ return nullptr ;
865905 }
866906
867- return nullptr ;
907+ return mkWaitHandle ( std::move (wait)) ;
868908}
869909
870910// / @brief permute array
0 commit comments