Skip to content

Commit d96b415

Browse files
authored
perf: optimize ggml_ext_chunk (#1084)
1 parent 8f05f5b commit d96b415

File tree

2 files changed

+14
-24
lines changed

2 files changed

+14
-24
lines changed

common.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,12 @@ class GEGLU : public UnaryBlock {
194194
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
195195

196196
x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
197-
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0);
197+
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0, false);
198198
x = x_vec[0]; // [ne3, ne2, ne1, dim_out]
199199
auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out]
200200

201+
gate = ggml_cont(ctx->ggml_ctx, gate);
202+
201203
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
202204

203205
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]

ggml_extend.hpp

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -732,34 +732,22 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
732732
__STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_context* ctx,
733733
struct ggml_tensor* x,
734734
int num,
735-
int64_t dim) {
735+
int64_t dim,
736+
bool cont = true) {
736737
GGML_ASSERT(dim >= 0 && dim < 4);
737738
GGML_ASSERT(x->ne[dim] % num == 0);
738739

739-
int perm[4] = {0, 1, 2, 3};
740-
for (int i = dim; i < 3; ++i)
741-
perm[i] = perm[i + 1];
742-
perm[3] = dim;
743-
744-
int inv_perm[4];
745-
for (int i = 0; i < 4; ++i)
746-
inv_perm[perm[i]] = i;
747-
748-
if (dim != 3) {
749-
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
750-
x = ggml_cont(ctx, x);
751-
}
752-
753740
std::vector<struct ggml_tensor*> chunks;
754-
int64_t chunk_size = x->ne[3] / num;
741+
int64_t chunk_size = x->ne[dim] / num;
742+
int64_t stride = chunk_size * x->nb[dim];
743+
int64_t chunk_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]};
744+
chunk_ne[dim] = chunk_size;
755745
for (int i = 0; i < num; i++) {
756746
auto chunk = ggml_view_4d(
757747
ctx, x,
758-
x->ne[0], x->ne[1], x->ne[2], chunk_size,
759-
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * i * chunk_size);
760-
761-
if (dim != 3) {
762-
chunk = ggml_ext_torch_permute(ctx, chunk, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
748+
chunk_ne[0], chunk_ne[1], chunk_ne[2], chunk_ne[3],
749+
x->nb[1], x->nb[2], x->nb[3], stride * i);
750+
if (cont) {
763751
chunk = ggml_cont(ctx, chunk);
764752
}
765753
chunks.push_back(chunk);
@@ -772,7 +760,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor*
772760
// x: [ne3, ne2, ne1, ne0]
773761
// return: [ne3, ne2, ne1, ne0/2]
774762

775-
auto x_vec = ggml_ext_chunk(ctx, x, 2, 0);
763+
auto x_vec = ggml_ext_chunk(ctx, x, 2, 0, false);
776764
ggml_tensor* gate;
777765
if (gate_first) {
778766
gate = x_vec[0];
@@ -781,7 +769,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor*
781769
x = x_vec[0];
782770
gate = x_vec[1];
783771
}
784-
772+
gate = ggml_cont(ctx, gate);
785773
gate = ggml_silu_inplace(ctx, gate);
786774

787775
x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, ne0/2]

0 commit comments

Comments
 (0)