Skip to content

Commit d7daf05

Browse files
committed
refactor(style): style fixes per flake8
- cli: clean up imports and improve code readability - db: enhance SQL identifier validation and sanitization - test: reorganize imports in SQL injection tests
1 parent 4226874 commit d7daf05

File tree

4 files changed

+23
-46
lines changed

4 files changed

+23
-46
lines changed

stat_log_db/src/stat_log_db/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
2-
import sys
2+
# import sys
33

4-
from .parser import create_parser
5-
from .db import Database, MemDB, FileDB, BaseConnection
4+
# from .parser import create_parser
5+
from .db import MemDB # , FileDB, Database, BaseConnection
66

77

88
def main():

stat_log_db/src/stat_log_db/db.py

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -321,27 +321,21 @@ def fetchall(self):
321321
def _validate_sql_identifier(self, identifier: str, identifier_type: str = "identifier") -> str:
322322
"""
323323
Validate and sanitize SQL identifiers (table names, column names) to prevent SQL injection.
324-
325324
Args:
326325
identifier: The identifier to validate
327326
identifier_type: Type of identifier for error messages (e.g., "table name", "column name")
328-
329327
Returns:
330328
The validated identifier
331-
332329
Raises:
333330
ValueError: If the identifier is invalid or potentially dangerous
334331
"""
335332
if not isinstance(identifier, str):
336333
raise TypeError(f"SQL {identifier_type} must be a string, got {type(identifier).__name__}")
337-
338334
if len(identifier) == 0:
339335
raise ValueError(f"SQL {identifier_type} cannot be empty")
340-
341336
# Check for valid identifier pattern: starts with letter/underscore, contains only alphanumeric/underscore
342337
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier):
343338
raise ValueError(f"Invalid SQL {identifier_type}: '{identifier}'. Must start with letter or underscore and contain only letters, numbers, and underscores.")
344-
345339
# Check against SQLite reserved words (common ones that could cause issues)
346340
reserved_words = {
347341
'abort', 'action', 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
@@ -361,10 +355,8 @@ def _validate_sql_identifier(self, identifier: str, identifier_type: str = "iden
361355
'trigger', 'unbounded', 'union', 'unique', 'update', 'using', 'vacuum', 'values',
362356
'view', 'virtual', 'when', 'where', 'window', 'with', 'without'
363357
}
364-
365358
if identifier.lower() in reserved_words:
366359
raise ValueError(f"SQL {identifier_type} '{identifier}' is a reserved word and cannot be used")
367-
368360
return identifier
369361

370362
def _escape_sql_identifier(self, identifier: str) -> str:
@@ -382,39 +374,32 @@ def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_tab
382374
raise_auto_arg_type_error("table_name")
383375
if len(table_name) == 0:
384376
raise ValueError("'table_name' argument of create_table cannot be an empty string!")
385-
386-
# Validate and sanitize table name
387-
validated_table_name = self._validate_sql_identifier(table_name, "table name")
388-
escaped_table_name = self._escape_sql_identifier(validated_table_name)
389-
390-
if not isinstance(raise_if_exists, bool):
391-
raise_auto_arg_type_error("raise_if_exists")
392-
393-
# Check if table already exists using parameterized query
394-
if raise_if_exists:
395-
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,))
396-
if self.cursor.fetchone() is not None:
397-
raise ValueError(f"Table '{validated_table_name}' already exists.")
398-
399377
# Validate temp_table argument
400378
if not isinstance(temp_table, bool):
401379
raise_auto_arg_type_error("temp_table")
402-
380+
if not isinstance(raise_if_exists, bool):
381+
raise_auto_arg_type_error("raise_if_exists")
403382
# Validate columns argument
404383
if (not isinstance(columns, list)) or (not all(
405384
isinstance(col, tuple) and len(col) == 2
406385
and isinstance(col[0], str)
407386
and isinstance(col[1], str)
408387
for col in columns)):
409388
raise_auto_arg_type_error("columns")
410-
389+
# Validate and sanitize table name
390+
validated_table_name = self._validate_sql_identifier(table_name, "table name")
391+
escaped_table_name = self._escape_sql_identifier(validated_table_name)
392+
# Check if table already exists using parameterized query
393+
if raise_if_exists:
394+
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,))
395+
if self.cursor.fetchone() is not None:
396+
raise ValueError(f"Table '{validated_table_name}' already exists.")
411397
# Validate and construct columns portion of query
412398
validated_columns = []
413399
for col_name, col_type in columns:
414400
# Validate column name
415401
validated_col_name = self._validate_sql_identifier(col_name, "column name")
416402
escaped_col_name = self._escape_sql_identifier(validated_col_name)
417-
418403
# Validate column type - allow only safe, known SQLite types
419404
allowed_types = {
420405
'TEXT', 'INTEGER', 'REAL', 'BLOB', 'NUMERIC',
@@ -423,27 +408,21 @@ def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_tab
423408
'BOOLEAN', 'DECIMAL', 'DOUBLE', 'FLOAT',
424409
'INT', 'BIGINT', 'SMALLINT', 'TINYINT'
425410
}
426-
427411
# Allow type specifications with length/precision (e.g., VARCHAR(50), DECIMAL(10,2))
428412
base_type = re.match(r'^([A-Z]+)', col_type.upper())
429413
if not base_type or base_type.group(1) not in allowed_types:
430414
raise ValueError(f"Unsupported column type: '{col_type}'. Must be one of: {', '.join(sorted(allowed_types))}")
431-
432415
# Basic validation for type specification format
433416
if not re.match(r'^[A-Z]+(\([0-9,\s]+\))?$', col_type.upper()):
434417
raise ValueError(f"Invalid column type format: '{col_type}'")
435-
436418
validated_columns.append(f"{escaped_col_name} {col_type.upper()}")
437-
438419
columns_qstr = ",\n ".join(validated_columns)
439-
440420
# Assemble full query with escaped identifiers
441421
temp_keyword = " TEMPORARY" if temp_table else ""
442422
query = f"""CREATE{temp_keyword} TABLE IF NOT EXISTS {escaped_table_name} (
443423
id INTEGER PRIMARY KEY AUTOINCREMENT,
444424
{columns_qstr}
445425
);"""
446-
447426
self.execute(query)
448427

