Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 63dce8f

Browse files
authored
Merge pull request #827 from datafold/type-annotate-everything-2
Type annotate some things ("no-brainers")
2 parents d5a4d12 + ff76f94 commit 63dce8f

35 files changed

+180
-162
lines changed

data_diff/__main__.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _get_schema(pair: Tuple[Database, DbPath]) -> Dict[str, RawColumnInfo]:
7777
return db.query_table_schema(table_path)
7878

7979

80-
def diff_schemas(table1, table2, schema1, schema2, columns):
80+
def diff_schemas(table1, table2, schema1, schema2, columns) -> None:
8181
logging.info("Diffing schemas...")
8282
attrs = "name", "type", "datetime_precision", "numeric_precision", "numeric_scale"
8383
for c in columns:
@@ -103,7 +103,7 @@ def diff_schemas(table1, table2, schema1, schema2, columns):
103103

104104

105105
class MyHelpFormatter(click.HelpFormatter):
106-
def __init__(self, **kwargs):
106+
def __init__(self, **kwargs) -> None:
107107
super().__init__(self, **kwargs)
108108
self.indent_increment = 6
109109

@@ -281,7 +281,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
281281
default=None,
282282
help="Override the dbt production schema configuration within dbt_project.yml",
283283
)
284-
def main(conf, run, **kw):
284+
def main(conf, run, **kw) -> None:
285285
log_handlers = _get_log_handlers(kw["dbt"])
286286
if kw["table2"] is None and kw["database2"]:
287287
# Use the "database table table" form
@@ -341,9 +341,7 @@ def main(conf, run, **kw):
341341
production_schema_flag=kw["prod_schema"],
342342
)
343343
else:
344-
return _data_diff(
345-
dbt_project_dir=project_dir_override, dbt_profiles_dir=profiles_dir_override, state=state, **kw
346-
)
344+
_data_diff(dbt_project_dir=project_dir_override, dbt_profiles_dir=profiles_dir_override, state=state, **kw)
347345
except Exception as e:
348346
logging.error(e)
349347
raise
@@ -389,7 +387,7 @@ def _data_diff(
389387
threads1=None,
390388
threads2=None,
391389
__conf__=None,
392-
):
390+
) -> None:
393391
if limit and stats:
394392
logging.error("Cannot specify a limit when using the -s/--stats switch")
395393
return

data_diff/abcs/database_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class Integer(NumericType, IKey):
290290
precision: int = 0
291291
python_type: type = int
292292

293-
def __attrs_post_init__(self):
293+
def __attrs_post_init__(self) -> None:
294294
assert self.precision == 0
295295

296296

data_diff/cloud/data_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def process_response(self, value: str) -> str:
4646
return value
4747

4848

49-
def _validate_temp_schema(temp_schema: str):
49+
def _validate_temp_schema(temp_schema: str) -> None:
5050
if len(temp_schema.split(".")) != 2:
5151
raise ValueError("Temporary schema should have a format <database>.<schema>")
5252

data_diff/cloud/datafold_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class DatafoldAPI:
185185
host: str = "https://app.datafold.com"
186186
timeout: int = 30
187187

