From 4c877dc71e6245ec30bcc3807f1a4f34d1bc2bb1 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 24 Jan 2025 16:19:16 +0800 Subject: [PATCH 1/4] fix #8298 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/inferers/merger.py | 21 +++++++++++++++++---- tests/test_zarr_avg_merger.py | 10 +++++++--- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index d01d334142..583cf68a70 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -15,12 +15,13 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from contextlib import nullcontext +from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any import numpy as np import torch -from monai.utils import ensure_tuple_size, optional_import, require_pkg +from monai.utils import ensure_tuple_size, get_package_version, optional_import, require_pkg, version_geq if TYPE_CHECKING: import zarr @@ -233,7 +234,7 @@ def __init__( store: zarr.storage.Store | str = "merged.zarr", value_store: zarr.storage.Store | str | None = None, count_store: zarr.storage.Store | str | None = None, - compressor: str = "default", + compressor: str | None = None, value_compressor: str | None = None, count_compressor: str | None = None, chunks: Sequence[int] | bool = True, @@ -246,8 +247,20 @@ def __init__( self.value_dtype = value_dtype self.count_dtype = count_dtype self.store = store - self.value_store = zarr.storage.TempStore() if value_store is None else value_store - self.count_store = zarr.storage.TempStore() if count_store is None else count_store + if version_geq(get_package_version("zarr"), "3.0.0"): + if value_store is None: + with TemporaryDirectory() as tmpdir: + self.value_store = zarr.storage.LocalStore(tmpdir) + else: + self.value_store = value_store + if count_store is None: + with TemporaryDirectory() as tmpdir: + self.count_store = zarr.storage.LocalStore(tmpdir) + else: + self.count_store = count_store + else: + self.value_store = zarr.storage.TempStore() if value_store is None else value_store + self.count_store = zarr.storage.TempStore() if count_store is None else count_store self.chunks = chunks self.compressor = compressor self.value_compressor = value_compressor diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py index a52dbceb4c..dba046f9ed 100644 --- a/tests/test_zarr_avg_merger.py +++ b/tests/test_zarr_avg_merger.py @@ -289,13 +289,17 @@ class ZarrAvgMergerTests(unittest.TestCase): def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected): if "compressor" in arguments: if arguments["compressor"] != "default": - arguments["compressor"] = zarr.codec_registry[arguments["compressor"].lower()]() + arguments["compressor"] = numcodecs.registry.codec_registry[arguments["compressor"].lower()]() if "value_compressor" in arguments: if arguments["value_compressor"] != "default": - arguments["value_compressor"] = zarr.codec_registry[arguments["value_compressor"].lower()]() + arguments["value_compressor"] = numcodecs.registry.codec_registry[ + arguments["value_compressor"].lower() + ]() if "count_compressor" in arguments: if arguments["count_compressor"] != "default": - arguments["count_compressor"] = zarr.codec_registry[arguments["count_compressor"].lower()]() + arguments["count_compressor"] = numcodecs.registry.codec_registry[ + arguments["count_compressor"].lower() + ]() merger = ZarrAvgMerger(**arguments) for pl in patch_locations: merger.aggregate(pl[0], pl[1]) From 0355bfbc0f8eee01bba7816c41652dc968f4b4cc Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 24 Jan 2025 20:19:57 +0800 Subject: [PATCH 2/4] address comments Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/inferers/merger.py | 8 ++++---- tests/test_zarr_avg_merger.py | 11 ++++------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index 583cf68a70..e10ad4d406 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -249,13 +249,13 @@ def __init__( self.store = store if version_geq(get_package_version("zarr"), "3.0.0"): if value_store is None: - with TemporaryDirectory() as tmpdir: - self.value_store = zarr.storage.LocalStore(tmpdir) + tmpdir = TemporaryDirectory() + self.value_store = zarr.storage.LocalStore(tmpdir.name) else: self.value_store = value_store if count_store is None: - with TemporaryDirectory() as tmpdir: - self.count_store = zarr.storage.LocalStore(tmpdir) + tmpdir = TemporaryDirectory() + self.count_store = zarr.storage.LocalStore(tmpdir.name) else: self.count_store = count_store else: diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py index dba046f9ed..3c89e4fb03 100644 --- a/tests/test_zarr_avg_merger.py +++ b/tests/test_zarr_avg_merger.py @@ -287,19 +287,16 @@ class ZarrAvgMergerTests(unittest.TestCase): ] ) def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected): + codec_reg = numcodecs.registry.codec_registry if "compressor" in arguments: if arguments["compressor"] != "default": - arguments["compressor"] = numcodecs.registry.codec_registry[arguments["compressor"].lower()]() + arguments["compressor"] = codec_reg[arguments["compressor"].lower()]() if "value_compressor" in arguments: if arguments["value_compressor"] != "default": - arguments["value_compressor"] = numcodecs.registry.codec_registry[ - arguments["value_compressor"].lower() - ]() + arguments["value_compressor"] = codec_reg[arguments["value_compressor"].lower()]() if "count_compressor" in arguments: if arguments["count_compressor"] != "default": - arguments["count_compressor"] = numcodecs.registry.codec_registry[ - arguments["count_compressor"].lower() - ]() + arguments["count_compressor"] = codec_reg[arguments["count_compressor"].lower()]() merger = ZarrAvgMerger(**arguments) for pl in patch_locations: merger.aggregate(pl[0], pl[1]) From de054da46cff46e18124aa9fc7f3e95d181dc008 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 24 Jan 2025 20:48:30 +0800 Subject: [PATCH 3/4] Update monai/inferers/merger.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/inferers/merger.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index e10ad4d406..70b14fb0d3 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -249,16 +249,17 @@ def __init__( self.store = store if version_geq(get_package_version("zarr"), "3.0.0"): if value_store is None: - tmpdir = TemporaryDirectory() - self.value_store = zarr.storage.LocalStore(tmpdir.name) + self.tmpdir = TemporaryDirectory() + self.value_store = zarr.storage.LocalStore(self.tmpdir.name) else: self.value_store = value_store if count_store is None: - tmpdir = TemporaryDirectory() - self.count_store = zarr.storage.LocalStore(tmpdir.name) + self.tmpdir = TemporaryDirectory() + self.count_store = zarr.storage.LocalStore(self.tmpdir.name) else: self.count_store = count_store else: + self.tmpdir = None self.value_store = zarr.storage.TempStore() if value_store is None else value_store self.count_store = zarr.storage.TempStore() if count_store is None else count_store self.chunks = chunks From 17eead372bf26ee3f786f533e578a4220241e5b3 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 24 Jan 2025 22:05:45 +0800 Subject: [PATCH 4/4] fix mypy Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/inferers/merger.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index 70b14fb0d3..1344207e18 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -247,6 +247,7 @@ def __init__( self.value_dtype = value_dtype self.count_dtype = count_dtype self.store = store + self.tmpdir: TemporaryDirectory | None if version_geq(get_package_version("zarr"), "3.0.0"): if value_store is None: self.tmpdir = TemporaryDirectory()