Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions include/svs/core/data/simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,12 @@ class GenericSerializer {
}

template <typename T, lib::LazyInvocable<size_t, size_t> F>
static lib::lazy_result_t<F, size_t, size_t>
load(const lib::ContextFreeLoadTable& table, std::istream& is, const F& lazy) {
static lib::lazy_result_t<F, size_t, size_t> load(
const lib::ContextFreeLoadTable& table,
const lib::detail::Deserializer& deserializer,
std::istream& is,
const F& lazy
) {
auto datatype = lib::load_at<DataType>(table, "eltype");
if (datatype != datatype_v<T>) {
throw ANNEXCEPTION(
Expand All @@ -151,6 +155,10 @@ class GenericSerializer {
size_t num_vectors = lib::load_at<size_t>(table, "num_vectors");
size_t dims = lib::load_at<size_t>(table, "dims");

deserializer.read_name(is);
deserializer.read_size(is);
deserializer.read_binary<io::v1::Header>(is);

return io::load_dataset(is, lazy, num_vectors, dims);
}
};
Expand Down Expand Up @@ -474,13 +482,14 @@ class SimpleData {

static SimpleData load(
const lib::ContextFreeLoadTable& table,
const lib::detail::Deserializer& deserializer,
std::istream& is,
const allocator_type& allocator = {}
)
requires(!is_view)
{
return GenericSerializer::load<T>(
table, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) {
table, deserializer, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) {
return SimpleData(n_elements, n_dimensions, allocator);
})
);
Expand Down Expand Up @@ -879,11 +888,15 @@ class SimpleData<T, Extent, Blocked<Alloc>> {

static SimpleData load(
const lib::ContextFreeLoadTable& table,
const lib::detail::Deserializer& deserializer,
std::istream& is,
const Blocked<Alloc>& allocator = {}
) {
return GenericSerializer::load<T>(
table, is, lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) {
table,
deserializer,
is,
lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) {
return SimpleData(n_elements, n_dimensions, allocator);
})
);
Expand Down
66 changes: 48 additions & 18 deletions include/svs/core/translation.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,27 +324,36 @@ class IDTranslator {
"external_to_internal_translation";
static constexpr lib::Version save_version = lib::Version(0, 0, 0);

lib::SaveTable save(const lib::SaveContext& ctx) const {
auto filename = ctx.generate_name("id_translation", "binary");
// Save the translations to a file.
auto stream = lib::open_write(filename);
for (auto i = begin(), iend = end(); i != iend; ++i) {
// N.B.: Apparently `std::pair` of integers is not trivially copyable ...
lib::write_binary(stream, i->first);
lib::write_binary(stream, i->second);
}
lib::SaveTable save_table() const {
return lib::SaveTable(
serialization_schema,
save_version,
{{"kind", kind},
{"num_points", lib::save(size())},
{"external_id_type", lib::save(datatype_v<external_id_type>)},
{"internal_id_type", lib::save(datatype_v<internal_id_type>)},
{"filename", lib::save(filename.filename())}}
{"internal_id_type", lib::save(datatype_v<internal_id_type>)}}
);
}

static IDTranslator load(const lib::LoadTable& table) {
void save(std::ostream& os) const {
for (auto i = begin(), iend = end(); i != iend; ++i) {
// N.B.: Apparently `std::pair` of integers is not trivially copyable ...
lib::write_binary(os, i->first);
lib::write_binary(os, i->second);
}
}

lib::SaveTable save(const lib::SaveContext& ctx) const {
auto filename = ctx.generate_name("id_translation", "binary");
// Save the translations to a file.
auto os = lib::open_write(filename);
save(os);
auto table = save_table();
table.insert("filename", lib::save(filename.filename()));
return table;
}

static void validate(const lib::ContextFreeLoadTable& table) {
if (kind != lib::load_at<std::string>(table, "kind")) {
throw ANNEXCEPTION("Mismatched kind!");
}
Expand All @@ -357,21 +366,42 @@ class IDTranslator {
if (internal_id_name != lib::load_at<std::string>(table, "internal_id_type")) {
throw ANNEXCEPTION("Mismatched internal id types!");
}
}

// Now that we've more-or-less validated the metadata, time to start loading
// the points.
static IDTranslator load(const lib::ContextFreeLoadTable& table, std::istream& is) {
auto num_points = lib::load_at<size_t>(table, "num_points");

auto translator = IDTranslator{};
auto resolved = table.resolve_at("filename");
auto stream = lib::open_read(resolved);
for (size_t i = 0; i < num_points; ++i) {
auto external_id = lib::read_binary<external_id_type>(stream);
auto internal_id = lib::read_binary<internal_id_type>(stream);
auto external_id = lib::read_binary<external_id_type>(is);
auto internal_id = lib::read_binary<internal_id_type>(is);
translator.insert_translation(external_id, internal_id);
}
return translator;
}

static IDTranslator load(
const lib::ContextFreeLoadTable& table,
const lib::detail::Deserializer& deserializer,
std::istream& is
) {
IDTranslator::validate(table);
deserializer.read_name(is);
deserializer.read_size(is);

return IDTranslator::load(table, is);
}

static IDTranslator load(const lib::LoadTable& table) {
IDTranslator::validate(table);

// Now that we've more-or-less validated the metadata, time to start loading
// the points.
auto resolved = table.resolve_at("filename");
auto is = lib::open_read(resolved);
return IDTranslator::load(table, is);
}

private:
template <class Begin, class End, class Map, class Modifier = lib::identity>
void check(
Expand Down
84 changes: 84 additions & 0 deletions include/svs/index/flat/dynamic_flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "svs/lib/invoke.h"
#include "svs/lib/misc.h"
#include "svs/lib/preprocessor.h"
#include "svs/lib/stream.h"
#include "svs/lib/threads.h"

namespace svs::index::flat {
Expand Down Expand Up @@ -403,6 +404,26 @@ template <typename Data, typename Dist> class DynamicFlatIndex {
// Save the dataset in the separate data directory
lib::save_to_disk(data_, data_directory);
}

void save(std::ostream& os) {
compact();

lib::begin_serialization(os);
// Save data structures and translation to config directory
lib::SaveTable save_table = lib::SaveTable(
"dynamic_flat_config",
save_version,
{
{"name", name()},
{"translation", lib::detail::exit_hook(translator_.save_table())},
}
);
lib::save_to_stream(save_table, os);
translator_.save(os);

lib::save_to_stream(data_, os);
}

constexpr std::string_view name() const { return "dynamic flat index"; }

///// Thread Pool Management
Expand Down Expand Up @@ -767,4 +788,67 @@ auto auto_dynamic_assemble(
);
}

auto load_translator(const lib::detail::Deserializer& deserializer, std::istream& is) {
auto table = lib::detail::begin_deserialization(deserializer, is);
auto translator = IDTranslator::load(
table.template cast<toml::table>().at("translation").template cast<toml::table>(),
deserializer,
is
);
return translator;
}

template <typename LazyDataLoader, typename Distance, typename ThreadPoolProto>
auto auto_dynamic_assemble(
const lib::detail::Deserializer& deserializer,
std::istream& is,
LazyDataLoader&& data_loader,
Distance distance,
ThreadPoolProto threadpool_proto,
// Set this to `true` to use the identity map for ID translation.
// This allows us to read files generated by the static index construction routines
// to easily benchmark the static versus dynamic implementation.
//
// This is an internal API and should not be considered officially supported nor stable.
bool SVS_UNUSED(debug_load_from_static) = false,
svs::logging::logger_ptr logger = svs::logging::get()
) {
IDTranslator translator;
// In legacy deserialization the order of directories isn't determined.
auto name = deserializer.read_name_in_advance(is);

// We have to hardcode the file_name for legacy mode, since it was hardcoded when legacy
// model was serialized
bool translator_before_data =
(name == "config/svs_config.toml") || deserializer.is_native();
if (translator_before_data) {
translator = load_translator(deserializer, is);
}

// Load the dataset
auto threadpool = threads::as_threadpool(std::move(threadpool_proto));
auto data = svs::detail::dispatch_load(data_loader(), threadpool);
auto datasize = data.size();

if (!translator_before_data) {
translator = load_translator(deserializer, is);
}

// Validate the translator
auto translator_size = translator.size();
if (translator_size != datasize) {
throw ANNEXCEPTION(
"Translator has {} IDs but should have {}", translator_size, datasize
);
}

return DynamicFlatIndex(
std::move(data),
std::move(translator),
std::move(distance),
std::move(threadpool),
std::move(logger)
);
}

} // namespace svs::index::flat
5 changes: 4 additions & 1 deletion include/svs/index/flat/flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,10 @@ class FlatIndex {
lib::save_to_disk(data_, data_directory);
}

void save(std::ostream& os) const { lib::save_to_stream(data_, os); }
void save(std::ostream& os) const {
lib::begin_serialization(os);
lib::save_to_stream(data_, os);
}
};

///
Expand Down
Loading
Loading