diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 51bcce1ebb0..686ede9171a 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -65,6 +65,13 @@ struct clip_hparams { int32_t n_mel_bins = 0; // whisper preprocessor int32_t proj_stack_factor = 0; // ultravox + // audio-to-mel preprocessor params + int32_t audio_chunk_len = -1; // in seconds + int32_t audio_sample_rate = -1; + int32_t audio_n_fft = -1; + int32_t audio_window_len = -1; + int32_t audio_hop_len = -1; + // legacy bool has_llava_projector = false; int minicpmv_version = 0; @@ -277,3 +284,5 @@ struct clip_model { || proj_type == PROJECTOR_TYPE_VOXTRAL; } }; + +const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index bb922e30b43..f8615759841 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1141,11 +1141,15 @@ struct clip_model_loader { bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX || model.proj_type == PROJECTOR_TYPE_VOXTRAL; get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack); - if (hparams.n_mel_bins != 128) { - throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__)); - } hparams.ffn_op = FFN_GELU_ERF; log_ffn_op = "gelu_erf"; // temporary solution for logging + + // audio preprocessing params + hparams.audio_chunk_len = 30; // in seconds + hparams.audio_sample_rate = 16000; + hparams.audio_n_fft = 400; + hparams.audio_window_len = 400; + hparams.audio_hop_len = 160; } break; default: break; @@ -1183,6 +1187,11 @@ struct clip_model_loader { LOG_INF("\n--- audio hparams ---\n"); LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins); LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor); + LOG_INF("%s: audio_chunk_len: %d\n", __func__, hparams.audio_chunk_len); + LOG_INF("%s: audio_sample_rate: %d\n", __func__, hparams.audio_sample_rate); + LOG_INF("%s: audio_n_fft: %d\n", __func__, hparams.audio_n_fft); + LOG_INF("%s: audio_window_len: %d\n", __func__, hparams.audio_window_len); + LOG_INF("%s: audio_hop_len: %d\n", __func__, hparams.audio_hop_len); } LOG_INF("\n"); LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); @@ -3416,3 +3425,7 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel batch->entries.push_back(clip_image_f32_ptr(audio)); batch->is_audio = true; } + +const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx) { + return &ctx->model.hparams; +} diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp index 4d053895cda..f68829a61a4 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -11,63 +11,149 @@ // most of the code here is copied from whisper.cpp -// align x to upper multiple of n -#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) - -namespace whisper_preprocessor { - -#define SIN_COS_N_COUNT WHISPER_N_FFT -namespace { -struct whisper_global_cache { - // In FFT, we frequently use sine and cosine operations with the same values. - // We can use precalculated values to speed up the process. - float sin_vals[SIN_COS_N_COUNT]; - float cos_vals[SIN_COS_N_COUNT]; - - // Hann window (Use cosf to eliminate difference) - // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html - // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 - float hann_window[WHISPER_N_FFT]; - - whisper_global_cache() { - fill_sin_cos_table(); - fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window); - } - - void fill_sin_cos_table() { - for (int i = 0; i < SIN_COS_N_COUNT; i++) { - double theta = (2 * M_PI * i) / SIN_COS_N_COUNT; +constexpr bool DEBUG = false; + +struct mtmd_audio_mel_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +// note: this global cache is shared among all preprocessors +// if we want to use multiple preprocessors at the same time, +// we will need to enclose it in the preprocessor class in the future +static struct mtmd_audio_global_cache { + // precomputed sin/cos table for FFT + std::vector sin_vals; + std::vector cos_vals; + + // hann window + std::vector hann_window; + + // mel filter bank + mtmd_audio_mel_filters filters; + + void fill_sin_cos_table(int n) { + sin_vals.resize(n); + cos_vals.resize(n); + for (int i = 0; i < n; i++) { + double theta = (2 * M_PI * i) / n; sin_vals[i] = sinf(theta); cos_vals[i] = cosf(theta); } } - void fill_hann_window(int length, bool periodic, float * output) { + void fill_hann_window(int length, bool periodic) { + hann_window.resize(length); int offset = -1; if (periodic) { offset = 0; } for (int i = 0; i < length; i++) { - output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); } } -} global_cache; -} + + // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. + // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. + void fill_mel_filterbank_matrix( + int n_mel, + int n_fft, + int sample_rate, // e.g. 16000 + float fmin = 0.0f, // e.g. 0.0 + float fmax = -1.0f, // e.g. sr/2; pass -1 for auto + bool slaney_area_norm = true, + float scale = 1.0f // optional extra scaling; use 1.0f/1000.0f to mimic your code + ) { + GGML_ASSERT(n_mel > 0 && n_fft > 1); + if (fmax <= 0.0f) { + fmax = 0.5f * sample_rate; + } + + // Slaney scale (matches librosa default) + const double min_log_hz = 1000.0; + const double lin_slope = 3 / 200.; + const double min_log_mel = min_log_hz * lin_slope; + const double log_step = log(6.4) / 27.0; + auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { + return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; + }; + auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { + return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); + }; + + // infer N_fft from n_fft_bins + const double bin_hz_step = double(sample_rate) / double(n_fft); + + // mel grid: n_mel + 2 edges + const double m_lo = hz_to_mel(fmin); + const double m_hi = hz_to_mel(fmax); + std::vector mel_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); + } + + // convert to Hz + std::vector hz_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + hz_pts[i] = mel_to_hz(mel_pts[i]); + } + + const int n_fft_bins = n_fft / 2 + 1; + + // filterbank + std::vector out(n_mel * n_fft_bins, 0); + for (int m = 0; m < n_mel; ++m) { + const double f_left = hz_pts[m]; + const double f_center = hz_pts[m + 1]; + const double f_right = hz_pts[m + 2]; + + const double denom_l = std::max(1e-30, f_center - f_left); + const double denom_r = std::max(1e-30, f_right - f_center); + const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; + + for (int k = 0; k < n_fft_bins; ++k) { + const double f = k * bin_hz_step; + double w = 0.0; + if (f >= f_left && f <= f_center) { + w = (f - f_left) / denom_l; + } else if (f > f_center && f <= f_right) { + w = (f_right - f) / denom_r; + } + out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); + } + } + + filters.n_mel = n_mel; + filters.n_fft = n_fft; + filters.data = std::move(out); + + if (DEBUG) { // debug + for (size_t i = 0; i < filters.data.size(); ++i) { + if (filters.data[i] != 0.0f) { + printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); + } + } + } + } +} g_cache; // naive Discrete Fourier Transform // input is real-valued // output is complex-valued -static void dft(const float* in, int N, float* out) { - const int sin_cos_step = SIN_COS_N_COUNT / N; +static void dft(const float * in, int N, float * out) { + const int n_sin_cos_vals = g_cache.sin_vals.size(); + const int sin_cos_step = n_sin_cos_vals / N; for (int k = 0; k < N; k++) { float re = 0; float im = 0; for (int n = 0; n < N; n++) { - int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N - re += in[n]*global_cache.cos_vals[idx]; // cos(t) - im -= in[n]*global_cache.sin_vals[idx]; // sin(t) + int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N + re += in[n] * g_cache.cos_vals[idx]; // cos(t) + im -= in[n] * g_cache.sin_vals[idx]; // sin(t) } out[k*2 + 0] = re; @@ -79,7 +165,8 @@ static void dft(const float* in, int N, float* out) { // poor man's implementation - use something better // input is real-valued // output is complex-valued -static void fft(float* in, int N, float* out) { +static void fft(float * in, int N, float * out) { + const int n_sin_cos_vals = g_cache.sin_vals.size(); if (N == 1) { out[0] = in[0]; out[1] = 0; @@ -106,11 +193,11 @@ static void fft(float* in, int N, float* out) { float* odd_fft = even_fft + N; fft(odd, half_N, odd_fft); - const int sin_cos_step = SIN_COS_N_COUNT / N; + const int sin_cos_step = n_sin_cos_vals / N; for (int k = 0; k < half_N; k++) { int idx = k * sin_cos_step; // t = 2*M_PI*k/N - float re = global_cache.cos_vals[idx]; // cos(t) - float im = -global_cache.sin_vals[idx]; // sin(t) + float re = g_cache.cos_vals[idx]; // cos(t) + float im = -g_cache.sin_vals[idx]; // sin(t) float re_odd = odd_fft[2*k + 0]; float im_odd = odd_fft[2*k + 1]; @@ -123,20 +210,34 @@ static void fft(float* in, int N, float* out) { } } +struct filter_params { + int32_t n_mel; + int32_t n_fft_bins; + int32_t hann_window_size; + int32_t hop_length; + int32_t sample_rate; + bool center_padding = false; + float preemph = 0.f; + bool use_natural_log = false; + bool norm_per_feature = false; +}; + static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, int n_samples, int frame_size, int frame_step, int n_threads, - const whisper_filters & filters, whisper_mel & mel) { + const filter_params & params, mtmd_audio_mel & out) { std::vector fft_in(frame_size * 2, 0.0); std::vector fft_out(frame_size * 2 * 2 * 2); - int n_fft = filters.n_fft; + int n_fft_bins = params.n_fft_bins; int i = ith; - // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist - WHISPER_ASSERT(n_fft == 1 + (frame_size / 2)); + const auto & filters = g_cache.filters; + // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist + GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2)); + GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size()); // calculate FFT only when fft_in are not all zero - for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { + for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) { const int offset = i * frame_step; // apply Hann window (~10% faster) @@ -154,36 +255,39 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. - for (int j = 0; j < n_fft; j++) { + for (int j = 0; j < n_fft_bins; j++) { fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); } // mel spectrogram - for (int j = 0; j < mel.n_mel; j++) { + for (int j = 0; j < out.n_mel; j++) { double sum = 0.0; // unroll loop (suggested by GH user @lunixbochs) int k = 0; - for (k = 0; k < n_fft - 3; k += 4) { + for (k = 0; k < n_fft_bins - 3; k += 4) { + size_t idx = size_t(j) * size_t(n_fft_bins) + size_t(k); sum += - fft_out[k + 0] * filters.data[j * n_fft + k + 0] + - fft_out[k + 1] * filters.data[j * n_fft + k + 1] + - fft_out[k + 2] * filters.data[j * n_fft + k + 2] + - fft_out[k + 3] * filters.data[j * n_fft + k + 3]; + fft_out[k + 0] * filters.data[idx + 0] + + fft_out[k + 1] * filters.data[idx + 1] + + fft_out[k + 2] * filters.data[idx + 2] + + fft_out[k + 3] * filters.data[idx + 3]; } // handle n_fft remainder - for (; k < n_fft; k++) { - sum += fft_out[k] * filters.data[j * n_fft + k]; + for (; k < n_fft_bins; k++) { + sum += fft_out[k] * filters.data[j * n_fft_bins + k]; } - sum = log10(std::max(sum, 1e-10)); - mel.data[j * mel.n_len + i] = sum; + sum = params.use_natural_log + ? log(sum + 5.960464477539063e-08) + : log10(std::max(sum, 1e-10)); + out.data[j * out.n_len + i] = sum; } } // Otherwise fft_out are all zero - double sum = log10(1e-10); - for (; i < mel.n_len; i += n_threads) { - for (int j = 0; j < mel.n_mel; j++) { - mel.data[j * mel.n_len + i] = sum; + double sum = params.use_natural_log ? log(1e-10) : log10(1e-10); + for (; i < out.n_len; i += n_threads) { + for (int j = 0; j < out.n_mel; j++) { + out.data[j * out.n_len + i] = sum; } } } @@ -191,115 +295,212 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 static bool log_mel_spectrogram( const float * samples, - const int n_samples, - const int /*sample_rate*/, - const int frame_size, - const int frame_step, - const int n_mel, - const int n_threads, - const whisper_filters & filters, - const bool debug, - whisper_mel & mel) { + const int n_samples_in, + const int n_threads, + const filter_params & params, + mtmd_audio_mel & out) { //const int64_t t_start_us = ggml_time_us(); - // Hann window - WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size"); - const float * hann = global_cache.hann_window; + out.n_len_org = n_samples_in; + int n_samples = n_samples_in; - // Calculate the length of padding - int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; - int64_t stage_2_pad = frame_size / 2; + // Hann window + const float * hann = g_cache.hann_window.data(); + const int frame_size = (params.n_fft_bins - 1) * 2; + const int frame_step = params.hop_length; - // Initialize a vector and copy data from C array to it. + // Padding std::vector samples_padded; - samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); - std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); + if (params.center_padding) { + const auto pad_amount = frame_size / 2; + samples_padded = std::vector(n_samples + 2 * pad_amount, 0); + std::copy(samples, samples + n_samples, samples_padded.data() + pad_amount); + samples = samples_padded.data(); + n_samples = samples_padded.size(); + } else { + // existing padding logic + int64_t stage_1_pad = params.sample_rate * 30; + int64_t stage_2_pad = frame_size / 2; + samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); + std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); + // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio + std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); + // reflective pad 200 samples at the beginning of audio + if (n_samples < stage_2_pad + 1) { + // TODO: Handle short audio differently or return error + return false; + } + std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); + } - // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio - std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); + // preemphasis + if (params.preemph) { + const int pad_amount = frame_size / 2; + const float preemph = 0.97f; + float prev = samples_padded[pad_amount]; + for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) { + float cur = samples_padded[i]; + samples_padded[i] = cur - preemph * prev; + prev = cur; + } + } + + // pad hann window if it's smaller than frame_size + // TODO: probably unnecessary here? (or better doing it in g_cache?) + std::vector hann_window_padded; + if (params.hann_window_size < frame_size) { + hann_window_padded.resize(frame_size); + const int padding = (frame_size - params.hann_window_size) / 2; + std::copy(hann, hann + params.hann_window_size, &hann_window_padded[padding]); + hann = hann_window_padded.data(); + } - // reflective pad 200 samples at the beginning of audio - std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); - mel.n_mel = n_mel; - // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 - // Calculate number of frames + remove the last frame - mel.n_len = (samples_padded.size() - frame_size) / frame_step; - // Calculate semi-padded sample length to ensure compatibility - mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step; - mel.data.resize(mel.n_mel * mel.n_len); + out.n_mel = params.n_mel; + out.n_len = (n_samples - frame_size) / frame_step + 1; + // TODO: handle these checks better + if (out.n_mel > 0 && (unsigned long)out.n_len > SIZE_MAX / out.n_mel) { + LOG_ERR("%s: size overflow\n", __func__); + return false; + } + if (n_samples < frame_size) { + LOG_ERR("%s: not enough samples after padding\n", __func__); + return false; + } + out.data.resize(out.n_mel * out.n_len); { std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw] = std::thread( log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), - n_samples + stage_2_pad, frame_size, frame_step, n_threads, - std::cref(filters), std::ref(mel)); + n_samples, frame_size, frame_step, n_threads, + std::cref(params), std::ref(out)); } // main thread - log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel); - + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw].join(); } } - // clamping and normalization - double mmax = -1e20; - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { - if (mel.data[i] > mmax) { - mmax = mel.data[i]; - } - } + const int effective_n_len = n_samples_in / frame_step; + if (params.norm_per_feature) { + for (int i = 0; i < out.n_mel; i++) { + double mean = 0; + for (int j = 0; j < effective_n_len; ++j) { + mean += out.data[i * out.n_len + j]; + } + mean /= effective_n_len; - mmax -= 8.0; + double var = 0.0; + for (int j = 0; j < effective_n_len; ++j) { + const double value = out.data[i * out.n_len + j] - mean; + var += value * value; + } + var /= effective_n_len - 1; // unbiased + const double mstd = std::sqrt(var + 1e-5); + + for (int j = 0; j < effective_n_len; ++j) { + auto &value = out.data[i * out.n_len + j]; + value = (value - mean) / mstd; + } - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { - if (mel.data[i] < mmax) { - mel.data[i] = mmax; + // pad the rest with zeros + for (int j = effective_n_len; j < out.n_len; ++j) { + out.data[i * out.n_len + j] = 0.0; + } + } + } else { + // clamping and normalization + double mmax = -1e20; + for (int i = 0; i < out.n_mel*out.n_len; i++) { + if (out.data[i] > mmax) { + mmax = out.data[i]; + } } - mel.data[i] = (mel.data[i] + 4.0)/4.0; + mmax -= 8.0; + + for (int i = 0; i < out.n_mel*out.n_len; i++) { + if (out.data[i] < mmax) { + out.data[i] = mmax; + } + out.data[i] = (out.data[i] + 4.0)/4.0; + } } // Dump log_mel_spectrogram - if (debug) { + if (DEBUG) { std::ofstream outFile("log_mel_spectrogram.json"); outFile << "["; - for (uint64_t i = 0; i < mel.data.size() - 1; i++) { - outFile << mel.data[i] << ", "; + for (uint64_t i = 0; i < out.data.size() - 1; i++) { + outFile << out.data[i] << ", "; } - outFile << mel.data[mel.data.size() - 1] << "]"; + outFile << out.data[out.data.size() - 1] << "]"; outFile.close(); } return true; } -bool preprocess_audio( +// +// mtmd_audio_preprocessor_whisper +// + +void mtmd_audio_preprocessor_whisper::initialize() { + g_cache.fill_sin_cos_table(hparams.audio_n_fft); + g_cache.fill_hann_window(hparams.audio_window_len, true); + g_cache.fill_mel_filterbank_matrix( + hparams.n_mel_bins, + hparams.audio_n_fft, + hparams.audio_sample_rate); +} + +bool mtmd_audio_preprocessor_whisper::preprocess( const float * samples, size_t n_samples, - const whisper_filters & filters, - std::vector & output) { - + std::vector & output) { if (n_samples == 0) { // empty audio return false; } - whisper_mel out_full; + std::vector smpl; + // if input is too short, pad with zeros + // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram + // TODO: maybe handle this better + size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin + if (n_samples < min_samples) { + smpl.resize(min_samples, 0.0f); + std::memcpy(smpl.data(), samples, n_samples * sizeof(float)); + samples = smpl.data(); + n_samples = smpl.size(); + } + + filter_params params; + params.n_mel = hparams.n_mel_bins; + params.n_fft_bins = 1 + (hparams.audio_n_fft / 2); + params.hann_window_size = hparams.audio_window_len; + params.hop_length = hparams.audio_hop_len; + params.sample_rate = hparams.audio_sample_rate; + params.center_padding = false; + params.preemph = 0.0f; // disabled + params.use_natural_log = false; + params.norm_per_feature = false; + + // make sure the global cache is initialized + GGML_ASSERT(!g_cache.sin_vals.empty()); + GGML_ASSERT(!g_cache.cos_vals.empty()); + GGML_ASSERT(!g_cache.filters.data.empty()); + + mtmd_audio_mel out_full; bool ok = log_mel_spectrogram( samples, n_samples, - COMMON_SAMPLE_RATE, - WHISPER_N_FFT, - WHISPER_HOP_LENGTH, - filters.n_mel, 4, // n_threads - filters, - false, // debug + params, out_full); if (!ok) { return false; @@ -307,7 +508,9 @@ bool preprocess_audio( // because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel // we always expect the mel to have 3000 silent frames at the end - // printf("n_len %d\n", out_full.n_len); + if (DEBUG) { + printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len); + } const size_t frames_per_chunk = 3000; GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk); for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) { @@ -316,7 +519,7 @@ bool preprocess_audio( break; // last uncomplete chunk will always be a padded chunk, safe to ignore } - whisper_mel out_chunk; + mtmd_audio_mel out_chunk; out_chunk.n_len = n_len; out_chunk.n_mel = out_full.n_mel; out_chunk.n_len_org = out_full.n_mel; // unused @@ -332,438 +535,3 @@ bool preprocess_audio( return true; } - -} // namespace whisper_preprocessor - - -// precalculated mel filter banks -// values are multiplied by 1000.0 to save space, and will be divided by 1000.0 in the end of the function -// -// generated from python code: -// -// from numpy import load -// data = load('mel_filters.npz') -// lst = data.files -// for item in lst: -// print(item) -// print(data[item].shape) -// n_mel = data[item].shape[0] -// n_fft = data[item].shape[1] -// for i, row in enumerate(data[item]): -// for j, val in enumerate(row): -// val = val * 1000.0 -// if val != 0: -// print(f"data[{i*n_fft + j}] = {val:.6f};") - -namespace whisper_precalc_filters { - -whisper_preprocessor::whisper_filters get_128_bins() { - whisper_preprocessor::whisper_filters filters; - filters.n_mel = 128; - filters.n_fft = 201; - std::vector data(filters.n_mel * filters.n_fft, 0.0f); - - data[1] = 12.37398665; - data[202] = 30.39256483; - data[404] = 24.74797331; - data[605] = 18.01857911; - data[807] = 37.12195903; - data[1008] = 5.64459199; - data[1009] = 6.72939420; - data[1210] = 36.03715822; - data[1412] = 19.10337992; - data[1613] = 23.66316877; - data[1815] = 31.47736564; - data[2016] = 11.28918398; - data[2017] = 1.08480197; - data[2218] = 41.68175161; - data[2420] = 13.45878839; - data[2621] = 29.30776216; - data[2823] = 25.83277412; - data[3024] = 16.93377644; - data[3226] = 38.20675984; - data[3427] = 4.55979025; - data[3428] = 7.81419594; - data[3629] = 34.95235741; - data[3831] = 20.18818259; - data[4032] = 22.57836796; - data[4234] = 32.56217018; - data[4435] = 10.20438317; - data[4436] = 2.16960395; - data[4637] = 40.59694707; - data[4839] = 14.54358920; - data[5040] = 28.22295949; - data[5242] = 26.91757679; - data[5443] = 15.84897563; - data[5645] = 39.29156065; - data[5846] = 3.47498828; - data[5847] = 8.89899861; - data[6048] = 33.86755288; - data[6250] = 21.27298526; - data[6451] = 21.49356715; - data[6653] = 33.64697099; - data[6854] = 9.11958050; - data[6855] = 3.25440569; - data[7056] = 39.51214626; - data[7258] = 15.62839188; - data[7459] = 27.13815868; - data[7661] = 28.00237760; - data[7862] = 14.76417296; - data[8064] = 40.37636518; - data[8265] = 2.38068704; - data[8266] = 10.20263787; - data[8467] = 31.61146119; - data[8669] = 24.54700135; - data[8870] = 15.32919332; - data[8871] = 1.66583748; - data[9072] = 36.72905266; - data[9274] = 20.09709924; - data[9475] = 16.93102531; - data[9476] = 2.90265540; - data[9677] = 32.84499049; - data[9879] = 23.52004871; - data[10080] = 11.03894413; - data[10081] = 10.72582975; - data[10282] = 22.71829173; - data[10484] = 32.27872774; - data[10685] = 0.11626833; - data[10686] = 22.85348251; - data[10887] = 8.56344029; - data[10888] = 14.97978810; - data[11089] = 15.51398356; - data[11090] = 8.51490628; - data[11291] = 21.10680379; - data[11292] = 3.32652032; - data[11493] = 25.47064796; - data[11695] = 27.35907957; - data[11896] = 0.65853616; - data[11897] = 23.83812517; - data[12098] = 3.44359246; - data[12099] = 21.22455277; - data[12300] = 5.35842171; - data[12301] = 19.42555793; - data[12502] = 6.49324711; - data[12503] = 18.35542172; - data[12704] = 6.93138083; - data[12705] = 17.93504693; - data[12906] = 6.74968259; - data[12907] = 18.09151843; - data[13108] = 6.01899112; - data[13109] = 18.75767298; - data[13310] = 4.80452832; - data[13311] = 19.87172849; - data[13512] = 3.16627859; - data[13513] = 21.37690969; - data[13514] = 1.25317345; - data[13714] = 1.15934468; - data[13715] = 20.80361731; - data[13716] = 4.04486805; - data[13917] = 17.55363122; - data[13918] = 7.08320038; - data[14119] = 14.07538634; - data[14120] = 10.32655034; - data[14321] = 10.40921453; - data[14322] = 13.73696327; - data[14523] = 6.59187697; - data[14524] = 17.27988198; - data[14525] = 1.46804214; - data[14725] = 2.65681883; - data[14726] = 18.09193194; - data[14727] = 5.85655728; - data[14928] = 13.34277913; - data[14929] = 10.28267574; - data[15130] = 8.56800377; - data[15131] = 14.72230814; - data[15132] = 1.04039861; - data[15332] = 3.79085587; - data[15333] = 17.14678481; - data[15334] = 6.11609267; - data[15535] = 11.75929047; - data[15536] = 11.13393717; - data[15737] = 6.43857848; - data[15738] = 16.07806236; - data[15739] = 4.23917221; - data[15939] = 1.19989377; - data[15940] = 12.75671553; - data[15941] = 9.65298992; - data[16142] = 7.06935255; - data[16143] = 14.94054683; - data[16144] = 4.19024844; - data[16344] = 1.51483389; - data[16345] = 12.00899947; - data[16346] = 9.84823331; - data[16547] = 6.10224018; - data[16548] = 15.33857174; - data[16549] = 5.57676842; - data[16749] = 0.36827257; - data[16750] = 9.89749376; - data[16751] = 11.35340426; - data[16752] = 2.05122307; - data[16952] = 3.89297144; - data[16953] = 12.97352277; - data[16954] = 8.06631614; - data[17155] = 6.74493238; - data[17156] = 13.85874674; - data[17157] = 5.41190524; - data[17357] = 0.74220158; - data[17358] = 8.98779090; - data[17359] = 11.37871388; - data[17360] = 3.32958088; - data[17560] = 2.82313535; - data[17561] = 10.68049297; - data[17562] = 9.43340641; - data[17563] = 1.76325557; - data[17763] = 4.39018616; - data[17764] = 11.87758986; - data[17765] = 7.97005836; - data[17766] = 0.66104700; - data[17966] = 5.49466675; - data[17967] = 12.62953598; - data[17968] = 6.93987962; - data[18169] = 6.18401915; - data[18170] = 12.93473132; - data[18171] = 6.29778765; - data[18371] = 0.02325210; - data[18372] = 6.50206627; - data[18373] = 12.32661773; - data[18374] = 6.00216538; - data[18574] = 0.31548753; - data[18575] = 6.48925547; - data[18576] = 12.04130240; - data[18577] = 6.01462880; - data[18777] = 0.29979556; - data[18778] = 6.18288014; - data[18779] = 12.04272825; - data[18780] = 6.29981188; - data[18781] = 0.55689598; - data[18980] = 0.01120471; - data[18981] = 5.61729167; - data[18982] = 11.22337859; - data[18983] = 6.82516303; - data[18984] = 1.35264499; - data[19184] = 4.82410006; - data[19185] = 10.16623247; - data[19186] = 7.56075513; - data[19187] = 2.34590308; - data[19387] = 3.83235747; - data[19388] = 8.92296247; - data[19389] = 8.47910438; - data[19390] = 3.50978645; - data[19590] = 2.66873185; - data[19591] = 7.51965167; - data[19592] = 9.55500547; - data[19593] = 4.81966138; - data[19594] = 0.08431751; - data[19793] = 1.35767367; - data[19794] = 5.98019501; - data[19795] = 10.60271543; - data[19796] = 6.25298498; - data[19797] = 1.74059917; - data[19997] = 4.32644226; - data[19998] = 8.73131864; - data[19999] = 7.78916525; - data[20000] = 3.48923868; - data[20200] = 2.57835095; - data[20201] = 6.77582854; - data[20202] = 9.40941647; - data[20203] = 5.31194592; - data[20204] = 1.21447595; - data[20403] = 0.75411191; - data[20404] = 4.75395704; - data[20405] = 8.75380263; - data[20406] = 7.19209015; - data[20407] = 3.28754401; - data[20607] = 2.68179690; - data[20608] = 6.49331464; - data[20609] = 9.11457930; - data[20610] = 5.39387390; - data[20611] = 1.67316827; - data[20810] = 0.57394296; - data[20811] = 4.20600036; - data[20812] = 7.83805829; - data[20813] = 7.52023002; - data[20814] = 3.97470826; - data[20815] = 0.42918732; - data[21014] = 1.90464477; - data[21015] = 5.36569161; - data[21016] = 8.82673822; - data[21017] = 6.27609482; - data[21018] = 2.89750961; - data[21218] = 2.89885257; - data[21219] = 6.19694078; - data[21220] = 8.56699049; - data[21221] = 5.34748193; - data[21222] = 2.12797290; - data[21421] = 0.44750227; - data[21422] = 3.59030394; - data[21423] = 6.73310598; - data[21424] = 7.77023612; - data[21425] = 4.70231380; - data[21426] = 1.63439126; - data[21625] = 1.01536023; - data[21626] = 4.01018746; - data[21627] = 7.00501446; - data[21628] = 7.23442994; - data[21629] = 4.31095669; - data[21630] = 1.38748321; - data[21829] = 1.33348850; - data[21830] = 4.18730825; - data[21831] = 7.04112789; - data[21832] = 6.93188375; - data[21833] = 4.14605811; - data[21834] = 1.36023236; - data[22033] = 1.42879714; - data[22034] = 4.14824858; - data[22035] = 6.86769979; - data[22036] = 6.83705276; - data[22037] = 4.18239459; - data[22038] = 1.52773573; - data[22237] = 1.32610439; - data[22238] = 3.91751388; - data[22239] = 6.50892360; - data[22240] = 6.92639686; - data[22241] = 4.39672917; - data[22242] = 1.86706171; - data[22441] = 1.04827771; - data[22442] = 3.51767405; - data[22443] = 5.98707050; - data[22444] = 7.17824046; - data[22445] = 4.76767914; - data[22446] = 2.35711760; - data[22645] = 0.61636406; - data[22646] = 2.96949223; - data[22647] = 5.32262027; - data[22648] = 7.57265091; - data[22649] = 5.27558755; - data[22650] = 2.97852419; - data[22651] = 0.68146095; - data[22849] = 0.04971400; - data[22850] = 2.29204819; - data[22851] = 4.53438237; - data[22852] = 6.77671656; - data[22853] = 5.90240723; - data[22854] = 3.71349836; - data[22855] = 1.52458926; - data[23054] = 1.50285335; - data[23055] = 3.63961048; - data[23056] = 5.77636715; - data[23057] = 6.63159089; - data[23058] = 4.54574358; - data[23059] = 2.45989650; - data[23060] = 0.37404924; - data[23258] = 0.61795861; - data[23259] = 2.65410915; - data[23260] = 4.69025923; - data[23261] = 6.72641024; - data[23262] = 5.46034705; - data[23263] = 3.47270933; - data[23264] = 1.48507138; - data[23463] = 1.59233576; - data[23464] = 3.53261665; - data[23465] = 5.47289755; - data[23466] = 6.44368259; - data[23467] = 4.54962999; - data[23468] = 2.65557761; - data[23469] = 0.76152512; - data[23667] = 0.46749352; - data[23668] = 2.31641904; - data[23669] = 4.16534441; - data[23670] = 6.01426978; - data[23671] = 5.67844696; - data[23672] = 3.87357362; - data[23673] = 2.06870004; - data[23674] = 0.26382666; - data[23872] = 1.05349103; - data[23873] = 2.81536230; - data[23874] = 4.57723346; - data[23875] = 6.33910485; - data[23876] = 5.12815686; - data[23877] = 3.40826320; - data[23878] = 1.68837002; - data[24077] = 1.43350090; - data[24078] = 3.11241671; - data[24079] = 4.79133241; - data[24080] = 6.40943693; - data[24081] = 4.77052201; - data[24082] = 3.13160778; - data[24083] = 1.49269309; - data[24281] = 0.02932359; - data[24282] = 1.62918994; - data[24283] = 3.22905602; - data[24284] = 4.82892245; - data[24285] = 6.14671456; - data[24286] = 4.58496623; - data[24287] = 3.02321767; - data[24288] = 1.46146910; - data[24486] = 0.13601698; - data[24487] = 1.66055572; - data[24488] = 3.18509457; - data[24489] = 4.70963307; - data[24490] = 6.04072399; - data[24491] = 4.55250870; - data[24492] = 3.06429295; - data[24493] = 1.57607743; - data[24494] = 0.08786193; - data[24691] = 0.09328097; - data[24692] = 1.54603878; - data[24693] = 2.99879676; - data[24694] = 4.45155473; - data[24695] = 5.90431225; - data[24696] = 4.65566106; - data[24697] = 3.23751615; - data[24698] = 1.81937125; - data[24699] = 0.40122634; - data[24897] = 1.30262633; - data[24898] = 2.68698297; - data[24899] = 4.07133950; - data[24900] = 5.45569602; - data[24901] = 4.87832492; - data[24902] = 3.52695142; - data[24903] = 2.17557792; - data[24904] = 0.82420459; - data[25102] = 0.94595028; - data[25103] = 2.26512621; - data[25104] = 3.58430226; - data[25105] = 4.90347855; - data[25106] = 5.20569785; - data[25107] = 3.91795207; - data[25108] = 2.63020652; - data[25109] = 1.34246063; - data[25110] = 0.05471494; - data[25307] = 0.49037894; - data[25308] = 1.74744334; - data[25309] = 3.00450763; - data[25310] = 4.26157191; - data[25311] = 5.51863620; - data[25312] = 4.39707236; - data[25313] = 3.16995848; - data[25314] = 1.94284460; - data[25315] = 0.71573065; - data[25513] = 1.14698056; - data[25514] = 2.34485767; - data[25515] = 3.54273478; - data[25516] = 4.74061165; - data[25517] = 4.95198462; - data[25518] = 3.78264743; - data[25519] = 2.61331047; - data[25520] = 1.44397374; - data[25521] = 0.27463681; - data[25718] = 0.47569509; - data[25719] = 1.61717169; - data[25720] = 2.75864848; - data[25721] = 3.90012516; - data[25722] = 5.04160160; - data[25723] = 4.45712078; - data[25724] = 3.34284059; - data[25725] = 2.22856039; - data[25726] = 1.11428020; - - for (auto & val : data) { - val /= 1000.0f; - } - - filters.data = std::move(data); - return filters; -} - -} // namespace whisper_precalc_filters diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h index 0e552347a0a..1b454337cbe 100644 --- a/tools/mtmd/mtmd-audio.h +++ b/tools/mtmd/mtmd-audio.h @@ -1,6 +1,7 @@ #pragma once #include "ggml.h" +#include "clip-model.h" #include #include @@ -8,18 +9,7 @@ #define MTMD_INTERNAL_HEADER -#define WHISPER_ASSERT GGML_ASSERT - -#define WHISPER_SAMPLE_RATE 16000 -#define WHISPER_N_FFT 400 -#define WHISPER_HOP_LENGTH 160 -#define WHISPER_CHUNK_SIZE 30 - -#define COMMON_SAMPLE_RATE 16000 - -namespace whisper_preprocessor { - -struct whisper_mel { +struct mtmd_audio_mel { int n_len; int n_len_org; int n_mel; @@ -27,23 +17,18 @@ struct whisper_mel { std::vector data; }; -struct whisper_filters { - int32_t n_mel; - int32_t n_fft; +struct mtmd_audio_preprocessor { + const clip_hparams & hparams; - std::vector data; -}; + mtmd_audio_preprocessor(const clip_ctx * ctx): hparams(*clip_get_hparams(ctx)) {} -bool preprocess_audio( - const float * samples, - size_t n_samples, - const whisper_filters & filters, - std::vector & output); - -} // namespace whisper_preprocessor - -namespace whisper_precalc_filters { - -whisper_preprocessor::whisper_filters get_128_bins(); + virtual ~mtmd_audio_preprocessor() = default; + virtual void initialize() = 0; // NOT thread-safe + virtual bool preprocess(const float * samples, size_t n_samples, std::vector & output) = 0; +}; -} // namespace whisper_precalc_filters +struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor { + mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} + void initialize() override; + bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; +}; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index d06fa42e616..c63f299cd90 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -151,8 +151,7 @@ struct mtmd_context { // string template for slice image delimiters with row/col (idefics3) std::string sli_img_start_tmpl; - // for whisper, we pre-calculate the mel filter bank - whisper_preprocessor::whisper_filters w_filters; + std::unique_ptr audio_preproc; // TODO @ngxson : add timings @@ -317,14 +316,25 @@ struct mtmd_context { GGML_ASSERT(ctx_a != nullptr); projector_type proj = clip_get_projector_type(ctx_a); - if (clip_has_whisper_encoder(ctx_a)) { - // TODO @ngxson : check if model n_mel is 128 or 80 - w_filters = whisper_precalc_filters::get_128_bins(); - } - LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n" " https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__); + // set preprocessor + switch (proj) { + case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_QWEN25O: + case PROJECTOR_TYPE_ULTRAVOX: + case PROJECTOR_TYPE_VOXTRAL: + audio_preproc = std::make_unique(ctx_a); + break; + default: + GGML_ABORT("unsupported audio projector type"); + } + + // initialize audio preprocessor + audio_preproc->initialize(); + + // set special tokens if (proj == PROJECTOR_TYPE_QWEN2A) { // <|audio_bos|> ... (embeddings) ... <|audio_eos|> aud_beg = "<|audio_bos|>"; @@ -653,11 +663,10 @@ struct mtmd_tokenizer { } // preprocess audio - GGML_ASSERT(ctx->w_filters.n_mel); // make sure we have filter preloaded - std::vector mel_spec_chunks; + std::vector mel_spec_chunks; const float * samples = (const float *)bitmap->data.data(); size_t n_samples = bitmap->data.size() / sizeof(float); - bool ok = whisper_preprocessor::preprocess_audio(samples, n_samples, ctx->w_filters, mel_spec_chunks); + bool ok = ctx->audio_preproc->preprocess(samples, n_samples, mel_spec_chunks); if (!ok) { LOG_ERR("Unable to preprocess audio\n"); return 2; @@ -863,8 +872,7 @@ int mtmd_get_audio_bitrate(mtmd_context * ctx) { if (!ctx->ctx_a) { return -1; } - // for now, we assume that all audio models have the same bitrate - return 16000; // 16kHz + return clip_get_hparams(ctx->ctx_a)->audio_sample_rate; } //