From ba2c727e264bba150cc4be16159bb3fd3019f94a Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Tue, 16 Dec 2025 11:12:54 -0800 Subject: [PATCH] Gracefully handle missing/broken lz4/zstd codecs to improve portability --- tests/unit/test_client.py | 49 ++++++++++++++++++++++++++++++++++++ tests/unit/test_dbapi.py | 24 ++++++++++++++++++ trino/client.py | 53 ++++++++++++++++++++++++++++++++++----- trino/dbapi.py | 6 ++--- 4 files changed, 123 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index ba5f7f28..f011a54d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -57,6 +57,7 @@ from trino.client import _retry_with from trino.client import _RetryWithExponentialBackoff from trino.client import ClientSession +from trino.client import CompressedQueryDataDecoderFactory from trino.client import TrinoQuery from trino.client import TrinoRequest from trino.client import TrinoResult @@ -1450,3 +1451,51 @@ def delete_password(self, servicename, username): return None os.remove(file_path) + + +def test_trino_request_headers_encoding_default_behavior(): + session = ClientSession(user="test", encoding=None) + + # Case 1: Both available -> No header + with mock.patch("trino.client.CODECS_UNAVAILABLE", {}): + req = TrinoRequest("host", 8080, session) + assert constants.HEADER_ENCODING not in req.http_headers + + # Case 2: Zstd missing -> Header set with json+lz4,json + with mock.patch("trino.client.CODECS_UNAVAILABLE", {"zstd": "Not installed"}): + req = TrinoRequest("host", 8080, session) + assert req.http_headers[constants.HEADER_ENCODING] == "json+lz4,json" + + # Case 3: Lz4 missing -> Header set with json+zstd,json + with mock.patch("trino.client.CODECS_UNAVAILABLE", {"lz4": "Not installed"}): + req = TrinoRequest("host", 8080, session) + assert req.http_headers[constants.HEADER_ENCODING] == "json+zstd,json" + + # Case 4: Both missing -> Header set with json + with mock.patch("trino.client.CODECS_UNAVAILABLE", {"lz4": "Not installed", "zstd": "Not installed"}): + req = TrinoRequest("host", 8080, session) + assert req.http_headers[constants.HEADER_ENCODING] == "json" + + +def test_decoder_factory_raises_with_message_on_missing_zstd(): + mapper = mock.Mock() + factory = CompressedQueryDataDecoderFactory(mapper) + error_message = "No module named 'zstandard'" + with mock.patch("trino.client.CODECS_UNAVAILABLE", {"zstd": error_message}): + with pytest.raises( + ValueError, + match=f"zstd is not installed so json\\+zstd encoding is not supported: {error_message}" + ): + factory.create("json+zstd") + + +def test_decoder_factory_raises_with_message_on_missing_lz4(): + mapper = mock.Mock() + factory = CompressedQueryDataDecoderFactory(mapper) + error_message = "No module named 'lz4.block'" + with mock.patch("trino.client.CODECS_UNAVAILABLE", {"lz4": error_message}): + with pytest.raises( + ValueError, + match=f"lz4 is not installed so json\\+lz4 encoding is not supported: {error_message}" + ): + factory.create("json+lz4") diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index 080a3904..e3821bba 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -338,3 +338,27 @@ def test_description_is_none_when_cursor_is_not_executed(): def test_setting_http_scheme(host, port, http_scheme_input_argument, http_scheme_set): connection = Connection(host, port, http_scheme=http_scheme_input_argument) assert connection.http_scheme == http_scheme_set + + +@patch("trino.client.CODECS_UNAVAILABLE", {"lz4": "Not installed", "zstd": "Not installed"}) +def test_default_encoding_no_compression(): + connection = Connection("host", 8080, user="test") + assert connection._client_session.encoding == ["json"] + + +@patch("trino.client.CODECS_UNAVAILABLE", {"zstd": "Not installed"}) +def test_default_encoding_lz4(): + connection = Connection("host", 8080, user="test") + assert connection._client_session.encoding == ["json+lz4", "json"] + + +@patch("trino.client.CODECS_UNAVAILABLE", {"lz4": "Not installed"}) +def test_default_encoding_zstd(): + connection = Connection("host", 8080, user="test") + assert connection._client_session.encoding == ["json+zstd", "json"] + + +@patch("trino.client.CODECS_UNAVAILABLE", {}) +def test_default_encoding_all(): + connection = Connection("host", 8080, user="test") + assert connection._client_session.encoding == ["json+zstd", "json+lz4", "json"] diff --git a/trino/client.py b/trino/client.py index 3ab27e33..b5cc62ba 100644 --- a/trino/client.py +++ b/trino/client.py @@ -64,18 +64,31 @@ from typing import Union from zoneinfo import ZoneInfo -import lz4.block +try: + import lz4.block +except ImportError as err: + _LZ4_ERROR = str(err) +else: + _LZ4_ERROR = None + try: import orjson as json except ImportError: import json import requests -import zstandard from requests import Response from requests import Session from requests.structures import CaseInsensitiveDict +try: + import zstandard +except ImportError as err: + _ZSTD_ERROR = str(err) +else: + _ZSTD_ERROR = None + + import trino.logging from trino import constants from trino import exceptions @@ -87,6 +100,7 @@ from trino.mapper import RowMapper from trino.mapper import RowMapperFactory + __all__ = [ "ClientSession", "TrinoQuery", @@ -117,6 +131,13 @@ def close_executor(): _HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$') +ENCODINGS = ["json+zstd", "json+lz4", "json"] +CODECS_UNAVAILABLE = {} +if _LZ4_ERROR: + CODECS_UNAVAILABLE["lz4"] = _LZ4_ERROR +if _ZSTD_ERROR: + CODECS_UNAVAILABLE["zstd"] = _ZSTD_ERROR + ROLE_PATTERN = re.compile(r"^ROLE\{(.*)\}$") @@ -290,7 +311,7 @@ def timezone(self) -> str: return self._timezone @property - def encoding(self) -> Union[str, List[str]]: + def encoding(self) -> Optional[Union[str, List[str]]]: with self._object_lock: return self._encoding @@ -524,7 +545,15 @@ def http_headers(self) -> CaseInsensitiveDict[str]: headers[constants.HEADER_USER] = self._client_session.user headers[constants.HEADER_TIMEZONE] = self._client_session.timezone if self._client_session.encoding is None: - pass + if not CODECS_UNAVAILABLE: + pass + else: + encoding = [ + enc + for enc in ENCODINGS + if (enc.split("+")[1] if "+" in enc else None) not in CODECS_UNAVAILABLE + ] + headers[constants.HEADER_ENCODING] = ",".join(encoding) elif isinstance(self._client_session.encoding, list): headers[constants.HEADER_ENCODING] = ",".join(self._client_session.encoding) elif isinstance(self._client_session.encoding, str): @@ -1271,8 +1300,16 @@ def __init__(self, mapper: RowMapper) -> None: def create(self, encoding: str) -> QueryDataDecoder: if encoding == "json+zstd": + if "zstd" in CODECS_UNAVAILABLE: + raise ValueError( + f"zstd is not installed so json+zstd encoding is not supported: {CODECS_UNAVAILABLE['zstd']}" + ) return ZStdQueryDataDecoder(JsonQueryDataDecoder(self._mapper)) elif encoding == "json+lz4": + if "lz4" in CODECS_UNAVAILABLE: + raise ValueError( + f"lz4 is not installed so json+lz4 encoding is not supported: {CODECS_UNAVAILABLE['lz4']}" + ) return Lz4QueryDataDecoder(JsonQueryDataDecoder(self._mapper)) elif encoding == "json": return JsonQueryDataDecoder(self._mapper) @@ -1322,10 +1359,14 @@ def decode(self, data: bytes, metadata: _SegmentMetadataTO) -> List[List[Any]]: class ZStdQueryDataDecoder(CompressedQueryDataDecoder): - zstd_decompressor = zstandard.ZstdDecompressor() + def __init__(self, delegate: QueryDataDecoder) -> None: + super().__init__(delegate) + self._decompressor = None def decompress(self, data: bytes, metadata: _SegmentMetadataTO) -> bytes: - return ZStdQueryDataDecoder.zstd_decompressor.decompress(data) + if self._decompressor is None: + self._decompressor = zstandard.ZstdDecompressor() + return self._decompressor.decompress(data) class Lz4QueryDataDecoder(CompressedQueryDataDecoder): diff --git a/trino/dbapi.py b/trino/dbapi.py index fb73cc38..42eeb547 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -170,9 +170,9 @@ def __init__( if encoding is _USE_DEFAULT_ENCODING: encoding = [ - "json+zstd", - "json+lz4", - "json", + enc + for enc in trino.client.ENCODINGS + if (enc.split("+")[1] if "+" in enc else None) not in trino.client.CODECS_UNAVAILABLE ] self.host = host if parsed_host.hostname is None else parsed_host.hostname + parsed_host.path