Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ extern "C" {
GGML_OP_SOFT_MAX,
GGML_OP_SOFT_MAX_BACK,
GGML_OP_ROPE,
GGML_OP_ROPE_COMP,
GGML_OP_ROPE_BACK,
GGML_OP_CLAMP,
GGML_OP_CONV_TRANSPOSE_1D,
Expand Down Expand Up @@ -1858,6 +1859,55 @@ extern "C" {
GGML_API void ggml_rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);


enum ggml_rope_ordering {
GGML_ROPE_ORDERING_NORMAL,
GGML_ROPE_ORDERING_NEOX,
};

// RoPE composable API
// note:
// theta_scale is usually powf(freq_base, -2.0f / (float)n_rot)
// each dimension i_dim is rotated by angle theta as follows:
// theta = pos[i_token] * theta_scale^i_dim
GGML_API struct ggml_tensor * ggml_rope_comp(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int32_t n_rot,
float theta_scale,
enum ggml_rope_ordering ordering);

GGML_API struct ggml_tensor * ggml_rope_comp_set_freq_factors(
struct ggml_context * ctx,
struct ggml_tensor * node,
struct ggml_tensor * freq_factors);

// set YaRN parameters
// note:
// freq_scale == 1.0f / scale_factor
// ramp_factor is usually 1.0f
// n_dims is usually n_rot, but can also be different because it is not used for indexing
GGML_API struct ggml_tensor * ggml_rope_comp_set_yarn(
struct ggml_context * ctx,
struct ggml_tensor * node,
int n_ctx_orig,
int n_dims,
float freq_base,
float freq_scale,
float ramp_factor,
float attn_factor,
float beta_fast,
float beta_slow);

// set M-RoPE mode
// pos tensor must have shape [n_tokens, 4]
GGML_API struct ggml_tensor * ggml_rope_comp_set_multi(
struct ggml_context * ctx,
struct ggml_tensor * node,
int mode,
int sections[GGML_MROPE_SECTIONS]);

// rotary position embedding backward, i.e compute dx from dy
// a - dy
GGML_API struct ggml_tensor * ggml_rope_ext_back(
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rope(params, tensor);
} break;
case GGML_OP_ROPE_COMP:
{
ggml_compute_forward_rope_comp(params, tensor);
} break;
case GGML_OP_ROPE_BACK:
{
ggml_compute_forward_rope_back(params, tensor);
Expand Down Expand Up @@ -2294,6 +2298,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ROPE:
case GGML_OP_ROPE_COMP:
case GGML_OP_ROPE_BACK:
case GGML_OP_ADD_REL_POS:
{
Expand Down Expand Up @@ -2812,6 +2817,7 @@ struct ggml_cplan ggml_graph_plan(
} break;
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
case GGML_OP_ROPE_COMP:
case GGML_OP_ROPE_BACK:
{
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
Expand Down
204 changes: 204 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5817,6 +5817,210 @@ void ggml_compute_forward_rope(
}
}

// ggml_compute_forward_rope_comp

enum ggml_rope_comp_mode {
GGML_ROPE_COMP_MODE_NORMAL,
GGML_ROPE_COMP_MODE_MROPE,
GGML_ROPE_COMP_MODE_IMROPE,
GGML_ROPE_COMP_MODE_VISION,
};

template<typename T, ggml_rope_comp_mode mode> // T = float or ggml_fp16_t
static void ggml_compute_forward_rope_comp_flt(
const ggml_compute_params * params,
ggml_tensor * dst,
const bool forward) {

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];

GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);

int32_t n_rot, idx_pair, idx_scale, idx_offset;
float theta_scale, yarn_high, yarn_low, freq_scale, ramp_factor, attn_factor;
int32_t sections[4];

memcpy(&n_rot, (int32_t *)dst->op_params + 0, sizeof(int32_t));
memcpy(&idx_pair, (int32_t *)dst->op_params + 1, sizeof(int32_t));
memcpy(&idx_scale, (int32_t *)dst->op_params + 2, sizeof(int32_t));
memcpy(&idx_offset, (int32_t *)dst->op_params + 3, sizeof(int32_t));
memcpy(&theta_scale, (int32_t *)dst->op_params + 4, sizeof(float));
memcpy(&yarn_high, (int32_t *)dst->op_params + 5, sizeof(float));
memcpy(&yarn_low, (int32_t *)dst->op_params + 6, sizeof(float));
memcpy(&freq_scale, (int32_t *)dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
memcpy(&ramp_factor, (int32_t *)dst->op_params + 9, sizeof(float));
memcpy(&sections[0], (int32_t *)dst->op_params + 10, sizeof(int32_t));
memcpy(&sections[1], (int32_t *)dst->op_params + 11, sizeof(int32_t));
memcpy(&sections[2], (int32_t *)dst->op_params + 12, sizeof(int32_t));
memcpy(&sections[3], (int32_t *)dst->op_params + 13, sizeof(int32_t));

GGML_TENSOR_UNARY_OP_LOCALS

GGML_ASSERT(nb0 == nb00);
GGML_ASSERT(nb0 == sizeof(T));

const int ith = params->ith;
const int nth = params->nth;

const int nr = ggml_nrows(dst);

GGML_ASSERT(n_rot <= ne0);
GGML_ASSERT(n_rot % 2 == 0);

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

// row index used to determine which thread to use
int ir = 0;

int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[0] + sections[1];

const float * freq_factors = NULL;
if (src2 != NULL) {
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src2->ne[0] >= n_rot / 2);
freq_factors = (const float *) src2->data;
}

// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
const float sin_sign = forward ? 1.0f : -1.0f;

const float * pos = (const float *) src1->data;

auto init_cache = [&](float * cache, int64_t i2) -> void {
for (int64_t i0 = 0; i0 < n_rot; i0 += 2) {
int64_t i_dim = i0/2;
float th_base = pos[i2]; // theta_base

// handle m-rope and vision rope
// dim: t = time, h = height, w = width, e = extra
if constexpr (mode == GGML_ROPE_COMP_MODE_MROPE) {
int sector = (i0 / 2) % sect_dims;
if (sector >= sections[0] && sector < sec_w) {
th_base = pos[i2 + ne2]; // h
} else if (sector >= sec_w && sector < sec_w + sections[2]) {
th_base = pos[i2 + ne2 * 2]; // w
} else if (sector >= sec_w + sections[2]) {
th_base = pos[i2 + ne2 * 3]; // e
} else {
th_base = pos[i2]; // t
}
} else if constexpr (mode == GGML_ROPE_COMP_MODE_IMROPE) {
int sector = (i0 / 2) % sect_dims;
if (sector % 3 == 1 && sector < 3 * sections[1]) {
th_base = pos[i2 + ne2]; // h
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
th_base = pos[i2 + ne2 * 2]; // w
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
th_base = pos[i2]; // t
} else {
th_base = pos[i2 + ne2 * 3]; // e
}
} else if constexpr (mode == GGML_ROPE_COMP_MODE_VISION) {
// for vision, we reset the dim index for each section
// it is equivalent to running 2 rope op separatedly
int sector = (i0 / 2) % sec_w;

// only 2 dims are supported for vision rope
if (sector < sections[0]) {
th_base = pos[i2];
i_dim = sector;
} else {
th_base = pos[i2 + ne2];
i_dim = sector - sections[0];
}
}

const float freq_factor = freq_factors ? freq_factors[i0/2] : 1.0f;

float theta = th_base * powf(theta_scale, i_dim) / freq_factor;
const float theta_extrap = theta;
const float theta_interp = freq_scale * theta;

if (ramp_factor != 0.0f) {
const float ramp_mix = rope_yarn_ramp(yarn_low, yarn_high, i0) * ramp_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
} else {
theta = theta_interp;
}

cache[i0 + 0] = cosf(theta) * attn_factor;
cache[i0 + 1] = sinf(theta) * attn_factor * sin_sign;
}
};

for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len

float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
init_cache(cache, i2);

for (int64_t i1 = idx_offset; i1 < ne1; i1++) { // attn-heads
if (ir++ < ir0) continue;
if (ir > ir1) break;

T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);

rotate_pairs<T>(n_rot, idx_pair, cache, src, dst_data, idx_scale);

// fill the remain channels with data from src tensor
for (int64_t i0 = n_rot; i0 < ne0; i0 += 2) {
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
} // attn-heads
}
}
}

