@@ -60,7 +60,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
6060 int32_t num_custom_embeddings = 0 ;
6161 int32_t num_custom_embeddings_2 = 0 ;
6262 std::vector<uint8_t > token_embed_custom;
63- std::vector <std::string> readed_embeddings ;
63+ std::map <std::string, std::pair< int , int >> embedding_pos_map ;
6464
6565 FrozenCLIPEmbedderWithCustomWords (ggml_backend_t backend,
6666 bool offload_params_to_cpu,
@@ -123,14 +123,17 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
123123 }
124124
125125 bool load_embedding (std::string embd_name, std::string embd_path, std::vector<int32_t >& bpe_tokens) {
126- // the order matters
127126 ModelLoader model_loader;
128127 if (!model_loader.init_from_file_and_convert_name (embd_path)) {
129128 LOG_ERROR (" embedding '%s' failed" , embd_name.c_str ());
130129 return false ;
131130 }
132- if (std::find (readed_embeddings.begin (), readed_embeddings.end (), embd_name) != readed_embeddings.end ()) {
131+ auto iter = embedding_pos_map.find (embd_name);
132+ if (iter != embedding_pos_map.end ()) {
133133 LOG_DEBUG (" embedding already read in: %s" , embd_name.c_str ());
134+ for (int i = iter->second .first ; i < iter->second .second ; i++) {
135+ bpe_tokens.push_back (text_model->model .vocab_size + i);
136+ }
134137 return true ;
135138 }
136139 struct ggml_init_params params;
@@ -161,7 +164,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
161164 return true ;
162165 };
163166 model_loader.load_tensors (on_load, 1 );
164- readed_embeddings. push_back (embd_name) ;
167+ int pos_start = num_custom_embeddings ;
165168 if (embd) {
166169 int64_t hidden_size = text_model->model .hidden_size ;
167170 token_embed_custom.resize (token_embed_custom.size () + ggml_nbytes (embd));
@@ -188,6 +191,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
188191 }
189192 LOG_DEBUG (" embedding '%s' applied, custom embeddings: %i (text model 2)" , embd_name.c_str (), num_custom_embeddings_2);
190193 }
194+ int pos_end = num_custom_embeddings;
195+ if (pos_end == pos_start) {
196+ return false ;
197+ }
198+ embedding_pos_map[embd_name] = std::pair{pos_start, pos_end};
191199 return true ;
192200 }
193201
0 commit comments