Skip to content

Commit 11ab095

Browse files
authored
fix: resolve embedding loading issue when calling generate_image multiple times (#1078)
1 parent a3a88fc commit 11ab095

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

conditioner.hpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
6060
int32_t num_custom_embeddings = 0;
6161
int32_t num_custom_embeddings_2 = 0;
6262
std::vector<uint8_t> token_embed_custom;
63-
std::vector<std::string> readed_embeddings;
63+
std::map<std::string, std::pair<int, int>> embedding_pos_map;
6464

6565
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
6666
bool offload_params_to_cpu,
@@ -123,14 +123,17 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
123123
}
124124

125125
bool load_embedding(std::string embd_name, std::string embd_path, std::vector<int32_t>& bpe_tokens) {
126-
// the order matters
127126
ModelLoader model_loader;
128127
if (!model_loader.init_from_file_and_convert_name(embd_path)) {
129128
LOG_ERROR("embedding '%s' failed", embd_name.c_str());
130129
return false;
131130
}
132-
if (std::find(readed_embeddings.begin(), readed_embeddings.end(), embd_name) != readed_embeddings.end()) {
131+
auto iter = embedding_pos_map.find(embd_name);
132+
if (iter != embedding_pos_map.end()) {
133133
LOG_DEBUG("embedding already read in: %s", embd_name.c_str());
134+
for (int i = iter->second.first; i < iter->second.second; i++) {
135+
bpe_tokens.push_back(text_model->model.vocab_size + i);
136+
}
134137
return true;
135138
}
136139
struct ggml_init_params params;
@@ -161,7 +164,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
161164
return true;
162165
};
163166
model_loader.load_tensors(on_load, 1);
164-
readed_embeddings.push_back(embd_name);
167+
int pos_start = num_custom_embeddings;
165168
if (embd) {
166169
int64_t hidden_size = text_model->model.hidden_size;
167170
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
@@ -188,6 +191,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
188191
}
189192
LOG_DEBUG("embedding '%s' applied, custom embeddings: %i (text model 2)", embd_name.c_str(), num_custom_embeddings_2);
190193
}
194+
int pos_end = num_custom_embeddings;
195+
if (pos_end == pos_start) {
196+
return false;
197+
}
198+
embedding_pos_map[embd_name] = std::pair{pos_start, pos_end};
191199
return true;
192200
}
193201

0 commit comments

Comments
 (0)