188-
def __attrs_post_init__(self):
188+
def __attrs_post_init__(self) -> None:
189189
self.host = self.host.rstrip("/")
190190
self.headers = {
191191
"Authorization": f"Key {self.api_key}",

data_diff/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _apply_config(config: Dict[str, Any], run_name: str, kw: Dict[str, Any]):
9999
_ENV_VAR_PATTERN = r"\$\{([A-Za-z0-9_]+)\}"
100100

101101

102-
def _resolve_env(config: Dict[str, Any]):
102+
def _resolve_env(config: Dict[str, Any]) -> None:
103103
"""
104104
Resolve environment variables referenced as ${ENV_VAR_NAME}.
105105
Missing environment variables are replaced with an empty string.

data_diff/databases/_connect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class Connect:
100100
database_by_scheme: Dict[str, Database]
101101
conn_cache: MutableMapping[Hashable, Database]
102102

103-
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
103+
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME) -> None:
104104
super().__init__()
105105
self.database_by_scheme = database_by_scheme
106106
self.conn_cache = weakref.WeakValueDictionary()

data_diff/databases/base.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,22 @@
55
import math
66
import sys
77
import logging
8-
from typing import Any, Callable, ClassVar, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar
8+
from typing import (
9+
Any,
10+
Callable,
11+
ClassVar,
12+
Dict,
13+
Generator,
14+
Iterator,
15+
NewType,
16+
Tuple,
17+
Optional,
18+
Sequence,
19+
Type,
20+
List,
21+
Union,
22+
TypeVar,
23+
)
924
from functools import partial, wraps
1025
from concurrent.futures import ThreadPoolExecutor
1126
import threading
@@ -116,7 +131,7 @@ def dialect(self) -> "BaseDialect":
116131
def compile(self, elem, params=None) -> str:
117132
return self.dialect.compile(self, elem, params)
118133

119-
def new_unique_name(self, prefix="tmp"):
134+
def new_unique_name(self, prefix="tmp") -> str:
120135
self._counter[0] += 1
121136
return f"{prefix}{self._counter[0]}"
122137

@@ -173,7 +188,7 @@ class ThreadLocalInterpreter:
173188
compiler: Compiler
174189
gen: Generator
175190

176-
def apply_queries(self, callback: Callable[[str], Any]):
191+
def apply_queries(self, callback: Callable[[str], Any]) -> None:
177192
q: Expr = next(self.gen)
178193
while True:
179194
sql = self.compiler.database.dialect.compile(self.compiler, q)
@@ -885,20 +900,21 @@ def optimizer_hints(self, hints: str) -> str:
885900

886901

887902
T = TypeVar("T", bound=BaseDialect)
903+
Row = Sequence[Any]
888904

889905

890906
@attrs.define(frozen=True)
891907
class QueryResult:
892-
rows: list
908+
rows: List[Row]
893909
columns: Optional[list] = None
894910

895-
def __iter__(self):
911+
def __iter__(self) -> Iterator[Row]:
896912
return iter(self.rows)
897913

898-
def __len__(self):
914+
def __len__(self) -> int:
899915
return len(self.rows)
900916

901-
def __getitem__(self, i):
917+
def __getitem__(self, i) -> Row:
902918
return self.rows[i]
903919

904920

@@ -1209,7 +1225,7 @@ class ThreadedDatabase(Database):
12091225
_queue: Optional[ThreadPoolExecutor] = None
12101226
thread_local: threading.local = attrs.field(factory=threading.local)
12111227

1212-
def __attrs_post_init__(self):
1228+
def __attrs_post_init__(self) -> None:
12131229
self._queue = ThreadPoolExecutor(self.thread_count, initializer=self.set_conn)
12141230
logger.info(f"[{self.name}] Starting a threadpool, size={self.thread_count}.")
12151231

data_diff/databases/bigquery.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ class Dialect(BaseDialect):
8585
def random(self) -> str:
8686
return "RAND()"
8787

88-
def quote(self, s: str):
88+
def quote(self, s: str) -> str:
8989
return f"`{s}`"
9090

91-
def to_string(self, s: str):
91+
def to_string(self, s: str) -> str:
9292
return f"cast({s} as string)"
9393

9494
def type_repr(self, t) -> str:
@@ -212,7 +212,7 @@ class BigQuery(Database):
212212
dataset: str
213213
_client: Any
214214

215-
def __init__(self, project, *, dataset, bigquery_credentials=None, **kw):
215+
def __init__(self, project, *, dataset, bigquery_credentials=None, **kw) -> None:
216216
super().__init__()
217217
credentials = bigquery_credentials
218218
bigquery = import_bigquery()

data_diff/databases/clickhouse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class Clickhouse(ThreadedDatabase):
175175

176176
_args: Dict[str, Any]
177177

178-
def __init__(self, *, thread_count: int, **kw):
178+
def __init__(self, *, thread_count: int, **kw) -> None:
179179
super().__init__(thread_count=thread_count)
180180

181181
self._args = kw

data_diff/databases/databricks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def type_repr(self, t) -> str:
6565
except KeyError:
6666
return super().type_repr(t)
6767

68-
def quote(self, s: str):
68+
def quote(self, s: str) -> str:
6969
return f"`{s}`"
7070

7171
def to_string(self, s: str) -> str:
@@ -118,7 +118,7 @@ class Databricks(ThreadedDatabase):
118118
catalog: str
119119
_args: Dict[str, Any]
120120

121-
def __init__(self, *, thread_count, **kw):
121+
def __init__(self, *, thread_count, **kw) -> None:
122122
super().__init__(thread_count=thread_count)
123123
logging.getLogger("databricks.sql").setLevel(logging.WARNING)
124124

0 commit comments

Comments
 (0)