449428
def drop_table(self, table_name: str, raise_if_not_exists: bool = False):
@@ -452,34 +431,30 @@ def drop_table(self, table_name: str, raise_if_not_exists: bool = False):
452431
raise_auto_arg_type_error("table_name")
453432
if len(table_name) == 0:
454433
raise ValueError("'table_name' argument of drop_table cannot be an empty string!")
455-
434+
if not isinstance(raise_if_not_exists, bool):
435+
raise_auto_arg_type_error("raise_if_not_exists")
456436
# Validate and sanitize table name
457437
validated_table_name = self._validate_sql_identifier(table_name, "table name")
458438
escaped_table_name = self._escape_sql_identifier(validated_table_name)
459-
460-
if not isinstance(raise_if_not_exists, bool):
461-
raise_auto_arg_type_error("raise_if_not_exists")
462-
463439
# Check if table exists using parameterized query
464440
if raise_if_not_exists:
465441
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,))
466442
if self.cursor.fetchone() is None:
467443
raise ValueError(f"Table '{validated_table_name}' does not exist.")
468-
469444
# Execute DROP statement with escaped identifier
470445
self.cursor.execute(f"DROP TABLE IF EXISTS {escaped_table_name};")
471446

472447
# def read(self):
473-
448+
# pass
474449

475450
# def write(self):
476-
451+
# pass
477452

478453
# def create(self):
479-
454+
# pass
480455

481456
# def unlink(self):
482-
457+
# pass
483458

484459

485460
class Connection(BaseConnection):

stat_log_db/tests/test_sql_injection.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
import sys
77
from pathlib import Path
88

9+
from stat_log_db.db import MemDB
10+
11+
912
# Add the src directory to the path to import the module
1013
ROOT = Path(__file__).resolve().parent.parent
1114
sys.path.insert(0, str(ROOT / "stat_log_db" / "src"))
1215

13-
from stat_log_db.db import MemDB
14-
1516

1617
@pytest.fixture
1718
def mem_db():

tests/test_tools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def test_help():
121121
except AssertionError:
122122
assert out.strip() == readme_content.strip(), "Help output does not match README content (leading & trailing whitespace stripped)"
123123

124+
124125
@pytest.mark.skipif(GITHUB_ACTIONS, reason="Skipping test on GitHub Actions")
125126
def test_install_dev(test_venv):
126127
code, out = run_tools(["-id"], use_test_venv=True)

0 commit comments

Comments
 (0)