void ggml_compute_forward_rope_comp(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];
const int mode = ((int32_t *) dst->op_params)[14];

bool is_mrope = mode == GGML_ROPE_TYPE_MROPE;
bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
bool is_vision = mode == GGML_ROPE_TYPE_VISION;

switch (src0->type) {
case GGML_TYPE_F16:
{
/**/ if (is_vision) ggml_compute_forward_rope_comp_flt<ggml_fp16_t, GGML_ROPE_COMP_MODE_VISION>(params, dst, true);
else if (is_imrope) ggml_compute_forward_rope_comp_flt<ggml_fp16_t, GGML_ROPE_COMP_MODE_IMROPE>(params, dst, true);
else if (is_mrope) ggml_compute_forward_rope_comp_flt<ggml_fp16_t, GGML_ROPE_COMP_MODE_MROPE> (params, dst, true);
else ggml_compute_forward_rope_comp_flt<ggml_fp16_t, GGML_ROPE_COMP_MODE_NORMAL>(params, dst, true);
} break;
case GGML_TYPE_F32:
{
/**/ if (is_vision) ggml_compute_forward_rope_comp_flt<float, GGML_ROPE_COMP_MODE_VISION>(params, dst, true);
else if (is_imrope) ggml_compute_forward_rope_comp_flt<float, GGML_ROPE_COMP_MODE_IMROPE>(params, dst, true);
else if (is_mrope) ggml_compute_forward_rope_comp_flt<float, GGML_ROPE_COMP_MODE_MROPE> (params, dst, true);
else ggml_compute_forward_rope_comp_flt<float, GGML_ROPE_COMP_MODE_NORMAL>(params, dst, true);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_rope_back

void ggml_compute_forward_rope_back(
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * para
void ggml_compute_forward_soft_max(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_soft_max_ext_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rope(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rope_comp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rope_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
case GGML_OP_ROPE:
case GGML_OP_ROPE_COMP:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
Expand Down
31 changes: 31 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,37 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_
return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope_comp(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_ROPE_COMP);

const int mode = ((const int32_t *) op->op_params)[14];

const bool is_mrope = mode == GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
const char * mode_name = "norm";
if (is_mrope) {
mode_name = "mrope";
} else if (is_imrope) {
mode_name = "imrope";
} else if (is_vision) {
mode_name = "vision";
}

char base[256];
char name[256];

snprintf(base, 256, "kernel_rope_comp_%s_%s", mode_name, ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_IM2COL);

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope_comp (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_RMS_NORM:
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
case GGML_OP_ROPE:
case GGML_OP_ROPE_COMP:
return true;
case GGML_OP_IM2COL:
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
Expand Down
Loading
Loading