@@ -575,6 +575,8 @@ _idtr_copy_reshape(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr,
575575 oData.data (), oData.sizes (), oData.strides ());
576576}
577577
578+ namespace {
579+
578580class id {
579581public:
580582 id (size_t dims) : _values(dims) {}
@@ -609,45 +611,38 @@ class id {
609611 return id (std::move (new_values));
610612 }
611613
614+ void next (const int64_t *shape) {
615+ size_t i = _values.size ();
616+ while (i--) {
617+ ++_values[i];
618+ if (_values[i] < shape[i]) {
619+ return ;
620+ }
621+ _values[i] = 0 ;
622+ }
623+ }
624+
612625 size_t size () { return _values.size (); }
613626
614627private:
615628 std::vector<int64_t > _values;
616629};
617630
618- id &next_idx (id &idx, const int64_t *shape) {
619- size_t i = idx.size ();
620- while (i--) {
621- ++idx[i];
622- if (idx[i] < shape[i]) {
623- return idx;
624- }
625- idx[i] = 0 ;
626- }
627- return idx;
628- }
629-
630631template <typename T> class ndarray {
631632public:
632633 ndarray (int64_t nDims, int64_t *gShape , int64_t *gOffsets , void *lData,
633634 int64_t *lShape, int64_t *lStrides)
634635 : _nDims(nDims), _gShape(gShape ), _gOffsets(gOffsets ), _lData((T *)lData),
635636 _lShape (lShape), _lStrides(lStrides) {}
636- // ndarray(std::vector<T> input, std::vector<int64_t> dims,
637- // std::vector<int64_t> strides);
638-
639- // id ids();
640- // id local_ids();
641637
642638 id firstLocalIndex () const { return id (_nDims, _gOffsets); }
643639
644640 void localIndices (const std::function<void (const id &)> &callback) const {
645641 size_t size = lSize ();
646642 id idx = firstLocalIndex ();
647643 while (size--) {
648- std::cout << " idx: " << idx[0 ] << " ," << idx[1 ] << std::endl;
649644 callback (idx);
650- next_idx ( idx, _gShape);
645+ idx. next ( _gShape);
651646 }
652647 }
653648
@@ -658,7 +653,6 @@ template <typename T> class ndarray {
658653 offset = (offset + localIdx[i]) * _lShape[i + 1 ];
659654 }
660655 offset += localIdx[_nDims - 1 ];
661- std::cout << " offset: " << offset << std::endl;
662656 return offset;
663657 }
664658
@@ -714,47 +708,7 @@ size_t getOutputRank(const std::vector<Parts> &parts, int64_t dim0) {
714708 return 0 ;
715709}
716710
717- // template <typename T>
718- // void permute(const ndarray<T> &input, const ndarray<T> &output, uint64_t
719- // nRanks,
720- // const std::vector<Parts> &parts,
721- // const std::vector<int64_t> &axes, ) {
722- // std::vector<std::vector<T>> sendBuffer(nRanks); // alltoall
723-
724- // input.permutedLocalIds(
725- // [&](const id &idx) {
726- // auto rank = getOutputRank(parts, idx[0]);
727- // sendBuffer[rank].push_back(input[idx]);
728- // },
729- // axes);
730-
731- // std::vector<int> receiveSizes(nRanks);
732- // std::vector<int> receiveOffsets(nRanks);
733-
734- // output.permutedLocalIds(
735- // [&](const id &idx) {
736- // auto rank = getInputRank(parts, idx[0]);
737- // ++receiveSizes[rank];
738- // },
739- // axes);
740- // for (size_t i = 1; i < nRanks; i++) {
741- // receiveOffsets[i] = receiveOffsets[i - 1] + receiveSizes[i - 1];
742- // }
743-
744- // return sendBuffer;
745- // }
746-
747- // template <typename T>
748- // void detranspose(std::vector<std::vector<T>> sendBuffer, ndarray<T> output,
749- // std::vector<int64_t> axes, uint64_t nRank) {
750- // std::vector<size_t> sendBufferIndex(sendBuffer.size());
751- // for (auto idx : output) {
752- // id in_idx = idx.permute(axes);
753- // auto i = sendBufferIndex[in_idx[0]];
754- // output[idx] = sendBuffer[in_idx[0]][i];
755- // sendBufferIndex[in_idx[0]] = i + 1;
756- // }
757- // }
711+ } // namespace
758712
759713// / @brief permute array
760714// / We assume array is partitioned along the first dimension (only) and
@@ -826,7 +780,6 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
826780 }
827781
828782 // First we allgather the current and target partitioning
829-
830783 std::vector<Parts> parts (nRanks);
831784 parts[rank].iStart = iOffsPtr[0 ];
832785 parts[rank].iEnd = iOffsPtr[0 ] + iDataShapePtr[0 ];
@@ -840,7 +793,7 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
840793 tc->gather (parts.data (), counts.data (), dspl.data (), SHARPY::INT64,
841794 SHARPY::REPLICATED);
842795
843- // transpose
796+ // Transpose
844797 ndarray<T> input (iNDims, iGShapePtr, iOffsPtr, iDataPtr, iDataShapePtr,
845798 iDataStridesPtr);
846799 ndarray<T> output (oNDims, oGShapePtr, oOffsPtr, oDataPtr, oDataShapePtr,
@@ -882,28 +835,11 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
882835 }
883836 }
884837
885- std::cout << " sendSizes: " << sendSizes[0 ] << std::endl;
886- std::cout << " sendOffsets: " << sendOffsets[0 ] << std::endl;
887- std::cout << " sendBuffer: " ;
888- for (int i = 0 ; i < sendBuffer.size (); ++i) {
889- std::cout << sendBuffer[i] << " ," ;
890- }
891- std::cout << std::endl;
892-
893- std::cout << " receiveSizes: " << receiveSizes[0 ] << std::endl;
894- std::cout << " receiveOffsets: " << receiveOffsets[0 ] << std::endl;
895-
896838 auto hdl = tc->alltoall (sendBuffer.data (), sendSizes.data (),
897839 sendOffsets.data (), sharpytype, receiveBuffer.data (),
898840 receiveSizes.data (), receiveOffsets.data ());
899841 tc->wait (hdl);
900842
901- std::cout << " receiveBuffer: " ;
902- for (int i = 0 ; i < receiveBuffer.size (); ++i) {
903- std::cout << receiveBuffer[i] << " ," ;
904- }
905- std::cout << std::endl;
906-
907843 {
908844 std::vector<std::vector<T>> receiveRankBuffer (nRanks);
909845 for (int64_t rank = 0 ; rank < nRanks; ++rank) {
0 commit comments