@@ -659,6 +659,7 @@ struct vk_device_struct {
659659 vk_pipeline pipeline_cos_f32;
660660 vk_pipeline pipeline_log[2];
661661 vk_pipeline pipeline_tri[2];
662+ vk_pipeline pipeline_diag[2];
662663 vk_pipeline pipeline_clamp_f32;
663664 vk_pipeline pipeline_pad_f32;
664665 vk_pipeline pipeline_roll_f32;
@@ -3924,6 +3925,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
39243925 ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
39253926 ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
39263927
3928+ ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3929+ ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3930+
39273931 ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
39283932
39293933 ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
@@ -8416,6 +8420,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
84168420 return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
84178421 }
84188422 return nullptr;
8423+ case GGML_OP_DIAG:
8424+ if (src0->type == dst->type &&
8425+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
8426+ return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16];
8427+ }
8428+ return nullptr;
84198429 case GGML_OP_CLAMP:
84208430 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
84218431 return ctx->device->pipeline_clamp_f32;
@@ -9109,6 +9119,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
91099119 case GGML_OP_COS:
91109120 case GGML_OP_LOG:
91119121 case GGML_OP_TRI:
9122+ case GGML_OP_DIAG:
91129123 case GGML_OP_CLAMP:
91139124 case GGML_OP_PAD:
91149125 case GGML_OP_ROLL:
@@ -9796,6 +9807,12 @@ static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const
97969807 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
97979808}
97989809
9810+ static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9811+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
9812+
9813+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p));
9814+ }
9815+
97999816static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
98009817 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
98019818 p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -11924,6 +11941,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1192411941 case GGML_OP_TRI:
1192511942 ggml_vk_tri(ctx, compute_ctx, src0, node);
1192611943
11944+ break;
11945+ case GGML_OP_DIAG:
11946+ ggml_vk_diag(ctx, compute_ctx, src0, node);
11947+
1192711948 break;
1192811949 case GGML_OP_CLAMP:
1192911950 ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -14067,6 +14088,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1406714088 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1406814089 case GGML_OP_LOG:
1406914090 case GGML_OP_TRI:
14091+ case GGML_OP_DIAG:
1407014092 return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1407114093 op->type == op->src[0]->type;
1407214094 case GGML_OP_ARGSORT:
@@ -14657,6 +14679,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1465714679 tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
1465814680 } else if (tensor->op == GGML_OP_TRI) {
1465914681 tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
14682+ } else if (tensor->op == GGML_OP_DIAG) {
14683+ tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
1466014684 } else if (tensor->op == GGML_OP_CLAMP) {
1466114685 const float * params = (const float *)tensor->op_params;
1466214686 tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
0 commit comments