Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from trino.client import _retry_with
from trino.client import _RetryWithExponentialBackoff
from trino.client import ClientSession
from trino.client import CompressedQueryDataDecoderFactory
from trino.client import TrinoQuery
from trino.client import TrinoRequest
from trino.client import TrinoResult
Expand Down Expand Up @@ -1450,3 +1451,51 @@ def delete_password(self, servicename, username):
return None

os.remove(file_path)


def test_trino_request_headers_encoding_default_behavior():
session = ClientSession(user="test", encoding=None)

# Case 1: Both available -> No header
with mock.patch("trino.client.CODECS_UNAVAILABLE", {}):
req = TrinoRequest("host", 8080, session)
assert constants.HEADER_ENCODING not in req.http_headers

# Case 2: Zstd missing -> Header set with json+lz4,json
with mock.patch("trino.client.CODECS_UNAVAILABLE", {"zstd": "Not installed"}):
req = TrinoRequest("host", 8080, session)
assert req.http_headers[constants.HEADER_ENCODING] == "json+lz4,json"

# Case 3: Lz4 missing -> Header set with json+zstd,json
with mock.patch("trino.client.CODECS_UNAVAILABLE", {"lz4": "Not installed"}):
req = TrinoRequest("host", 8080, session)
assert req.http_headers[constants.HEADER_ENCODING] == "json+zstd,json"

# Case 4: Both missing -> Header set with json
with mock.patch("trino.client.CODECS_UNAVAILABLE", {"lz4": "Not installed", "zstd": "Not installed"}):
req = TrinoRequest("host", 8080, session)
assert req.http_headers[constants.HEADER_ENCODING] == "json"


def test_decoder_factory_raises_with_message_on_missing_zstd():
mapper = mock.Mock()
factory = CompressedQueryDataDecoderFactory(mapper)
error_message = "No module named 'zstandard'"
with mock.patch("trino.client.CODECS_UNAVAILABLE", {"zstd": error_message}):
with pytest.raises(
ValueError,
match=f"zstd is not installed so json\\+zstd encoding is not supported: {error_message}"
):
factory.create("json+zstd")


def test_decoder_factory_raises_with_message_on_missing_lz4():
mapper = mock.Mock()
factory = CompressedQueryDataDecoderFactory(mapper)
error_message = "No module named 'lz4.block'"
with mock.patch("trino.client.CODECS_UNAVAILABLE", {"lz4": error_message}):
with pytest.raises(
ValueError,
match=f"lz4 is not installed so json\\+lz4 encoding is not supported: {error_message}"
):
factory.create("json+lz4")
24 changes: 24 additions & 0 deletions tests/unit/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,27 @@ def test_description_is_none_when_cursor_is_not_executed():
def test_setting_http_scheme(host, port, http_scheme_input_argument, http_scheme_set):
connection = Connection(host, port, http_scheme=http_scheme_input_argument)
assert connection.http_scheme == http_scheme_set


@patch("trino.client.CODECS_UNAVAILABLE", {"lz4": "Not installed", "zstd": "Not installed"})
def test_default_encoding_no_compression():
connection = Connection("host", 8080, user="test")
assert connection._client_session.encoding == ["json"]


@patch("trino.client.CODECS_UNAVAILABLE", {"zstd": "Not installed"})
def test_default_encoding_lz4():
connection = Connection("host", 8080, user="test")
assert connection._client_session.encoding == ["json+lz4", "json"]


@patch("trino.client.CODECS_UNAVAILABLE", {"lz4": "Not installed"})
def test_default_encoding_zstd():
connection = Connection("host", 8080, user="test")
assert connection._client_session.encoding == ["json+zstd", "json"]


@patch("trino.client.CODECS_UNAVAILABLE", {})
def test_default_encoding_all():
connection = Connection("host", 8080, user="test")
assert connection._client_session.encoding == ["json+zstd", "json+lz4", "json"]
53 changes: 47 additions & 6 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,31 @@
from typing import Union
from zoneinfo import ZoneInfo

import lz4.block
try:
import lz4.block
except ImportError as err:
_LZ4_ERROR = str(err)
else:
_LZ4_ERROR = None

try:
import orjson as json
except ImportError:
import json

