Skip to content

Commit e10bc53

Browse files
mdesmethashhar
authored andcommitted
Make names of ROW datatype available in result set
1 parent 73b0a58 commit e10bc53

File tree

2 files changed

+87
-4
lines changed

2 files changed

+87
-4
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,49 @@ def test_named_row(trino_connection):
745745
rows = cur.fetchall()
746746

747747
assert rows[0][0] == (1, 2.0)
748+
assert rows[0][0][0] == 1
749+
assert rows[0][0][1] == 2.0
750+
assert rows[0][0].x == 1
751+
assert rows[0][0].y == 2.0
752+
753+
assert rows[0][0].__annotations__["names"] == ['x', 'y']
754+
assert rows[0][0].__annotations__["types"] == ['bigint', 'double']
755+
756+
757+
def test_named_row_duplicate_names(trino_connection):
758+
cur = trino_connection.cursor()
759+
cur.execute("SELECT CAST(ROW(1, 2e0) AS ROW(x BIGINT, x DOUBLE))")
760+
rows = cur.fetchall()
761+
762+
assert rows[0][0] == (1, 2.0)
763+
with pytest.raises(ValueError, match="Ambiguous row field reference: x"):
764+
rows[0][0].x
765+
766+
assert rows[0][0].__annotations__["names"] == ['x', 'x']
767+
assert rows[0][0].__annotations__["types"] == ['bigint', 'double']
768+
assert str(rows[0][0]) == "(1, 2.0)"
769+
770+
771+
def test_nested_named_row(trino_connection):
772+
cur = trino_connection.cursor()
773+
cur.execute("SELECT CAST(ROW(DECIMAL '2.3', ROW(1, 'test')) AS ROW(x DECIMAL(3,2), y ROW(x BIGINT, y VARCHAR)))")
774+
rows = cur.fetchall()
775+
776+
assert rows[0][0] == (Decimal('2.3'), (1, 'test'))
777+
assert rows[0][0][0] == Decimal('2.3')
778+
assert rows[0][0][1] == (1, 'test')
779+
assert rows[0][0][1][0] == 1
780+
assert rows[0][0][1][1] == 'test'
781+
assert rows[0][0].x == Decimal('2.3')
782+
assert rows[0][0].y.x == 1
783+
assert rows[0][0].y.y == 'test'
784+
785+
assert rows[0][0].__annotations__["names"] == ['x', 'y']
786+
assert rows[0][0].__annotations__["types"] == ['decimal', 'row']
787+
788+
assert rows[0][0].y.__annotations__["names"] == ['x', 'y']
789+
assert rows[0][0].y.__annotations__["types"] == ['bigint', 'varchar']
790+
assert str(rows[0][0]) == "(x: Decimal('2.30'), y: (x: 1, y: 'test'))"
748791

749792

750793
def test_float_query_param(trino_connection):

trino/client.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,14 +1128,48 @@ def map(self, values: List[Any]) -> Optional[List[Any]]:
11281128
return [self.mapper.map(value) for value in values]
11291129

11301130

1131+
class NamedRowTuple(tuple):
1132+
"""Custom tuple class as namedtuple doesn't support missing or duplicate names"""
1133+
def __new__(cls, values, names: List[str], types: List[str]):
1134+
return super().__new__(cls, values)
1135+
1136+
def __init__(self, values, names: List[str], types: List[str]):
1137+
self._names = names
1138+
# With names and types users can retrieve the name and Trino data type of a row
1139+
self.__annotations__ = dict()
1140+
self.__annotations__["names"] = names
1141+
self.__annotations__["types"] = types
1142+
elements: List[Any] = []
1143+
for name, value in zip(names, values):
1144+
if names.count(name) == 1:
1145+
setattr(self, name, value)
1146+
elements.append(f"{name}: {repr(value)}")
1147+
else:
1148+
elements.append(repr(value))
1149+
self._repr = "(" + ", ".join(elements) + ")"
1150+
1151+
def __getattr__(self, name):
1152+
if self._names.count(name):
1153+
raise ValueError("Ambiguous row field reference: " + name)
1154+
1155+
def __repr__(self):
1156+
return self._repr
1157+
1158+
11311159
class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]):
1132-
def __init__(self, mappers: List[ValueMapper[Any]]):
1160+
def __init__(self, mappers: List[ValueMapper[Any]], names: List[str], types: List[str]):
11331161
self.mappers = mappers
1162+
self.names = names
1163+
self.types = types
11341164

11351165
def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]:
11361166
if values is None:
11371167
return None
1138-
return tuple(self.mappers[index].map(value) for index, value in enumerate(values))
1168+
return NamedRowTuple(
1169+
list(self.mappers[index].map(value) for index, value in enumerate(values)),
1170+
self.names,
1171+
self.types
1172+
)
11391173

11401174

11411175
class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]):
@@ -1183,8 +1217,14 @@ def _create_value_mapper(self, column) -> ValueMapper:
11831217
value_mapper = self._create_value_mapper(column['arguments'][0]['value'])
11841218
return ArrayValueMapper(value_mapper)
11851219
elif col_type == 'row':
1186-
mappers = [self._create_value_mapper(arg['value']['typeSignature']) for arg in column['arguments']]
1187-
return RowValueMapper(mappers)
1220+
mappers = []
1221+
names = []
1222+
types = []
1223+
for arg in column['arguments']:
1224+
mappers.append(self._create_value_mapper(arg['value']['typeSignature']))
1225+
names.append(arg['value']['fieldName']['name'] if "fieldName" in arg['value'] else None)
1226+
types.append(arg['value']['typeSignature']['rawType'])
1227+
return RowValueMapper(mappers, names, types)
11881228
elif col_type == 'map':
11891229
key_mapper = self._create_value_mapper(column['arguments'][0]['value'])
11901230
value_mapper = self._create_value_mapper(column['arguments'][1]['value'])

0 commit comments

Comments
 (0)