Skip to content

Commit 3229a23

Browse files
authored
vulkan: support GGML_OP_DIAG (#17893)
1 parent 303f861 commit 3229a23

File tree

3 files changed

+55
-0
lines changed

3 files changed

+55
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@ struct vk_device_struct {
659659
vk_pipeline pipeline_cos_f32;
660660
vk_pipeline pipeline_log[2];
661661
vk_pipeline pipeline_tri[2];
662+
vk_pipeline pipeline_diag[2];
662663
vk_pipeline pipeline_clamp_f32;
663664
vk_pipeline pipeline_pad_f32;
664665
vk_pipeline pipeline_roll_f32;
@@ -3924,6 +3925,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
39243925
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
39253926
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
39263927

3928+
ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3929+
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3930+
39273931
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
39283932

39293933
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
@@ -8416,6 +8420,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
84168420
return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
84178421
}
84188422
return nullptr;
8423+
case GGML_OP_DIAG:
8424+
if (src0->type == dst->type &&
8425+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
8426+
return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16];
8427+
}
8428+
return nullptr;
84198429
case GGML_OP_CLAMP:
84208430
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
84218431
return ctx->device->pipeline_clamp_f32;
@@ -9109,6 +9119,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
91099119
case GGML_OP_COS:
91109120
case GGML_OP_LOG:
91119121
case GGML_OP_TRI:
9122+
case GGML_OP_DIAG:
91129123
case GGML_OP_CLAMP:
91139124
case GGML_OP_PAD:
91149125
case GGML_OP_ROLL:
@@ -9796,6 +9807,12 @@ static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const
97969807
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
97979808
}
97989809

9810+
static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9811+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
9812+
9813+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p));
9814+
}
9815+
97999816
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
98009817
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
98019818
p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -11924,6 +11941,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1192411941
case GGML_OP_TRI:
1192511942
ggml_vk_tri(ctx, compute_ctx, src0, node);
1192611943

11944+
break;
11945+
case GGML_OP_DIAG:
11946+
ggml_vk_diag(ctx, compute_ctx, src0, node);
11947+
1192711948
break;
1192811949
case GGML_OP_CLAMP:
1192911950
ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -14067,6 +14088,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1406714088
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1406814089
case GGML_OP_LOG:
1406914090
case GGML_OP_TRI:
14091+
case GGML_OP_DIAG:
1407014092
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1407114093
op->type == op->src[0]->type;
1407214094
case GGML_OP_ARGSORT:
@@ -14657,6 +14679,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1465714679
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
1465814680
} else if (tensor->op == GGML_OP_TRI) {
1465914681
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
14682+
} else if (tensor->op == GGML_OP_DIAG) {
14683+
tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
1466014684
} else if (tensor->op == GGML_OP_CLAMP) {
1466114685
const float * params = (const float *)tensor->op_params;
1466214686
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#version 450
2+
3+
#include "rte.glsl"
4+
#include "types.glsl"
5+
#include "generic_unary_head.glsl"
6+
7+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
8+
9+
void main() {
10+
const uint idx = get_idx();
11+
12+
if (idx >= p.ne) {
13+
return;
14+
}
15+
16+
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
17+
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
18+
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
19+
const uint i12_offset = i12*p.ne11*p.ne10;
20+
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
21+
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
22+
23+
if (i10 == i11) {
24+
const float val = float(data_a[get_aoffset() + i13*p.nb03 + i12*p.nb02 + 0*p.nb01 + i10*p.nb00]);
25+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
26+
} else {
27+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
28+
}
29+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,8 @@ void process_shaders() {
854854

855855
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
856856
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
857+
string_to_spv("diag_f16", "diag.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
858+
string_to_spv("diag_f32", "diag.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
857859

858860
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
859861
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

0 commit comments

Comments
 (0)