@@ -718,20 +718,26 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
718718 SHARPY::Transceiver *tc, int64_t iNDims,
719719 int64_t *iGShapePtr, int64_t *iOffsPtr,
720720 void *iDataPtr, int64_t *iDataShapePtr,
721- int64_t *iDataStridesPtr, int64_t oNDims,
722- int64_t *oGShapePtr, int64_t *oOffsPtr,
721+ int64_t *iDataStridesPtr, int64_t *oOffsPtr,
723722 void *oDataPtr, int64_t *oDataShapePtr,
724723 int64_t *oDataStridesPtr, int64_t *axesPtr) {
725724#ifdef NO_TRANSCEIVER
726725 initMPIRuntime ();
727726 tc = SHARPY::getTransceiver ();
728727#endif
729728 if (!iGShapePtr || !iOffsPtr || !iDataPtr || !iDataShapePtr ||
730- !iDataStridesPtr || !oGShapePtr || !oOffsPtr || !oDataPtr ||
731- !oDataShapePtr || ! oDataStridesPtr || !tc) {
729+ !iDataStridesPtr || !oOffsPtr || !oDataPtr || !oDataShapePtr ||
730+ !oDataStridesPtr || !tc) {
732731 throw std::invalid_argument (" Fatal: received nullptr in reshape" );
733732 }
734733
734+ std::vector<int64_t > oGShape (iNDims);
735+ for (int64_t i = 0 ; i < iNDims; ++i) {
736+ oGShape[i] = iGShapePtr[axesPtr[i]];
737+ }
738+ auto *oGShapePtr = oGShape.data ();
739+ const auto oNDims = iNDims;
740+
735741 assert (std::accumulate (&iGShapePtr[0 ], &iGShapePtr[iNDims], 1 ,
736742 std::multiplies<int64_t >()) ==
737743 std::accumulate (&oGShapePtr[0 ], &oGShapePtr[oNDims], 1 ,
@@ -817,21 +823,21 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
817823 });
818824
819825 int lastOffset = 0 ;
820- for (size_t i = 0 ; i < nRanks; i ++) {
821- sendSizes[i ] = sendRankBuffer[i ].size ();
822- sendOffsets[i ] = lastOffset;
823- sendBuffer.insert (sendBuffer.end (), sendRankBuffer[i ].begin (),
824- sendRankBuffer[i ].end ());
825- lastOffset += sendSizes[i ];
826+ for (size_t rank = 0 ; rank < nRanks; rank ++) {
827+ sendSizes[rank ] = sendRankBuffer[rank ].size ();
828+ sendOffsets[rank ] = lastOffset;
829+ sendBuffer.insert (sendBuffer.end (), sendRankBuffer[rank ].begin (),
830+ sendRankBuffer[rank ].end ());
831+ lastOffset += sendSizes[rank ];
826832 }
827833
828834 output.localIndices ([&](const id &outputIndex) {
829835 id inputIndex = outputIndex.permute (axes);
830836 auto rank = getInputRank (parts, inputIndex[0 ]);
831837 ++receiveSizes[rank];
832838 });
833- for (size_t i = 1 ; i < nRanks; i ++) {
834- receiveOffsets[i ] = receiveOffsets[i - 1 ] + receiveSizes[i - 1 ];
839+ for (size_t rank = 1 ; rank < nRanks; rank ++) {
840+ receiveOffsets[rank ] = receiveOffsets[rank - 1 ] + receiveSizes[rank - 1 ];
835841 }
836842 }
837843
@@ -842,7 +848,7 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
842848
843849 {
844850 std::vector<std::vector<T>> receiveRankBuffer (nRanks);
845- for (int64_t rank = 0 ; rank < nRanks; ++rank) {
851+ for (size_t rank = 0 ; rank < nRanks; ++rank) {
846852 auto &rankBuffer = receiveRankBuffer[rank];
847853 rankBuffer.insert (
848854 rankBuffer.end (), receiveBuffer.begin () + receiveOffsets[rank],
@@ -866,12 +872,12 @@ template <typename T>
866872WaitHandleBase *
867873_idtr_copy_permute (SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr,
868874 int64_t iNOffs, void *iLOffsDescr, int64_t iNDims,
869- void *iDataDescr, int64_t oNSzs , void *oGShapeDescr ,
870- int64_t oNOffs , void *oLOffsDescr , int64_t oNDims ,
871- void *oDataDescr, int64_t axesSzs, void * axesDescr) {
875+ void *iDataDescr, int64_t oNOffs , void *oLOffsDescr ,
876+ int64_t oNDims , void *oDataDescr , int64_t axesSzs ,
877+ void *axesDescr) {
872878
873- if (!iGShapeDescr || !iLOffsDescr || !iDataDescr || !oGShapeDescr ||
874- !oLOffsDescr || ! oDataDescr || !axesDescr) {
879+ if (!iGShapeDescr || !iLOffsDescr || !iDataDescr || !oLOffsDescr ||
880+ !oDataDescr || !axesDescr) {
875881 throw std::invalid_argument (
876882 " Fatal error: received nullptr in update_halo." );
877883 }
@@ -882,15 +888,14 @@ _idtr_copy_permute(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr,
882888 MRIdx1d iGShape (iNSzs, iGShapeDescr);
883889 MRIdx1d iOffs (iNOffs, iLOffsDescr);
884890 SHARPY::UnrankedMemRefType<T> iData (iNDims, iDataDescr);
885- MRIdx1d oGShape (oNSzs, oGShapeDescr);
886891 MRIdx1d oOffs (oNOffs, oLOffsDescr);
887892 SHARPY::UnrankedMemRefType<T> oData (oNDims, oDataDescr);
888893 MRIdx1d axes (axesSzs, axesDescr);
889894
890- return _idtr_copy_permute<T>(
891- sharpyType, tc, iNDims, iGShape .data (), iOffs .data (), iData.data (),
892- iData. sizes (), iData.strides (), oNDims, oGShape .data (), oOffs .data (),
893- oData. data (), oData.sizes (), oData.strides (), axes.data ());
895+ return _idtr_copy_permute<T>(sharpyType, tc, iNDims, iGShape. data (),
896+ iOffs .data (), iData .data (), iData.sizes (),
897+ iData.strides (), oOffs .data (), oData .data (),
898+ oData.sizes (), oData.strides (), axes.data ());
894899}
895900
896901extern " C" {
@@ -919,12 +924,11 @@ TYPED_COPY_RESHAPE(i1, bool);
919924 void *_idtr_copy_permute_##_sfx( \
920925 SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr, \
921926 int64_t iNOffs, void *iLOffsDescr, int64_t iNDims, void *iLDescr, \
922- int64_t oNSzs, void *oGShapeDescr, int64_t oNOffs, void *oLOffsDescr, \
923- int64_t oNDims, void *oLDescr, int64_t axesSzs, void *axesDescr) { \
924- return _idtr_copy_permute<_typ>(tc, iNSzs, iGShapeDescr, iNOffs, \
925- iLOffsDescr, iNDims, iLDescr, oNSzs, \
926- oGShapeDescr, oNOffs, oLOffsDescr, oNDims, \
927- oLDescr, axesSzs, axesDescr); \
927+ int64_t oNOffs, void *oLOffsDescr, int64_t oNDims, void *oLDescr, \
928+ int64_t axesSzs, void *axesDescr) { \
929+ return _idtr_copy_permute<_typ>( \
930+ tc, iNSzs, iGShapeDescr, iNOffs, iLOffsDescr, iNDims, iLDescr, oNOffs, \
931+ oLOffsDescr, oNDims, oLDescr, axesSzs, axesDescr); \
928932 } \
929933 _Pragma (STRINGIFY(weak _mlir_ciface__idtr_copy_permute_##_sfx = \
930934 _idtr_copy_permute_##_sfx))
0 commit comments