diff --git a/include/svs/core/data/simple.h b/include/svs/core/data/simple.h index cf3c5df2..890cda37 100644 --- a/include/svs/core/data/simple.h +++ b/include/svs/core/data/simple.h @@ -136,8 +136,12 @@ class GenericSerializer { } template F> - static lib::lazy_result_t - load(const lib::ContextFreeLoadTable& table, std::istream& is, const F& lazy) { + static lib::lazy_result_t load( + const lib::ContextFreeLoadTable& table, + const lib::detail::Deserializer& deserializer, + std::istream& is, + const F& lazy + ) { auto datatype = lib::load_at(table, "eltype"); if (datatype != datatype_v) { throw ANNEXCEPTION( @@ -151,6 +155,10 @@ class GenericSerializer { size_t num_vectors = lib::load_at(table, "num_vectors"); size_t dims = lib::load_at(table, "dims"); + deserializer.read_name(is); + deserializer.read_size(is); + deserializer.read_binary(is); + return io::load_dataset(is, lazy, num_vectors, dims); } }; @@ -474,13 +482,14 @@ class SimpleData { static SimpleData load( const lib::ContextFreeLoadTable& table, + const lib::detail::Deserializer& deserializer, std::istream& is, const allocator_type& allocator = {} ) requires(!is_view) { return GenericSerializer::load( - table, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) { + table, deserializer, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) { return SimpleData(n_elements, n_dimensions, allocator); }) ); @@ -879,11 +888,15 @@ class SimpleData> { static SimpleData load( const lib::ContextFreeLoadTable& table, + const lib::detail::Deserializer& deserializer, std::istream& is, const Blocked& allocator = {} ) { return GenericSerializer::load( - table, is, lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) { + table, + deserializer, + is, + lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) { return SimpleData(n_elements, n_dimensions, allocator); }) ); diff --git a/include/svs/core/translation.h b/include/svs/core/translation.h index a3c4bca3..1b1fde9c 100644 --- a/include/svs/core/translation.h +++ b/include/svs/core/translation.h @@ -324,27 +324,36 @@ class IDTranslator { "external_to_internal_translation"; static constexpr lib::Version save_version = lib::Version(0, 0, 0); - lib::SaveTable save(const lib::SaveContext& ctx) const { - auto filename = ctx.generate_name("id_translation", "binary"); - // Save the translations to a file. - auto stream = lib::open_write(filename); - for (auto i = begin(), iend = end(); i != iend; ++i) { - // N.B.: Apparently `std::pair` of integers is not trivially copyable ... - lib::write_binary(stream, i->first); - lib::write_binary(stream, i->second); - } + lib::SaveTable save_table() const { return lib::SaveTable( serialization_schema, save_version, {{"kind", kind}, {"num_points", lib::save(size())}, {"external_id_type", lib::save(datatype_v)}, - {"internal_id_type", lib::save(datatype_v)}, - {"filename", lib::save(filename.filename())}} + {"internal_id_type", lib::save(datatype_v)}} ); } - static IDTranslator load(const lib::LoadTable& table) { + void save(std::ostream& os) const { + for (auto i = begin(), iend = end(); i != iend; ++i) { + // N.B.: Apparently `std::pair` of integers is not trivially copyable ... + lib::write_binary(os, i->first); + lib::write_binary(os, i->second); + } + } + + lib::SaveTable save(const lib::SaveContext& ctx) const { + auto filename = ctx.generate_name("id_translation", "binary"); + // Save the translations to a file. + auto os = lib::open_write(filename); + save(os); + auto table = save_table(); + table.insert("filename", lib::save(filename.filename())); + return table; + } + + static void validate(const lib::ContextFreeLoadTable& table) { if (kind != lib::load_at(table, "kind")) { throw ANNEXCEPTION("Mismatched kind!"); } @@ -357,21 +366,42 @@ class IDTranslator { if (internal_id_name != lib::load_at(table, "internal_id_type")) { throw ANNEXCEPTION("Mismatched internal id types!"); } + } - // Now that we've more-or-less validated the metadata, time to start loading - // the points. + static IDTranslator load(const lib::ContextFreeLoadTable& table, std::istream& is) { auto num_points = lib::load_at(table, "num_points"); + auto translator = IDTranslator{}; - auto resolved = table.resolve_at("filename"); - auto stream = lib::open_read(resolved); for (size_t i = 0; i < num_points; ++i) { - auto external_id = lib::read_binary(stream); - auto internal_id = lib::read_binary(stream); + auto external_id = lib::read_binary(is); + auto internal_id = lib::read_binary(is); translator.insert_translation(external_id, internal_id); } return translator; } + static IDTranslator load( + const lib::ContextFreeLoadTable& table, + const lib::detail::Deserializer& deserializer, + std::istream& is + ) { + IDTranslator::validate(table); + deserializer.read_name(is); + deserializer.read_size(is); + + return IDTranslator::load(table, is); + } + + static IDTranslator load(const lib::LoadTable& table) { + IDTranslator::validate(table); + + // Now that we've more-or-less validated the metadata, time to start loading + // the points. + auto resolved = table.resolve_at("filename"); + auto is = lib::open_read(resolved); + return IDTranslator::load(table, is); + } + private: template void check( diff --git a/include/svs/index/flat/dynamic_flat.h b/include/svs/index/flat/dynamic_flat.h index 868054ba..26946a22 100644 --- a/include/svs/index/flat/dynamic_flat.h +++ b/include/svs/index/flat/dynamic_flat.h @@ -36,6 +36,7 @@ #include "svs/lib/invoke.h" #include "svs/lib/misc.h" #include "svs/lib/preprocessor.h" +#include "svs/lib/stream.h" #include "svs/lib/threads.h" namespace svs::index::flat { @@ -403,6 +404,26 @@ template class DynamicFlatIndex { // Save the dataset in the separate data directory lib::save_to_disk(data_, data_directory); } + + void save(std::ostream& os) { + compact(); + + lib::begin_serialization(os); + // Save data structures and translation to config directory + lib::SaveTable save_table = lib::SaveTable( + "dynamic_flat_config", + save_version, + { + {"name", name()}, + {"translation", lib::detail::exit_hook(translator_.save_table())}, + } + ); + lib::save_to_stream(save_table, os); + translator_.save(os); + + lib::save_to_stream(data_, os); + } + constexpr std::string_view name() const { return "dynamic flat index"; } ///// Thread Pool Management @@ -767,4 +788,67 @@ auto auto_dynamic_assemble( ); } +auto load_translator(const lib::detail::Deserializer& deserializer, std::istream& is) { + auto table = lib::detail::begin_deserialization(deserializer, is); + auto translator = IDTranslator::load( + table.template cast().at("translation").template cast(), + deserializer, + is + ); + return translator; +} + +template +auto auto_dynamic_assemble( + const lib::detail::Deserializer& deserializer, + std::istream& is, + LazyDataLoader&& data_loader, + Distance distance, + ThreadPoolProto threadpool_proto, + // Set this to `true` to use the identity map for ID translation. + // This allows us to read files generated by the static index construction routines + // to easily benchmark the static versus dynamic implementation. + // + // This is an internal API and should not be considered officially supported nor stable. + bool SVS_UNUSED(debug_load_from_static) = false, + svs::logging::logger_ptr logger = svs::logging::get() +) { + IDTranslator translator; + // In legacy deserialization the order of directories isn't determined. + auto name = deserializer.read_name_in_advance(is); + + // We have to hardcode the file_name for legacy mode, since it was hardcoded when legacy + // model was serialized + bool translator_before_data = + (name == "config/svs_config.toml") || deserializer.is_native(); + if (translator_before_data) { + translator = load_translator(deserializer, is); + } + + // Load the dataset + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + auto data = svs::detail::dispatch_load(data_loader(), threadpool); + auto datasize = data.size(); + + if (!translator_before_data) { + translator = load_translator(deserializer, is); + } + + // Validate the translator + auto translator_size = translator.size(); + if (translator_size != datasize) { + throw ANNEXCEPTION( + "Translator has {} IDs but should have {}", translator_size, datasize + ); + } + + return DynamicFlatIndex( + std::move(data), + std::move(translator), + std::move(distance), + std::move(threadpool), + std::move(logger) + ); +} + } // namespace svs::index::flat diff --git a/include/svs/index/flat/flat.h b/include/svs/index/flat/flat.h index d81e7cf5..925fe68a 100644 --- a/include/svs/index/flat/flat.h +++ b/include/svs/index/flat/flat.h @@ -523,7 +523,10 @@ class FlatIndex { lib::save_to_disk(data_, data_directory); } - void save(std::ostream& os) const { lib::save_to_stream(data_, os); } + void save(std::ostream& os) const { + lib::begin_serialization(os); + lib::save_to_stream(data_, os); + } }; /// diff --git a/include/svs/lib/saveload/load.h b/include/svs/lib/saveload/load.h index c848e80c..09e072de 100644 --- a/include/svs/lib/saveload/load.h +++ b/include/svs/lib/saveload/load.h @@ -833,39 +833,78 @@ inline SerializedObject begin_deserialization(const std::filesystem::path& fullp std::move(table), lib::LoadContext{fullpath.parent_path(), version}}; } -inline ContextFreeSerializedObject begin_deserialization(std::istream& stream) { - lib::StreamArchiver::size_type magic = 0; - lib::StreamArchiver::read_size(stream, magic); - if (magic == lib::DirectoryArchiver::magic_number) { - // Backward compatibility mode for older versions: - // Previously, SVS serialized models using an intermediate file, - // so some dummy information was added to the stream. - lib::StreamArchiver::size_type num_files = 0; - lib::StreamArchiver::read_size(stream, num_files); - - std::string file_name; - lib::StreamArchiver::read_name(stream, file_name); - } else if (magic != lib::StreamArchiver::magic_number) { - throw ANNEXCEPTION("Invalid magic number in stream deserialization!"); +class Deserializer { + enum SerializationScheme { native, legacy }; + SerializationScheme scheme_; + + mutable bool skip_next_name_ = false; + + explicit Deserializer(const SerializationScheme& scheme) + : scheme_(scheme) {} + + public: + static Deserializer build(std::istream& stream) { + lib::StreamArchiver::size_type magic = 0; + lib::StreamArchiver::read_size(stream, magic); + if (magic == lib::StreamArchiver::magic_number) { + return Deserializer(SerializationScheme::native); + } else if (magic == lib::DirectoryArchiver::magic_number) { + // Backward compatibility mode for older versions: + // Previously, SVS serialized models using an intermediate file, + // so some dummy information was added to the stream. + lib::StreamArchiver::size_type num_files = 0; + lib::StreamArchiver::read_size(stream, num_files); + + return Deserializer(SerializationScheme::legacy); + } else { + throw ANNEXCEPTION("Invalid magic number in stream deserialization!"); + } + } + + bool is_native() const { return scheme_ == SerializationScheme::native; } + + std::string read_name_in_advance(std::istream& stream) const { + std::string name; + if (scheme_ == SerializationScheme::legacy) { + lib::StreamArchiver::read_name(stream, name); + skip_next_name_ = true; + } + return name; + } + + void read_name(std::istream& stream) const { + if (scheme_ == SerializationScheme::legacy) { + if (!skip_next_name_) { + std::string name; + lib::StreamArchiver::read_name(stream, name); + } + skip_next_name_ = false; + } } + void read_size(std::istream& stream) const { + if (scheme_ == SerializationScheme::legacy) { + lib::StreamArchiver::size_type size = 0; + lib::StreamArchiver::read_size(stream, size); + } + } + + template void read_binary(std::istream& stream) const { + if (scheme_ == SerializationScheme::legacy) { + lib::read_binary(stream); + } + } +}; + +inline ContextFreeSerializedObject +begin_deserialization(const Deserializer& deserializer, std::istream& stream) { + deserializer.read_name(stream); if (!stream) { throw ANNEXCEPTION("Error reading from stream!"); } auto table = lib::StreamArchiver::read_table(stream); - if (magic == lib::DirectoryArchiver::magic_number) { - // Backward compatibility mode for older versions: - // Previously, SVS serialized models using an intermediate file, - // so some dummy information was added to the stream. - std::string file_name; - lib::StreamArchiver::read_name(stream, file_name); - - lib::StreamArchiver::size_type file_size = 0; - lib::StreamArchiver::read_size(stream, file_size); - lib::read_binary(stream); - } return ContextFreeSerializedObject{std::move(table)}; } @@ -920,17 +959,28 @@ T load_from_disk(const std::filesystem::path& path, Args&&... args) { ///// load_from_stream template -T load_from_stream(const Loader& loader, std::istream& stream, Args&&... args) { +T load_from_stream( + const Loader& loader, + const detail::Deserializer& deserializer, + std::istream& stream, + Args&&... args +) { // At this point, we will try the saving/loading framework to load the object. // Here we go! return lib::load( - loader, detail::begin_deserialization(stream), stream, SVS_FWD(args)... + loader, + detail::begin_deserialization(deserializer, stream), + deserializer, + stream, + SVS_FWD(args)... ); } template -T load_from_stream(std::istream& stream, Args&&... args) { - return lib::load_from_stream(Loader(), stream, SVS_FWD(args)...); +T load_from_stream( + const detail::Deserializer& deserializer, std::istream& stream, Args&&... args +) { + return lib::load_from_stream(Loader(), deserializer, stream, SVS_FWD(args)...); } ///// load_from_file diff --git a/include/svs/lib/saveload/save.h b/include/svs/lib/saveload/save.h index fe7694c4..c609151e 100644 --- a/include/svs/lib/saveload/save.h +++ b/include/svs/lib/saveload/save.h @@ -377,13 +377,20 @@ template void save_to_file(const T& x, const std::filesystem::path& detail::save_node_to_file(lib::save(x), path); } -template void save_to_stream(const T& x, std::ostream& os) { +inline void begin_serialization(std::ostream& os) { lib::StreamArchiver::write_size(os, lib::StreamArchiver::magic_number); +} - auto save_table = x.save_table(); - detail::save_node_to_stream(detail::exit_hook(save_table), os); - - x.save(os); +template void save_to_stream(const T& x, std::ostream& os) { + if constexpr (requires { x.save_table(); }) { + auto save_table = x.save_table(); + detail::save_node_to_stream(detail::exit_hook(save_table), os); + x.save(os); + } else if constexpr (std::is_same_v) { + detail::save_node_to_stream(detail::exit_hook(x), os); + } else { + static_assert(sizeof(T) == 0, "Type not stream-serializable"); + } } } // namespace svs::lib diff --git a/include/svs/orchestrators/dynamic_flat.h b/include/svs/orchestrators/dynamic_flat.h index e06efb45..250edf10 100644 --- a/include/svs/orchestrators/dynamic_flat.h +++ b/include/svs/orchestrators/dynamic_flat.h @@ -123,13 +123,7 @@ class DynamicFlatImpl // Stream-based save implementation void save(std::ostream& stream) override { if constexpr (Impl::supports_saving) { - lib::UniqueTempDirectory tempdir{"svs_dynflat_save"}; - const auto config_dir = tempdir.get() / "config"; - const auto data_dir = tempdir.get() / "data"; - std::filesystem::create_directories(config_dir); - std::filesystem::create_directories(data_dir); - save(config_dir, data_dir); - lib::DirectoryArchiver::pack(tempdir, stream); + impl().save(stream); } else { throw ANNEXCEPTION("The current DynamicFlat backend doesn't support saving!"); } @@ -282,28 +276,22 @@ class DynamicFlat : public manager::IndexManager { ThreadPoolProto threadpool_proto, DataLoaderArgs&&... data_args ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_dynflat_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); - - const auto config_path = tempdir.get() / "config"; - if (!fs::is_directory(config_path)) { - throw ANNEXCEPTION( - "Invalid Dynamic Flat index archive: missing config directory!" - ); - } - - const auto data_path = tempdir.get() / "data"; - if (!fs::is_directory(data_path)) { - throw ANNEXCEPTION("Invalid Dynamic Flat index archive: missing data directory!" - ); - } - - return assemble( - config_path, - lib::load_from_disk(data_path, SVS_FWD(data_args)...), - distance, - threads::as_threadpool(std::move(threadpool_proto)) + auto deserializer = svs::lib::detail::Deserializer::build(stream); + return DynamicFlat( + AssembleTag(), + manager::as_typelist(), + index::flat::auto_dynamic_assemble( + deserializer, + stream, + // lazy-loader + [&]() -> Data { + return lib::load_from_stream( + deserializer, stream, SVS_FWD(data_args)... + ); + }, + distance, + threads::as_threadpool(std::move(threadpool_proto)) + ) ); } diff --git a/include/svs/orchestrators/exhaustive.h b/include/svs/orchestrators/exhaustive.h index 7fa969ea..bf46c134 100644 --- a/include/svs/orchestrators/exhaustive.h +++ b/include/svs/orchestrators/exhaustive.h @@ -195,8 +195,9 @@ class Flat : public manager::IndexManager { ThreadPoolProto threadpool_proto, DataLoaderArgs&&... data_args ) { + auto deserializer = svs::lib::detail::Deserializer::build(stream); return assemble( - lib::load_from_stream(stream, SVS_FWD(data_args)...), + lib::load_from_stream(deserializer, stream, SVS_FWD(data_args)...), distance, threads::as_threadpool(std::move(threadpool_proto)) ); diff --git a/tests/svs/index/flat/dynamic_flat.cpp b/tests/svs/index/flat/dynamic_flat.cpp index f9f99e70..76c4d4c6 100644 --- a/tests/svs/index/flat/dynamic_flat.cpp +++ b/tests/svs/index/flat/dynamic_flat.cpp @@ -214,3 +214,143 @@ CATCH_TEST_CASE("Testing Flat Index", "[dynamic_flat]") { test_loop(index, reference, queries, div(reference.size(), modify_fraction), 2, 6); } + +CATCH_TEST_CASE("DynamicFlat Index Save and Load", "[dynamic_flat][index][saveload]") { +#if defined(NDEBUG) + const float initial_fraction = 0.25; + const float modify_fraction = 0.05; +#else + const float initial_fraction = 0.05; + const float modify_fraction = 0.005; +#endif + const size_t num_threads = 10; + + // Load the base dataset and queries. + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto num_points = data.size(); + auto queries = test_dataset::queries(); + + auto reference = svs::misc::ReferenceDataset( + std::move(data), + Distance(), + num_threads, + div(num_points, 0.5 * modify_fraction), + NUM_NEIGHBORS, + queries, + 0x12345678 + ); + + auto num_indices_to_add = div(reference.size(), initial_fraction); + + // Construct a blocked dataset consisting of initial fraction of the base dataset. + auto data_mutable = svs::data::BlockedData(num_indices_to_add, N); + std::vector initial_indices{}; + { + auto [vectors, indices] = reference.generate(num_indices_to_add); + // Copy assign ``initial_indices`` + auto num_points_added = indices.size(); + CATCH_REQUIRE(vectors.size() == num_points_added); + CATCH_REQUIRE(num_points_added <= num_indices_to_add); + CATCH_REQUIRE(num_points_added > num_indices_to_add - reference.bucket_size()); + + initial_indices = indices; + if (vectors.size() != num_indices_to_add || indices.size() != num_indices_to_add) { + throw ANNEXCEPTION("Something when horribly wrong!"); + } + + for (size_t i = 0; i < num_indices_to_add; ++i) { + data_mutable.set_datum(i, vectors.get_datum(i)); + } + } + + using Data_t = svs::data::BlockedData; + using Distance_t = svs::distance::DistanceL2; + using Index_t = svs::index::flat::DynamicFlatIndex; + + Distance_t dist; + auto index = Index_t(std::move(data_mutable), initial_indices, dist, num_threads); + + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + index.search(results.view(), queries.cview(), {}); + + reference.configure_extra_checks(true); + CATCH_REQUIRE(reference.extra_checks_enabled()); + + CATCH_SECTION("Load DynamicFlat being serialized natively to stream") { + std::stringstream stream; + index.save(stream); + { + auto deserializer = svs::lib::detail::Deserializer::build(stream); + Index_t loaded_index = svs::index::flat::auto_dynamic_assemble( + deserializer, + stream, + // lazy-loader + [&]() -> Data_t { + return svs::lib::load_from_stream(deserializer, stream); + }, + dist, + svs::threads::as_threadpool(num_threads) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + loaded_index.search(loaded_results.view(), queries.cview(), {}); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + + CATCH_SECTION("Load DynamicFlat being serialized with intermediate files") { + std::stringstream stream; + { + svs::lib::UniqueTempDirectory tempdir{"svs_dynflat_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(data_dir); + index.save(config_dir, data_dir); + svs::lib::DirectoryArchiver::pack(tempdir, stream); + } + { + auto deserializer = svs::lib::detail::Deserializer::build(stream); + Index_t loaded_index = svs::index::flat::auto_dynamic_assemble( + deserializer, + stream, + // lazy-loader + [&]() -> Data_t { + return svs::lib::load_from_stream(deserializer, stream); + }, + dist, + svs::threads::as_threadpool(num_threads) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + loaded_index.search(loaded_results.view(), queries.cview(), {}); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } +} diff --git a/tests/svs/index/flat/flat.cpp b/tests/svs/index/flat/flat.cpp index d09532d8..29879b66 100644 --- a/tests/svs/index/flat/flat.cpp +++ b/tests/svs/index/flat/flat.cpp @@ -93,8 +93,11 @@ CATCH_TEST_CASE("Flat Index Save and Load", "[flat][index][saveload]") { std::stringstream ss; index.save(ss); + auto deserializer = svs::lib::detail::Deserializer::build(ss); Index_t loaded_index = Index_t( - svs::lib::load_from_stream(ss), dist, svs::threads::DefaultThreadPool(1) + svs::lib::load_from_stream(deserializer, ss), + dist, + svs::threads::DefaultThreadPool(1) ); CATCH_REQUIRE(loaded_index.size() == index.size()); @@ -122,8 +125,11 @@ CATCH_TEST_CASE("Flat Index Save and Load", "[flat][index][saveload]") { index.save(tempdir); svs::lib::DirectoryArchiver::pack(tempdir, ss); + auto deserializer = svs::lib::detail::Deserializer::build(ss); Index_t loaded_index = Index_t( - svs::lib::load_from_stream(ss), dist, svs::threads::DefaultThreadPool(1) + svs::lib::load_from_stream(deserializer, ss), + dist, + svs::threads::DefaultThreadPool(1) ); CATCH_REQUIRE(loaded_index.size() == index.size());