import requests
import zstandard
from requests import Response
from requests import Session
from requests.structures import CaseInsensitiveDict

try:
import zstandard
except ImportError as err:
_ZSTD_ERROR = str(err)
else:
_ZSTD_ERROR = None


import trino.logging
from trino import constants
from trino import exceptions
Expand All @@ -87,6 +100,7 @@
from trino.mapper import RowMapper
from trino.mapper import RowMapperFactory


__all__ = [
"ClientSession",
"TrinoQuery",
Expand Down Expand Up @@ -117,6 +131,13 @@ def close_executor():

_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$')

ENCODINGS = ["json+zstd", "json+lz4", "json"]
CODECS_UNAVAILABLE = {}
if _LZ4_ERROR:
CODECS_UNAVAILABLE["lz4"] = _LZ4_ERROR
if _ZSTD_ERROR:
CODECS_UNAVAILABLE["zstd"] = _ZSTD_ERROR

ROLE_PATTERN = re.compile(r"^ROLE\{(.*)\}$")


Expand Down Expand Up @@ -290,7 +311,7 @@ def timezone(self) -> str:
return self._timezone

@property
def encoding(self) -> Union[str, List[str]]:
def encoding(self) -> Optional[Union[str, List[str]]]:
with self._object_lock:
return self._encoding

Expand Down Expand Up @@ -524,7 +545,15 @@ def http_headers(self) -> CaseInsensitiveDict[str]:
headers[constants.HEADER_USER] = self._client_session.user
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
if self._client_session.encoding is None:
pass
if not CODECS_UNAVAILABLE:
pass
else:
encoding = [
enc
for enc in ENCODINGS
if (enc.split("+")[1] if "+" in enc else None) not in CODECS_UNAVAILABLE
]
headers[constants.HEADER_ENCODING] = ",".join(encoding)
elif isinstance(self._client_session.encoding, list):
headers[constants.HEADER_ENCODING] = ",".join(self._client_session.encoding)
elif isinstance(self._client_session.encoding, str):
Expand Down Expand Up @@ -1271,8 +1300,16 @@ def __init__(self, mapper: RowMapper) -> None:

def create(self, encoding: str) -> QueryDataDecoder:
if encoding == "json+zstd":
if "zstd" in CODECS_UNAVAILABLE:
raise ValueError(
f"zstd is not installed so json+zstd encoding is not supported: {CODECS_UNAVAILABLE['zstd']}"
)
return ZStdQueryDataDecoder(JsonQueryDataDecoder(self._mapper))
elif encoding == "json+lz4":
if "lz4" in CODECS_UNAVAILABLE:
raise ValueError(
f"lz4 is not installed so json+lz4 encoding is not supported: {CODECS_UNAVAILABLE['lz4']}"
)
return Lz4QueryDataDecoder(JsonQueryDataDecoder(self._mapper))
elif encoding == "json":
return JsonQueryDataDecoder(self._mapper)
Expand Down Expand Up @@ -1322,10 +1359,14 @@ def decode(self, data: bytes, metadata: _SegmentMetadataTO) -> List[List[Any]]:


class ZStdQueryDataDecoder(CompressedQueryDataDecoder):
zstd_decompressor = zstandard.ZstdDecompressor()
def __init__(self, delegate: QueryDataDecoder) -> None:
super().__init__(delegate)
self._decompressor = None

def decompress(self, data: bytes, metadata: _SegmentMetadataTO) -> bytes:
return ZStdQueryDataDecoder.zstd_decompressor.decompress(data)
if self._decompressor is None:
self._decompressor = zstandard.ZstdDecompressor()
return self._decompressor.decompress(data)


class Lz4QueryDataDecoder(CompressedQueryDataDecoder):
Expand Down
6 changes: 3 additions & 3 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def __init__(

if encoding is _USE_DEFAULT_ENCODING:
encoding = [
"json+zstd",
"json+lz4",
"json",
enc
for enc in trino.client.ENCODINGS
if (enc.split("+")[1] if "+" in enc else None) not in trino.client.CODECS_UNAVAILABLE
]

self.host = host if parsed_host.hostname is None else parsed_host.hostname + parsed_host.path
Expand Down