@@ -732,34 +732,22 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
732732__STATIC_INLINE__ std::vector<struct ggml_tensor *> ggml_ext_chunk (struct ggml_context * ctx,
733733 struct ggml_tensor * x,
734734 int num,
735- int64_t dim) {
735+ int64_t dim,
736+ bool cont = true ) {
736737 GGML_ASSERT (dim >= 0 && dim < 4 );
737738 GGML_ASSERT (x->ne [dim] % num == 0 );
738739
739- int perm[4 ] = {0 , 1 , 2 , 3 };
740- for (int i = dim; i < 3 ; ++i)
741- perm[i] = perm[i + 1 ];
742- perm[3 ] = dim;
743-
744- int inv_perm[4 ];
745- for (int i = 0 ; i < 4 ; ++i)
746- inv_perm[perm[i]] = i;
747-
748- if (dim != 3 ) {
749- x = ggml_ext_torch_permute (ctx, x, perm[0 ], perm[1 ], perm[2 ], perm[3 ]);
750- x = ggml_cont (ctx, x);
751- }
752-
753740 std::vector<struct ggml_tensor *> chunks;
754- int64_t chunk_size = x->ne [3 ] / num;
741+ int64_t chunk_size = x->ne [dim] / num;
742+ int64_t stride = chunk_size * x->nb [dim];
743+ int64_t chunk_ne[4 ] = {x->ne [0 ], x->ne [1 ], x->ne [2 ], x->ne [3 ]};
744+ chunk_ne[dim] = chunk_size;
755745 for (int i = 0 ; i < num; i++) {
756746 auto chunk = ggml_view_4d (
757747 ctx, x,
758- x->ne [0 ], x->ne [1 ], x->ne [2 ], chunk_size,
759- x->nb [1 ], x->nb [2 ], x->nb [3 ], x->nb [3 ] * i * chunk_size);
760-
761- if (dim != 3 ) {
762- chunk = ggml_ext_torch_permute (ctx, chunk, inv_perm[0 ], inv_perm[1 ], inv_perm[2 ], inv_perm[3 ]);
748+ chunk_ne[0 ], chunk_ne[1 ], chunk_ne[2 ], chunk_ne[3 ],
749+ x->nb [1 ], x->nb [2 ], x->nb [3 ], stride * i);
750+ if (cont) {
763751 chunk = ggml_cont (ctx, chunk);
764752 }
765753 chunks.push_back (chunk);
@@ -772,7 +760,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor*
772760 // x: [ne3, ne2, ne1, ne0]
773761 // return: [ne3, ne2, ne1, ne0/2]
774762
775- auto x_vec = ggml_ext_chunk (ctx, x, 2 , 0 );
763+ auto x_vec = ggml_ext_chunk (ctx, x, 2 , 0 , false );
776764 ggml_tensor* gate;
777765 if (gate_first) {
778766 gate = x_vec[0 ];
@@ -781,7 +769,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor*
781769 x = x_vec[0 ];
782770 gate = x_vec[1 ];
783771 }
784-
772+ gate = ggml_cont (ctx, gate);
785773 gate = ggml_silu_inplace (ctx, gate);
786774
787775 x = ggml_mul (ctx, x, gate); // [ne3, ne2, ne1, ne0/2]
0 commit comments