diff --git a/docs/api_reference.md b/docs/api_reference.md index f819ea3eb..7e7473180 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -40,3 +40,10 @@ and .. automodule:: mdio.core.dimension :members: ``` + +## Optimization + +```{eval-rst} +.. automodule:: mdio.optimize.access_pattern + :members: +``` diff --git a/docs/tutorials/rechunking.ipynb b/docs/tutorials/rechunking.ipynb index f9d4b12da..aeaae43d9 100644 --- a/docs/tutorials/rechunking.ipynb +++ b/docs/tutorials/rechunking.ipynb @@ -17,257 +17,1325 @@ "## Introduction\n", "\n", "In this page we will be showing how we can take an existing MDIO and add\n", - "fast access, lossy, versions of the data in X/Y/Z cross-sections (slices).\n", + "fast access, lossy, versions of the data in IL/XL/TWT cross-sections (slices).\n", "\n", "We can re-use the MDIO dataset we created in the [Quickstart](#quickstart) page.\n", "Please run it first.\n", "\n", - "We will define our compression levels first. We will use this to adjust the quality\n", - "of the lossy compression." + "Let's open the original MDIO first." ] }, { "cell_type": "code", - "execution_count": 1, - "id": "initial_id", - "metadata": { - "ExecuteTime": { - "end_time": "2025-04-16T18:38:02.462276Z", - "start_time": "2025-04-16T18:38:02.459882Z" - } - }, - "outputs": [], - "source": [ - "from enum import Enum\n", - "\n", - "\n", - "class MdioZfpQuality(float, Enum):\n", - " \"\"\"Config options for ZFP compression.\"\"\"\n", - "\n", - " VERY_LOW = 6\n", - " LOW = 3\n", - " MEDIUM = 1\n", - " HIGH = 0.1\n", - " VERY_HIGH = 0.01\n", - " ULTRA = 0.001" - ] - }, - { - "cell_type": "markdown", - "id": "c2a09a89-b453-4c3e-b879-14caaedd29de", - "metadata": {}, - "source": [ - "We will use the lower level `MDIOAccessor` to open the existing file in write mode that\n", - "allows us to modify its raw metadata. This can be dangerous, we recommend using only provided\n", - "tools to avoid data corruption.\n", - "\n", - "We specify the original access pattern of the source data `\"012\"` with some parameters like\n", - "caching. For the rechunking, we recommend using the single threaded `\"zarr\"` backend to avoid\n", - "race conditions.\n", - "\n", - "We also define a `dict` for common arguments in rechunking." - ] - }, - { - "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "45558306-ab9c-46aa-a299-8758a911b373", - "metadata": { - "ExecuteTime": { - "end_time": "2025-04-16T18:38:04.107696Z", - "start_time": "2025-04-16T18:38:04.101239Z" + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 403MB\n",
+       "Dimensions:           (inline: 345, crossline: 188, time: 1501)\n",
+       "Coordinates:\n",
+       "  * inline            (inline) int32 1kB 1 2 3 4 5 6 ... 340 341 342 343 344 345\n",
+       "  * crossline         (crossline) int32 752B 1 2 3 4 5 6 ... 184 185 186 187 188\n",
+       "  * time              (time) int32 6kB 0 2 4 6 8 10 ... 2992 2994 2996 2998 3000\n",
+       "    cdp_y             (inline, crossline) float64 519kB ...\n",
+       "    cdp_x             (inline, crossline) float64 519kB ...\n",
+       "Data variables:\n",
+       "    amplitude         (inline, crossline, time) float32 389MB ...\n",
+       "    headers           (inline, crossline) [('trace_seq_num_line', '<i4'), ('trace_seq_num_reel', '<i4'), ('orig_field_record_num', '<i4'), ('trace_num_orig_record', '<i4'), ('energy_source_point_num', '<i4'), ('ensemble_num', '<i4'), ('trace_num_ensemble', '<i4'), ('trace_id_code', '<i2'), ('vertically_summed_traces', '<i2'), ('horizontally_stacked_traces', '<i2'), ('data_use', '<i2'), ('source_to_receiver_distance', '<i4'), ('receiver_group_elevation', '<i4'), ('source_surface_elevation', '<i4'), ('source_depth_below_surface', '<i4'), ('receiver_datum_elevation', '<i4'), ('source_datum_elevation', '<i4'), ('source_water_depth', '<i4'), ('receiver_water_depth', '<i4'), ('elevation_depth_scalar', '<i2'), ('coordinate_scalar', '<i2'), ('source_coord_x', '<i4'), ('source_coord_y', '<i4'), ('group_coord_x', '<i4'), ('group_coord_y', '<i4'), ('coordinate_unit', '<i2'), ('weathering_velocity', '<i2'), ('subweathering_velocity', '<i2'), ('source_uphole_time', '<i2'), ('group_uphole_time', '<i2'), ('source_static_correction', '<i2'), ('receiver_static_correction', '<i2'), ('total_static_applied', '<i2'), ('lag_time_a', '<i2'), ('lag_time_b', '<i2'), ('delay_recording_time', '<i2'), ('mute_time_start', '<i2'), ('mute_time_end', '<i2'), ('samples_per_trace', '<i2'), ('sample_interval', '<i2'), ('instrument_gain_type', '<i2'), ('instrument_gain_const', '<i2'), ('instrument_gain_initial', '<i2'), ('correlated_data', '<i2'), ('sweep_freq_start', '<i2'), ('sweep_freq_end', '<i2'), ('sweep_length', '<i2'), ('sweep_type', '<i2'), ('sweep_taper_start', '<i2'), ('sweep_taper_end', '<i2'), ('taper_type', '<i2'), ('alias_filter_freq', '<i2'), ('alias_filter_slope', '<i2'), ('notch_filter_freq', '<i2'), ('notch_filter_slope', '<i2'), ('low_cut_freq', '<i2'), ('high_cut_freq', '<i2'), ('low_cut_slope', '<i2'), ('high_cut_slope', '<i2'), ('year_recorded', '<i2'), ('day_of_year', '<i2'), ('hour_of_day', '<i2'), ('minute_of_hour', '<i2'), ('second_of_minute', '<i2'), ('time_basis_code', '<i2'), ('trace_weighting_factor', '<i2'), ('group_num_roll_switch', '<i2'), ('group_num_first_trace', '<i2'), ('group_num_last_trace', '<i2'), ('gap_size', '<i2'), ('taper_overtravel', '<i2'), ('inline', '<i4'), ('crossline', '<i4'), ('cdp_x', '<i4'), ('cdp_y', '<i4')] 13MB ...\n",
+       "    segy_file_header  <U1 4B ...\n",
+       "    trace_mask        (inline, crossline) bool 65kB ...\n",
+       "Attributes:\n",
+       "    apiVersion:  1.1.1\n",
+       "    createdOn:   2025-12-19 16:05:58.230520+00:00\n",
+       "    name:        PostStack3DTime\n",
+       "    attributes:  {'surveyType': '3D', 'gatherType': 'stacked', 'defaultVariab...
" + ], + "text/plain": [ + " Size: 403MB\n", + "Dimensions: (inline: 345, crossline: 188, time: 1501)\n", + "Coordinates:\n", + " * inline (inline) int32 1kB 1 2 3 4 5 6 ... 340 341 342 343 344 345\n", + " * crossline (crossline) int32 752B 1 2 3 4 5 6 ... 184 185 186 187 188\n", + " * time (time) int32 6kB 0 2 4 6 8 10 ... 2992 2994 2996 2998 3000\n", + " cdp_y (inline, crossline) float64 519kB ...\n", + " cdp_x (inline, crossline) float64 519kB ...\n", + "Data variables:\n", + " amplitude (inline, crossline, time) float32 389MB ...\n", + " headers (inline, crossline) [('trace_seq_num_line', '\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 2GB\n",
+       "Dimensions:           (inline: 345, crossline: 188, time: 1501)\n",
+       "Coordinates:\n",
+       "  * inline            (inline) int32 1kB 1 2 3 4 5 6 ... 340 341 342 343 344 345\n",
+       "  * crossline         (crossline) int32 752B 1 2 3 4 5 6 ... 184 185 186 187 188\n",
+       "  * time              (time) int32 6kB 0 2 4 6 8 10 ... 2992 2994 2996 2998 3000\n",
+       "    cdp_x             (inline, crossline) float64 519kB ...\n",
+       "    cdp_y             (inline, crossline) float64 519kB ...\n",
+       "Data variables:\n",
+       "    segy_file_header  <U1 4B ...\n",
+       "    trace_mask        (inline, crossline) bool 65kB ...\n",
+       "    amplitude         (inline, crossline, time) float32 389MB ...\n",
+       "    headers           (inline, crossline) [('trace_seq_num_line', '<i4'), ('trace_seq_num_reel', '<i4'), ('orig_field_record_num', '<i4'), ('trace_num_orig_record', '<i4'), ('energy_source_point_num', '<i4'), ('ensemble_num', '<i4'), ('trace_num_ensemble', '<i4'), ('trace_id_code', '<i2'), ('vertically_summed_traces', '<i2'), ('horizontally_stacked_traces', '<i2'), ('data_use', '<i2'), ('source_to_receiver_distance', '<i4'), ('receiver_group_elevation', '<i4'), ('source_surface_elevation', '<i4'), ('source_depth_below_surface', '<i4'), ('receiver_datum_elevation', '<i4'), ('source_datum_elevation', '<i4'), ('source_water_depth', '<i4'), ('receiver_water_depth', '<i4'), ('elevation_depth_scalar', '<i2'), ('coordinate_scalar', '<i2'), ('source_coord_x', '<i4'), ('source_coord_y', '<i4'), ('group_coord_x', '<i4'), ('group_coord_y', '<i4'), ('coordinate_unit', '<i2'), ('weathering_velocity', '<i2'), ('subweathering_velocity', '<i2'), ('source_uphole_time', '<i2'), ('group_uphole_time', '<i2'), ('source_static_correction', '<i2'), ('receiver_static_correction', '<i2'), ('total_static_applied', '<i2'), ('lag_time_a', '<i2'), ('lag_time_b', '<i2'), ('delay_recording_time', '<i2'), ('mute_time_start', '<i2'), ('mute_time_end', '<i2'), ('samples_per_trace', '<i2'), ('sample_interval', '<i2'), ('instrument_gain_type', '<i2'), ('instrument_gain_const', '<i2'), ('instrument_gain_initial', '<i2'), ('correlated_data', '<i2'), ('sweep_freq_start', '<i2'), ('sweep_freq_end', '<i2'), ('sweep_length', '<i2'), ('sweep_type', '<i2'), ('sweep_taper_start', '<i2'), ('sweep_taper_end', '<i2'), ('taper_type', '<i2'), ('alias_filter_freq', '<i2'), ('alias_filter_slope', '<i2'), ('notch_filter_freq', '<i2'), ('notch_filter_slope', '<i2'), ('low_cut_freq', '<i2'), ('high_cut_freq', '<i2'), ('low_cut_slope', '<i2'), ('high_cut_slope', '<i2'), ('year_recorded', '<i2'), ('day_of_year', '<i2'), ('hour_of_day', '<i2'), ('minute_of_hour', '<i2'), ('second_of_minute', '<i2'), ('time_basis_code', '<i2'), ('trace_weighting_factor', '<i2'), ('group_num_roll_switch', '<i2'), ('group_num_first_trace', '<i2'), ('group_num_last_trace', '<i2'), ('gap_size', '<i2'), ('taper_overtravel', '<i2'), ('inline', '<i4'), ('crossline', '<i4'), ('cdp_x', '<i4'), ('cdp_y', '<i4')] 13MB ...\n",
+       "    fast_crossline    (inline, crossline, time) float32 389MB ...\n",
+       "    fast_inline       (inline, crossline, time) float32 389MB ...\n",
+       "    fast_time         (inline, crossline, time) float32 389MB ...\n",
+       "Attributes:\n",
+       "    apiVersion:  1.1.1\n",
+       "    createdOn:   2025-12-19 16:05:58.230520+00:00\n",
+       "    name:        PostStack3DTime\n",
+       "    attributes:  {'surveyType': '3D', 'gatherType': 'stacked', 'defaultVariab...
" + ], + "text/plain": [ + " Size: 2GB\n", + "Dimensions: (inline: 345, crossline: 188, time: 1501)\n", + "Coordinates:\n", + " * inline (inline) int32 1kB 1 2 3 4 5 6 ... 340 341 342 343 344 345\n", + " * crossline (crossline) int32 752B 1 2 3 4 5 6 ... 184 185 186 187 188\n", + " * time (time) int32 6kB 0 2 4 6 8 10 ... 2992 2994 2996 2998 3000\n", + " cdp_x (inline, crossline) float64 519kB ...\n", + " cdp_y (inline, crossline) float64 519kB ...\n", + "Data variables:\n", + " segy_file_header " + "
" ] }, "metadata": {}, @@ -402,159 +1450,116 @@ ], "source": [ "import matplotlib.pyplot as plt\n", - "from matplotlib.axes import Axes\n", - "from matplotlib.image import AxesImage\n", - "from mpl_toolkits.axes_grid1 import make_axes_locatable\n", - "from numpy.typing import NDArray\n", "\n", + "from mdio.builder.schemas.v1.stats import SummaryStatistics\n", + "\n", + "stats = SummaryStatistics.model_validate_json(ds.amplitude.attrs[\"statsV1\"])\n", "imshow_kw = {\n", - " \"vmin\": -3 * std,\n", - " \"vmax\": 3 * std,\n", + " \"vmin\": -3 * stats.std,\n", + " \"vmax\": 3 * stats.std,\n", " \"cmap\": \"gray_r\",\n", " \"interpolation\": \"bilinear\",\n", - " \"aspect\": \"auto\",\n", + " \"yincrease\": False,\n", + " \"add_colorbar\": False,\n", "}\n", "\n", + "fig, ax = plt.subplots(1, 4, sharex=\"all\", sharey=\"all\", figsize=(8, 5))\n", "\n", - "def attach_colorbar(image: AxesImage, axis: Axes) -> None:\n", - " \"\"\"Attach a colorbar to an axis.\"\"\"\n", - " divider = make_axes_locatable(axis)\n", - " cax = divider.append_axes(\"top\", size=\"2%\", pad=0.05)\n", - " plt.colorbar(image, cax=cax, orientation=\"horizontal\")\n", - " cax.xaxis.set_ticks_position(\"top\")\n", - " cax.tick_params(labelsize=8)\n", - "\n", - "\n", - "def plot_image_and_cbar(data: NDArray, axis: Axes, title: str) -> None:\n", - " \"\"\"Plot an image with a colorbar.\"\"\"\n", - " image = axis.imshow(data.T, **imshow_kw)\n", - " attach_colorbar(image, axis)\n", - " axis.set_title(title, y=-0.15)\n", + "ds_inline = ds.sel(inline=200)\n", "\n", + "ds_inline.amplitude.T.plot.imshow(ax=ax[0], **imshow_kw)\n", + "ds_inline.fast_inline.T.plot.imshow(ax=ax[1], **imshow_kw)\n", "\n", - "def plot_inlines_with_diff(orig: NDArray, compressed: NDArray, title: str) -> None:\n", - " \"\"\"Plot lossless and lossy inline with their differences.\"\"\"\n", - " fig, ax = plt.subplots(1, 4, sharey=\"all\", sharex=\"all\", figsize=(8, 5))\n", + "diff = ds_inline.amplitude - ds_inline.fast_inline\n", + "diff.T.plot.imshow(ax=ax[2], **imshow_kw)\n", + "(1000 * diff).T.plot.imshow(ax=ax[3], **imshow_kw)\n", "\n", - " diff = orig[200] - compressed[200]\n", + "for axis, title in zip(ax.ravel(), [\"original\", \"lossy\", \"difference\", \"1,000xdifference\"], strict=False):\n", + " if title != \"original\":\n", + " axis.set_ylabel(\"\")\n", + " axis.set_title(title)\n", "\n", - " plot_image_and_cbar(orig[200], ax[0], \"original\")\n", - " plot_image_and_cbar(compressed[200], ax[1], \"lossy\")\n", - " plot_image_and_cbar(diff, ax[2], \"difference\")\n", - " plot_image_and_cbar(diff * 1_000, ax[3], \"1,000x difference\")\n", - "\n", - " plt.suptitle(f\"{title} ({std=})\")\n", - " fig.tight_layout()\n", - "\n", - " plt.show()\n", - "\n", - "\n", - "plot_inlines_with_diff(orig_mdio, il_mdio, \"\")" - ] - }, - { - "cell_type": "markdown", - "id": "2900c40b-c332-4334-a4cc-f0e5571c7387", - "metadata": {}, - "source": [ - "In conclusion, we show that by generating optimized, lossy compressed copies of the data\n", - "for certain access patterns yield big performance benefits when reading the data.\n", - "\n", - "The differences are orders of magnitude larger on big datasets and remote stores, given available\n", - "network bandwidth." + "fig.tight_layout();" ] }, { "cell_type": "markdown", - "id": "38f7d35e-d743-4bb3-a7a6-fa77aab08f00", + "id": "220399c2-d0a3-48cc-89a3-2594af073f73", "metadata": {}, "source": [ - "## Optimizing in Batch\n", - "\n", - "Now that we understand how rechunking and lossy compression works, we will demonstrate how\n", - "to do this in batches.\n", + "## Adjusting the Compressor\n", "\n", - "The benefit of doing the batched processing is that the dataset gets read once. This is\n", - "especially important if the original MDIO resides in a remote store like AWS S3, or Google\n", - "Cloud's GCS.\n", + "The compressor can be modified for fast access patterns but the default setting usually works quite well.\n", + "Given 1:10 compression ratio, the fidelity is quite high with the default `ZfpQuality.LOW` setting.\n", "\n", - "Note that we not are overwriting the old optimized chunks, just creating new ones with the\n", - "suffix 2 to demonstrate we can create as many version of the original data as we want." + "If you still want to use the ZFP compression but change the quality settings follow the instructions below.\n", + "We can also use `Blosc` compressor available in MDIO as well, but we will not demonstrate that here." ] }, { "cell_type": "code", - "execution_count": 12, - "id": "14c58331-0ce3-42fa-9cd0-a43574ce24bf", - "metadata": { - "ExecuteTime": { - "end_time": "2025-04-16T18:39:19.725778Z", - "start_time": "2025-04-16T18:39:15.546003Z" - } - }, + "execution_count": null, + "id": "877160c9-9bf3-47b9-92ca-1e3dd87584e2", + "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "Rechunking to fast_il2,fast_xl2,fast_z2: 100%|██████████| 3/3 [00:03<00:00, 1.29s/chunk]\n" - ] + "data": { + "text/plain": [ + "ZFP(name='zfp', mode=, tolerance=0.09305394453239418, rate=None, precision=None)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "from mdio.api.convenience import rechunk_batch\n", + "from mdio.optimize import ZfpQuality\n", + "from mdio.optimize import get_default_zfp\n", "\n", - "rechunk_batch(\n", - " writer,\n", - " chunks_list=[(4, 512, 512), (512, 4, 512), (512, 512, 4)],\n", - " suffix_list=[\"fast_il2\", \"fast_xl2\", \"fast_z2\"],\n", - " **common_kwargs,\n", - ")" + "get_default_zfp(stats, ZfpQuality.HIGH)" + ] + }, + { + "cell_type": "markdown", + "id": "48a0ece3-2ff4-41f6-9867-6296c733e7e9", + "metadata": {}, + "source": [ + "Here is a medium example. Note that the tolerance changes because it is based on dataset statistics and compression quality setting." ] }, { "cell_type": "code", - "execution_count": 13, - "id": "a3f47a17-7537-4fc9-b3ab-b57badae18d1", - "metadata": { - "ExecuteTime": { - "end_time": "2025-04-16T18:39:33.576164Z", - "start_time": "2025-04-16T18:39:33.559671Z" - } - }, + "execution_count": null, + "id": "255713b5-988a-431f-a171-846bba87b228", + "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(64, 64, 64) (Blosc(cname='zstd', clevel=5, shuffle=SHUFFLE, blocksize=0),)\n", - "(4, 188, 512) (ZFPY(mode=4, tolerance=2.791618335971825, rate=-1, precision=-1),)\n", - "(345, 4, 512) (ZFPY(mode=4, tolerance=2.791618335971825, rate=-1, precision=-1),)\n", - "(345, 188, 4) (ZFPY(mode=4, tolerance=2.791618335971825, rate=-1, precision=-1),)\n" - ] + "data": { + "text/plain": [ + "ZFP(name='zfp', mode=, tolerance=0.9305394453239417, rate=None, precision=None)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "from mdio import MDIOReader\n", - "\n", - "orig_mdio = MDIOReader(mdio_path)\n", - "il2_mdio = MDIOReader(mdio_path, access_pattern=\"fast_il2\")\n", - "xl2_mdio = MDIOReader(mdio_path, access_pattern=\"fast_xl2\")\n", - "z2_mdio = MDIOReader(mdio_path, access_pattern=\"fast_z2\")\n", - "\n", - "print(orig_mdio.chunks, orig_mdio._traces.compressors)\n", - "print(il_mdio.chunks, il2_mdio._traces.compressors)\n", - "print(xl_mdio.chunks, xl2_mdio._traces.compressors)\n", - "print(z_mdio.chunks, z2_mdio._traces.compressors)" + "get_default_zfp(stats, ZfpQuality.MEDIUM)" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "008950d1-b142-4ca8-9879-f926007c97ca", + "cell_type": "markdown", + "id": "2900c40b-c332-4334-a4cc-f0e5571c7387", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "In conclusion, we show that by generating optimized, lossy compressed copies of the data\n", + "for certain access patterns yield big performance benefits when reading the data.\n", + "\n", + "The differences are orders of magnitude larger on big datasets and remote stores, given available\n", + "network bandwidth." + ] } ], "metadata": { @@ -572,8 +1577,7 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.1" + "pygments_lexer": "ipython3" } }, "nbformat": 4, diff --git a/noxfile.py b/noxfile.py index 64e715e4b..cbb0e96cc 100644 --- a/noxfile.py +++ b/noxfile.py @@ -169,7 +169,9 @@ def mypy(session: Session) -> None: def tests(session: Session) -> None: """Run the test suite.""" session_install_uv(session) - session_install_uv_package(session, ["coverage[toml]", "pytest", "pygments", "pytest-dependency", "s3fs"]) + session_install_uv_package( + session, ["coverage[toml]", "pytest", "pygments", "pytest-dependency", "s3fs", "distributed", "zfpy"] + ) try: session.run("coverage", "run", "--parallel", "-m", "pytest", *session.posargs) diff --git a/src/mdio/__init__.py b/src/mdio/__init__.py index 5fed389c8..857fb8064 100644 --- a/src/mdio/__init__.py +++ b/src/mdio/__init__.py @@ -1,11 +1,15 @@ """MDIO library.""" +from __future__ import annotations + from importlib import metadata from mdio.api.io import open_mdio from mdio.api.io import to_mdio from mdio.converters import mdio_to_segy from mdio.converters import segy_to_mdio +from mdio.optimize.access_pattern import OptimizedAccessPatternConfig +from mdio.optimize.access_pattern import optimize_access_patterns try: __version__ = metadata.version("multidimio") @@ -19,4 +23,6 @@ "to_mdio", "mdio_to_segy", "segy_to_mdio", + "OptimizedAccessPatternConfig", + "optimize_access_patterns", ] diff --git a/src/mdio/api/io.py b/src/mdio/api/io.py index 2654be315..ea5b22dc3 100644 --- a/src/mdio/api/io.py +++ b/src/mdio/api/io.py @@ -13,6 +13,7 @@ from xarray.backends.writers import to_zarr as xr_to_zarr from mdio.constants import ZarrFormat +from mdio.core.zarr_io import zarr_warnings_suppress_unstable_numcodecs_v3 from mdio.core.zarr_io import zarr_warnings_suppress_unstable_structs_v3 if TYPE_CHECKING: @@ -53,13 +54,14 @@ def open_mdio(input_path: UPath | Path | str, chunks: T_Chunks = None) -> xr_Dat storage_options = _normalize_storage_options(input_path) zarr_format = zarr.config.get("default_zarr_format") - return xr_open_zarr( - input_path.as_posix(), - chunks=chunks, - storage_options=storage_options, - mask_and_scale=zarr_format == ZarrFormat.V3, # off for v2, on for v3 - consolidated=zarr_format == ZarrFormat.V2, # on for v2, off for v3 - ) + with zarr_warnings_suppress_unstable_numcodecs_v3(): + return xr_open_zarr( + input_path.as_posix(), + chunks=chunks, + storage_options=storage_options, + mask_and_scale=zarr_format == ZarrFormat.V3, # off for v2, on for v3 + consolidated=zarr_format == ZarrFormat.V2, # on for v2, off for v3 + ) def to_mdio( # noqa: PLR0913 @@ -90,7 +92,7 @@ def to_mdio( # noqa: PLR0913 storage_options = _normalize_storage_options(output_path) zarr_format = zarr.config.get("default_zarr_format") - with zarr_warnings_suppress_unstable_structs_v3(): + with zarr_warnings_suppress_unstable_structs_v3(), zarr_warnings_suppress_unstable_numcodecs_v3(): xr_to_zarr( dataset, store=output_path.as_posix(), # xarray doesn't like URI when file:// is protocol diff --git a/src/mdio/builder/schemas/v1/stats.py b/src/mdio/builder/schemas/v1/stats.py index ccd12fc79..f4ec7d0b8 100644 --- a/src/mdio/builder/schemas/v1/stats.py +++ b/src/mdio/builder/schemas/v1/stats.py @@ -56,3 +56,18 @@ class SummaryStatistics(CamelCaseStrictModel): min: float = Field(..., description="The smallest value in the variable.") max: float = Field(..., description="The largest value in the variable.") histogram: Histogram = Field(..., description="Binned frequency distribution.") + + @property + def mean(self) -> float: + """Returns the mean of the data.""" + return self.sum / self.count + + @property + def variance(self) -> float: + """Returns the variance of the data.""" + return (self.sum_squares / self.count) - (self.mean**2) + + @property + def std(self) -> float: + """Returns the standard deviation of the data.""" + return self.variance**0.5 diff --git a/src/mdio/builder/xarray_builder.py b/src/mdio/builder/xarray_builder.py index 58501cba9..58e1f5834 100644 --- a/src/mdio/builder/xarray_builder.py +++ b/src/mdio/builder/xarray_builder.py @@ -1,24 +1,15 @@ """Convert MDIO v1 schema Dataset to Xarray DataSet and write it in Zarr.""" +import numcodecs import numpy as np import zarr from dask import array as dask_array from dask.array.core import normalize_chunks -from numcodecs import Blosc +from numcodecs import Blosc as numcodecs_Blosc from xarray import DataArray as xr_DataArray from xarray import Dataset as xr_Dataset -from zarr.codecs import BloscCodec - -from mdio.converters.type_converter import to_numpy_dtype - -try: - # zfpy is an optional dependency for ZFP compression - # It is not installed by default, so we check for its presence and import it only if available. - from numcodecs import ZFPY as zfpy_ZFPY # noqa: N811 - from zarr.codecs.numcodecs import ZFPY as zarr_ZFPY # noqa: N811 -except ImportError: - zfpy_ZFPY = None # noqa: N816 - zarr_ZFPY = None # noqa: N816 +from zarr.codecs import BloscCodec as zarr_BloscCodec +from zarr.codecs.numcodecs import ZFPY as zarr_ZFPY # noqa: N811 from mdio.builder.schemas.compressors import ZFP as mdio_ZFP # noqa: N811 from mdio.builder.schemas.compressors import Blosc as mdio_Blosc @@ -30,6 +21,15 @@ from mdio.builder.schemas.v1.variable import Variable from mdio.constants import ZarrFormat from mdio.constants import fill_value_map +from mdio.converters.type_converter import to_numpy_dtype +from mdio.core.zarr_io import zarr_warnings_suppress_unstable_numcodecs_v3 + + +def _import_numcodecs_zfpy() -> "type[numcodecs.ZFPY]": + """Helper to import the optional dependency at runtime.""" + from numcodecs import ZFPY as numcodecs_ZFPY # noqa: PLC0415, N811 + + return numcodecs_ZFPY def _get_all_named_dimensions(dataset: Dataset) -> dict[str, NamedDimension]: @@ -125,7 +125,7 @@ def _get_zarr_chunks(var: Variable, all_named_dims: dict[str, NamedDimension]) - def _compressor_to_encoding( compressor: mdio_Blosc | mdio_ZFP | None, -) -> dict[str, BloscCodec | Blosc | zfpy_ZFPY | zarr_ZFPY] | None: +) -> dict[str, "zarr.codecs.Blosc | numcodecs.Blosc | numcodecs.ZFPY | zarr.codecs.ZFPY | None"] | None: """Convert a compressor to a numcodecs compatible format.""" if compressor is None: return None @@ -140,17 +140,22 @@ def _compressor_to_encoding( if isinstance(compressor, mdio_Blosc): if is_v2 and kwargs["shuffle"] is None: kwargs["shuffle"] = -1 - codec_cls = Blosc if is_v2 else BloscCodec + codec_cls = numcodecs_Blosc if is_v2 else zarr_BloscCodec return {"compressors": codec_cls(**kwargs)} # must be ZFP beyond here - if zfpy_ZFPY is None: - msg = "zfpy and numcodecs are required to use ZFP compression" - raise ImportError(msg) + try: + numcodecs_ZFPY = _import_numcodecs_zfpy() # noqa: N806 + except ImportError as e: + msg = "The 'zfpy' package is required for lossy compression. Install via 'pip install multidimio[lossy]'." + raise ImportError(msg) from e + kwargs["mode"] = compressor.mode.int_code if is_v2: - return {"compressors": zfpy_ZFPY(**kwargs)} - return {"serializer": zarr_ZFPY(**kwargs), "compressors": None} + return {"compressors": numcodecs_ZFPY(**kwargs)} + with zarr_warnings_suppress_unstable_numcodecs_v3(): + serializer = zarr_ZFPY(**kwargs) + return {"serializer": serializer, "compressors": None} def _get_fill_value(data_type: ScalarType | StructuredType | str) -> any: diff --git a/src/mdio/core/zarr_io.py b/src/mdio/core/zarr_io.py index 844ce761f..c3ecfb556 100644 --- a/src/mdio/core/zarr_io.py +++ b/src/mdio/core/zarr_io.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING from zarr.errors import UnstableSpecificationWarning +from zarr.errors import ZarrUserWarning if TYPE_CHECKING: from collections.abc import Generator @@ -21,3 +22,14 @@ def zarr_warnings_suppress_unstable_structs_v3() -> Generator[None, None, None]: yield finally: pass + + +@contextmanager +def zarr_warnings_suppress_unstable_numcodecs_v3() -> Generator[None, None, None]: + """Context manager to suppress Zarr V3 unstable numcodecs warning.""" + warn = r"Numcodecs codecs are not in the Zarr version 3 specification" + warnings.filterwarnings("ignore", message=warn, category=ZarrUserWarning) + try: + yield + finally: + pass diff --git a/src/mdio/optimize/__init__.py b/src/mdio/optimize/__init__.py new file mode 100644 index 000000000..175266215 --- /dev/null +++ b/src/mdio/optimize/__init__.py @@ -0,0 +1,6 @@ +"""Module for optimizing datasets for various access patterns / LOD etc.""" + +from mdio.optimize.common import ZfpQuality +from mdio.optimize.common import get_default_zfp + +__all__ = ["ZfpQuality", "get_default_zfp"] diff --git a/src/mdio/optimize/access_pattern.py b/src/mdio/optimize/access_pattern.py new file mode 100644 index 000000000..2a21c74b4 --- /dev/null +++ b/src/mdio/optimize/access_pattern.py @@ -0,0 +1,118 @@ +"""Optimize MDIO seismic datasets for fast access patterns using ZFP compression and Dask. + +This module provides tools to create compressed, rechunked transpose views of seismic data for efficient +access along dataset dimensions. It uses configurable ZFP compression based on data statistics and +supports parallel processing with Dask Distributed. +""" + +import logging + +from pydantic import BaseModel +from pydantic import Field +from xarray import Dataset as xr_Dataset + +from mdio import to_mdio +from mdio.builder.schemas.compressors import ZFP +from mdio.builder.schemas.compressors import Blosc +from mdio.builder.schemas.v1.stats import SummaryStatistics +from mdio.builder.xarray_builder import _compressor_to_encoding +from mdio.optimize.common import apply_compressor_encoding +from mdio.optimize.common import get_default_zfp +from mdio.optimize.common import get_or_create_client + +logger = logging.getLogger(__name__) + + +class OptimizedAccessPatternConfig(BaseModel): + """Configuration for fast access pattern optimization.""" + + optimize_dimensions: dict[str, tuple[int, ...]] = Field(..., description="Optimize dims and desired chunks.") + processing_chunks: dict[str, int] = Field(..., description="Chunk sizes for processing the original variable.") + compressor: Blosc | ZFP | None = Field(default=None, description="Compressor to use for access patterns.") + + +def optimize_access_patterns( + dataset: xr_Dataset, + config: OptimizedAccessPatternConfig, + n_workers: int = 1, + threads_per_worker: int = 1, +) -> None: + """Optimize MDIO dataset for fast access along dimensions. + + Optimize an MDIO dataset by creating compressed, rechunked views for fast access along + configurable dimensions, then append them to the existing MDIO file. + + This uses ZFP compression with tolerance based on data standard deviation and the provided quality level. + Requires Dask Distributed for parallel execution. It will try to grab the existing distributed.Client + or create its own. Existing Client will be kept running after optimization. + + Args: + dataset: MDIO Dataset containing the seismic data. + config: Configuration object with quality, access patterns, and processing chunks. + n_workers: Number of Dask workers. Default is 1. + threads_per_worker: Threads per Dask worker. Default is 1. + + Raises: + ValueError: If required attrs/stats are missing or the dataset is invalid. + + Examples: + For Post-Stack 3D seismic data, we can optimize the inline, crossline, and depth dimensions. + + >>> from mdio import optimize_access_patterns, OptimizedAccessPatternConfig + >>> from mdio import open_mdio + >>> + >>> conf = OptimizedAccessPatternConfig( + >>> optimize_dimensions={ + >>> "inline": (4, 512, 512), + >>> "crossline": (512, 4, 512), + >>> "time": (512, 512, 4), + >>> }, + >>> processing_chunks= {"inline": 512, "crossline": 512, "time": 512} + >>> ) + >>> + >>> ds = open_mdio("/path/to/seismic.mdio") + >>> optimize_access_patterns(ds, conf, n_workers=4) + """ + # Extract and validate key attrs + attrs = dataset.attrs.get("attributes", {}) + var_name = attrs.get("defaultVariableName") + if not var_name: + msg = "Default variable name is missing from dataset attributes." + raise ValueError(msg) + + variable = dataset[var_name] + chunked_var = variable.chunk(**config.processing_chunks, inline_array=True) + + if config.compressor is None: + if "statsV1" not in variable.attrs: + msg = "Statistics are missing from data. Std. dev. is required for compression." + raise ValueError(msg) + + logger.info("No compressor provided, using default ZFP compression with MEDIUM quality.") + stats = SummaryStatistics.model_validate_json(variable.attrs["statsV1"]) + default_zfp = get_default_zfp(stats) + config.compressor = default_zfp + + compressor_encoding = _compressor_to_encoding(config.compressor) + + optimized_variables = {} + for dim_name, dim_new_chunks in config.optimize_dimensions.items(): + if dim_name not in chunked_var.dims: + msg = f"Dimension to optimize '{dim_name}' not found in original dataset dims: {chunked_var.dims}." + raise ValueError(msg) + optimized_var = apply_compressor_encoding(chunked_var, dim_new_chunks, compressor_encoding) + optimized_var.name = f"fast_{dim_name}" + optimized_variables[optimized_var.name] = optimized_var + + optimized_dataset = xr_Dataset(optimized_variables, attrs=dataset.attrs) + source_path = dataset.encoding["source"] + + with get_or_create_client(n_workers=n_workers, threads_per_worker=threads_per_worker) as client: + # The context manager ensures distributed is installed so we can try to register the plugin + # safely. The plugin is conditionally created based on the installation status of distributed + from mdio.optimize.patch import MonkeyPatchZfpDaskPlugin # noqa: PLC0415 + + client.register_plugin(MonkeyPatchZfpDaskPlugin()) + logger.info("Starting optimization with compressor %s.", compressor_encoding) + to_mdio(optimized_dataset, source_path, mode="a") + logger.info("Optimization completed successfully.") diff --git a/src/mdio/optimize/common.py b/src/mdio/optimize/common.py new file mode 100644 index 000000000..00cf85a36 --- /dev/null +++ b/src/mdio/optimize/common.py @@ -0,0 +1,86 @@ +"""Common optimization utilities.""" + +from __future__ import annotations + +import logging +from contextlib import contextmanager +from enum import Enum +from typing import TYPE_CHECKING +from typing import Any + +from mdio.builder.schemas.compressors import ZFP +from mdio.builder.schemas.compressors import ZFPMode + +if TYPE_CHECKING: + from collections.abc import Generator + + from xarray import DataArray + + from mdio.builder.schemas.v1.stats import SummaryStatistics + + +try: + import distributed +except ImportError: + distributed = None + + +logger = logging.getLogger(__name__) + + +class ZfpQuality(float, Enum): + """Config options for ZFP compression.""" + + VERY_LOW = 6 + LOW = 3 + MEDIUM = 1 + HIGH = 0.1 + VERY_HIGH = 0.01 + ULTRA = 0.001 + + +def get_default_zfp( + stats: SummaryStatistics, + quality: ZfpQuality = ZfpQuality.LOW, +) -> ZFP: + """Compute ZFP encoding based on data statistics and quality level.""" + if stats.std is None or stats.std <= 0: + msg = "Standard deviation must be positive for tolerance calculation." + raise ValueError(msg) + + tolerance = quality.value * stats.std + logger.info("Computed ZFP tolerance: %s (quality: %s, std: %s)", tolerance, quality.name, stats.std) + return ZFP(mode=ZFPMode.FIXED_ACCURACY, tolerance=tolerance) + + +def apply_compressor_encoding( + data_array: DataArray, chunks: tuple[int, ...], zfp_encoding: dict[str, Any] +) -> DataArray: + """Apply ZFP encoding and custom chunks to a DataArray copy.""" + # Drop coordinates to avoid re-writing them and avoid rechunking issues in views + data_array = data_array.copy().reset_coords(drop=True) + data_array.encoding.update(zfp_encoding) + data_array.encoding["chunks"] = chunks + return data_array + + +@contextmanager +def get_or_create_client(n_workers: int, threads_per_worker: int) -> Generator[distributed.Client, None, None]: + """Get or create a Dask Distributed Client.""" + if distributed is None: + msg = "The 'distributed' package is required for processing. Install: 'pip install multidimio[distributed]'." + raise ImportError(msg) + + created = False + try: + client = distributed.Client.current() + logger.info("Using existing Dask client: %s", client) + except ValueError: + logger.info("No active Dask client found. Creating a new one.") + client = distributed.Client(n_workers=n_workers, threads_per_worker=threads_per_worker) + created = True + try: + yield client + finally: + if created: + client.close() diff --git a/src/mdio/optimize/patch.py b/src/mdio/optimize/patch.py new file mode 100644 index 000000000..b3781f700 --- /dev/null +++ b/src/mdio/optimize/patch.py @@ -0,0 +1,52 @@ +"""Dask worker plugins for monkey patching ZFP due to bug. + +We can remove this once the fix is upstreamed: +https://github.com/zarr-developers/numcodecs/issues/812 +https://github.com/zarr-developers/numcodecs/pull/811 +""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +import numpy as np +from numcodecs import blosc +from zarr.codecs import numcodecs + +if TYPE_CHECKING: + from zarr.core.array_spec import ArraySpec + from zarr.core.buffer import Buffer + from zarr.core.buffer import NDBuffer + + +try: + import distributed +except ModuleNotFoundError: + distributed = None + + +class ZFPY(numcodecs.ZFPY, codec_name="zfpy"): + """Monkey patch ZFP codec to make input array contiguous before encoding.""" + + async def _encode_single(self, chunk_data: NDBuffer, chunk_spec: ArraySpec) -> Buffer: + chunk_ndarray = chunk_data.as_ndarray_like() + if not chunk_ndarray.flags.c_contiguous: + chunk_ndarray = np.ascontiguousarray(chunk_ndarray) + out = await asyncio.to_thread(self._codec.encode, chunk_ndarray) + return chunk_spec.prototype.buffer.from_bytes(out) + + +if distributed is not None: + + class MonkeyPatchZfpDaskPlugin(distributed.WorkerPlugin): + """Monkey patch ZFP codec and disable Blosc threading for Dask workers. + + Note that this is class is only importable if distributed is installed. However, in the caller + function we have a context manager that checks if distributed is installed, so it is safe (for now). + """ + + def setup(self, worker: distributed.Worker) -> None: # noqa: ARG002 + """Monkey patch ZFP codec and disable Blosc threading.""" + numcodecs._codecs.ZFPY = ZFPY + blosc.set_nthreads(1) diff --git a/tests/unit/test_optimize_access_pattern.py b/tests/unit/test_optimize_access_pattern.py new file mode 100644 index 000000000..fa119cfd1 --- /dev/null +++ b/tests/unit/test_optimize_access_pattern.py @@ -0,0 +1,166 @@ +"""Unit tests for optimize_access_pattern module.""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING +from unittest.mock import patch + +import numpy as np +import pytest +from distributed import Client +from segy import SegyFactory +from segy.standards import get_segy_standard +from zarr.codecs import ZFPY as zarr_ZFPY # noqa: N811 +from zarr.codecs import BloscCodec as zarr_BloscCodec + +from mdio import open_mdio +from mdio import segy_to_mdio +from mdio.builder.schemas.compressors import Blosc as mdio_Blosc +from mdio.builder.schemas.compressors import BloscCname +from mdio.builder.template_registry import get_template +from mdio.optimize.access_pattern import OptimizedAccessPatternConfig +from mdio.optimize.access_pattern import optimize_access_patterns + +if TYPE_CHECKING: + from pathlib import Path + + +INLINES = np.arange(1, 9) +CROSSLINES = np.arange(1, 17) +NUM_SAMPLES = 64 + +SPEC = get_segy_standard(1) + + +@pytest.fixture(scope="module") +def test_segy_path(fake_segy_tmp: Path) -> Path: + """Create a small synthetic 3D SEG-Y file.""" + segy_path = fake_segy_tmp / "optimize_ap_test_3d.sgy" + + num_traces = len(INLINES) * len(CROSSLINES) + + factory = SegyFactory(spec=SPEC, sample_interval=4000, samples_per_trace=NUM_SAMPLES) + headers = factory.create_trace_header_template(num_traces) + samples = factory.create_trace_sample_template(num_traces) + + headers["inline"] = INLINES.repeat(len(CROSSLINES)) + headers["crossline"] = np.tile(CROSSLINES, len(INLINES)) + headers["coordinate_scalar"] = 1 + + samples[:] = np.arange(num_traces)[..., None] + + with segy_path.open(mode="wb") as fp: + fp.write(factory.create_textual_header()) + fp.write(factory.create_binary_header()) + fp.write(factory.create_traces(headers, samples)) + + return segy_path + + +@pytest.fixture(scope="module") +def mdio_dataset_path(test_segy_path: Path, zarr_tmp: Path) -> Path: + """Convert synthetic SEG-Y to MDIO.""" + test_mdio_path = zarr_tmp / "optimize_ap_test_3d.mdio" + + env = { + "MDIO__IMPORT__CPU_COUNT": "true", + "MDIO__IMPORT__CLOUD_NATIVE": "true", + } + patch.dict(os.environ, env) + segy_to_mdio( + segy_spec=SPEC, + mdio_template=get_template("PostStack3DTime"), + input_path=test_segy_path, + output_path=test_mdio_path, + overwrite=True, + ) + return test_mdio_path + + +class TestOptimizeAccessPattern: + """Tests for optimize_access_pattern module.""" + + def test_optimize_access_patterns(self, mdio_dataset_path: str) -> None: + """Test optimization of access patterns.""" + conf = OptimizedAccessPatternConfig( + optimize_dimensions={"time": (128, 128, 4), "inline": (2, 64, 64)}, + processing_chunks={"inline": 128, "crossline": 128, "time": 128}, + ) + ds = open_mdio(mdio_dataset_path) + optimize_access_patterns(ds, conf) + + ds = open_mdio(mdio_dataset_path) + + assert "fast_time" in ds.variables + assert ds["fast_time"].encoding["chunks"] == (128, 128, 4) + assert isinstance(ds["fast_time"].encoding["serializer"], zarr_ZFPY) + + assert "inline" in ds.variables + assert ds["fast_inline"].encoding["chunks"] == (2, 64, 64) + assert isinstance(ds["fast_inline"].encoding["serializer"], zarr_ZFPY) + + def test_optimize_access_patterns_custom_compressor(self, mdio_dataset_path: str) -> None: + """Test optimization of access patterns with custom compressor.""" + conf = OptimizedAccessPatternConfig( + optimize_dimensions={"crossline": (32, 8, 32)}, + processing_chunks={"inline": 512, "crossline": 512, "time": 512}, + compressor=mdio_Blosc(cname=BloscCname.blosclz, clevel=1), + ) + ds = open_mdio(mdio_dataset_path) + optimize_access_patterns(ds, conf) + + ds = open_mdio(mdio_dataset_path) + + actual_compressor = ds["fast_crossline"].encoding["compressors"][0] + assert "fast_crossline" in ds.variables + assert ds["fast_crossline"].encoding["chunks"] == (32, 8, 32) + assert isinstance(actual_compressor, zarr_BloscCodec) + assert actual_compressor.cname == BloscCname.blosclz + assert actual_compressor.clevel == 1 + + def test_user_provided_client(self, mdio_dataset_path: str) -> None: + """Test when user provides a dask client is present.""" + conf = OptimizedAccessPatternConfig( + optimize_dimensions={"time": (128, 128, 4)}, + processing_chunks={"inline": 128, "crossline": 128, "time": 128}, + ) + ds = open_mdio(mdio_dataset_path) + + with Client(processes=False): + optimize_access_patterns(ds, conf) + + def test_missing_default_variable_name(self, mdio_dataset_path: str) -> None: + """Test case where default variable name is missing from dataset attributes.""" + conf = OptimizedAccessPatternConfig( + optimize_dimensions={"time": (128, 128, 4)}, + processing_chunks={"inline": 128, "crossline": 128, "time": 128}, + ) + ds = open_mdio(mdio_dataset_path) + del ds.attrs["attributes"] + + with pytest.raises(ValueError, match="Default variable name is missing from dataset attributes"): + optimize_access_patterns(ds, conf) + + def test_missing_stats(self, mdio_dataset_path: str) -> None: + """Test case where statistics are missing from default variable.""" + conf = OptimizedAccessPatternConfig( + optimize_dimensions={"time": (128, 128, 4)}, + processing_chunks={"inline": 128, "crossline": 128, "time": 128}, + ) + ds = open_mdio(mdio_dataset_path) + del ds["amplitude"].attrs["statsV1"] + + with pytest.raises(ValueError, match="Statistics are missing from data"): + optimize_access_patterns(ds, conf) + + def test_invalid_optimize_access_patterns(self, mdio_dataset_path: str) -> None: + """Test when optimize_dimensions contains invalid dimensions.""" + conf = OptimizedAccessPatternConfig( + optimize_dimensions={"time": (128, 128, 4), "invalid": (4, 2, 44)}, + processing_chunks={"inline": 128, "crossline": 128, "time": 128}, + ) + ds = open_mdio(mdio_dataset_path) + + with pytest.raises(ValueError, match="Dimension to optimize 'invalid' not found"): + optimize_access_patterns(ds, conf) diff --git a/tests/unit/v1/test_dataset_serializer.py b/tests/unit/v1/test_dataset_serializer.py index 81cc5e834..5db176104 100644 --- a/tests/unit/v1/test_dataset_serializer.py +++ b/tests/unit/v1/test_dataset_serializer.py @@ -1,15 +1,24 @@ """Tests the schema v1 dataset_serializer public API.""" -from pathlib import Path +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import patch import numpy as np import pytest +from zarr.codecs import ZFPY as zarr_ZFPY # noqa: N811 from zarr.codecs import BloscCodec from mdio import to_mdio from mdio.builder.dataset_builder import MDIODatasetBuilder from mdio.builder.schemas.chunk_grid import RegularChunkGrid from mdio.builder.schemas.chunk_grid import RegularChunkShape +from mdio.builder.schemas.compressors import ZFP as MDIO_ZFP +from mdio.builder.schemas.compressors import Blosc as mdio_Blosc +from mdio.builder.schemas.compressors import BloscCname +from mdio.builder.schemas.compressors import BloscShuffle +from mdio.builder.schemas.compressors import ZFPMode as mdio_ZFPMode from mdio.builder.schemas.dimension import NamedDimension from mdio.builder.schemas.dtype import ScalarType from mdio.builder.schemas.dtype import StructuredField @@ -31,20 +40,13 @@ from .helpers import make_seismic_poststack_3d_acceptance_dataset -try: # pragma: no cover - from zfpy import ZFPY +if TYPE_CHECKING: + from pathlib import Path - HAS_ZFPY = True +try: # pragma: no cover + import zfpy except ImportError: - ZFPY = None - HAS_ZFPY = False - - -from mdio.builder.schemas.compressors import ZFP as MDIO_ZFP -from mdio.builder.schemas.compressors import Blosc as mdio_Blosc -from mdio.builder.schemas.compressors import BloscCname -from mdio.builder.schemas.compressors import BloscShuffle -from mdio.builder.schemas.compressors import ZFPMode as mdio_ZFPMode + zfpy = None def test_get_all_named_dimensions() -> None: @@ -226,52 +228,52 @@ def test_get_fill_value() -> None: assert result_none_input is None -def test_compressor_to_encoding() -> None: - """Simple test for _compressor_to_encoding function covering basic scenarios.""" - # Test 1: None input - should return None - result_none = _compressor_to_encoding(None) - assert result_none is None +class TestCompressorToEncoding: + """Test _compressor_to_encoding function for various configurations.""" - # Test 2: mdio_Blosc compressor - should return nc_Blosc - mdio_compressor = mdio_Blosc(cname=BloscCname.lz4, clevel=5, shuffle=BloscShuffle.bitshuffle, blocksize=1024) - result_blosc = _compressor_to_encoding(mdio_compressor) + def test_compressor_encoding_blosc(self) -> None: + """Blosc Compressor - should return zarr codec BloscCodec.""" + mdio_compressor = mdio_Blosc(cname=BloscCname.lz4, clevel=5, shuffle=BloscShuffle.bitshuffle, blocksize=1024) + result = _compressor_to_encoding(mdio_compressor) - assert isinstance(result_blosc, dict) - assert "compressors" in result_blosc - assert isinstance(result_blosc["compressors"], BloscCodec) - assert result_blosc["compressors"].cname == BloscCname.lz4 - assert result_blosc["compressors"].clevel == 5 - assert result_blosc["compressors"].shuffle == BloscShuffle.bitshuffle - assert result_blosc["compressors"].blocksize == 1024 + assert isinstance(result["compressors"], BloscCodec) + assert result["compressors"].cname == BloscCname.lz4 + assert result["compressors"].clevel == 5 + assert result["compressors"].shuffle == BloscShuffle.bitshuffle + assert result["compressors"].blocksize == 1024 - # Test 3: mdio_ZFP compressor - should return zfpy_ZFPY if available - zfp_compressor = MDIO_ZFP(mode=mdio_ZFPMode.FIXED_RATE, tolerance=0.01, rate=8.0, precision=16) + def test_compressor_encoding_zfp(self) -> None: + """ZFP Compressor - should return zarr codec ZFPY.""" + zfp_compressor = MDIO_ZFP(mode=mdio_ZFPMode.FIXED_RATE, tolerance=0.01, rate=8.0, precision=16) - # TODO(BrianMichell): Update to also test zfp compression. - # https://github.com/TGSAI/mdio-python/issues/747 - if HAS_ZFPY: # pragma: no cover result_zfp = _compressor_to_encoding(zfp_compressor) - assert isinstance(result_zfp, dict) - assert "compressors" not in result_zfp - assert isinstance(result_zfp["serializer"], ZFPY) - assert result_zfp["serializer"].mode == 1 # ZFPMode.FIXED_RATE.value = "fixed_rate" - assert result_zfp["serializer"].tolerance == 0.01 - assert result_zfp["serializer"].rate == 8.0 - assert result_zfp["serializer"].precision == 16 - else: - # Test 5: mdio_ZFP without zfpy installed - should raise ImportError - with pytest.raises(ImportError) as exc_info: - _compressor_to_encoding(zfp_compressor) - error_message = str(exc_info.value) - assert "zfpy and numcodecs are required to use ZFP compression" in error_message - - # Test 6: Unsupported compressor type - should raise TypeError - unsupported_compressor = "invalid_compressor" - with pytest.raises(TypeError) as exc_info: - _compressor_to_encoding(unsupported_compressor) - error_message = str(exc_info.value) - assert "Unsupported compressor model" in error_message - assert "" in error_message + assert result_zfp["compressors"] is None + assert isinstance(result_zfp["serializer"], zarr_ZFPY) + assert result_zfp["serializer"].codec_config["mode"] == 2 # fixed rate + assert result_zfp["serializer"].codec_config["tolerance"] == 0.01 + assert result_zfp["serializer"].codec_config["rate"] == 8.0 + assert result_zfp["serializer"].codec_config["precision"] == 16 + + def test_compressor_encoding_zfp_missing(self) -> None: + """ZFP Compressor - should raise ImportError if zfpy is not installed.""" + zfp_compressor = MDIO_ZFP(mode=mdio_ZFPMode.FIXED_RATE, tolerance=0.01, rate=8.0, precision=16) + + with patch("mdio.builder.xarray_builder._import_numcodecs_zfpy") as mock_import: + mock_import.side_effect = ImportError # Simulate import failure + + with pytest.raises(ImportError, match="The 'zfpy' package is required for lossy compression."): + _compressor_to_encoding(zfp_compressor) + + def test_compressor_encoding_none(self) -> None: + """Test None encoding. Should return None.""" + result_none = _compressor_to_encoding(None) + assert result_none is None + + def test_compressor_encoding_unsupported(self) -> None: + """Test unsupported compressor type. Should raise TypeError.""" + unsupported_compressor = "invalid_compressor" + with pytest.raises(TypeError, match="Unsupported compressor model"): + _compressor_to_encoding(unsupported_compressor) def test_to_xarray_dataset(tmp_path: Path) -> None: