Skip to content

Commit 6736447

Browse files
Added support for JSON serialization via dataclass module
1 parent 68fd69a commit 6736447

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

sqlacodegen/codegen.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
_flask_prepend = 'db.'
3333

34+
_dataclass = False
35+
3436

3537
class _DummyInflectEngine(object):
3638
def singular_noun(self, noun):
@@ -312,7 +314,7 @@ def render(self):
312314
class ModelClass(Model):
313315
parent_name = 'Base'
314316

315-
def __init__(self, table, association_tables, inflect_engine, detect_joined):
317+
def __init__(self, table, association_tables, inflect_engine, detect_joined, collector):
316318
super(ModelClass, self).__init__(table)
317319
self.name = self._tablename_to_classname(table.name, inflect_engine)
318320
self.children = []
@@ -321,6 +323,10 @@ def __init__(self, table, association_tables, inflect_engine, detect_joined):
321323
# Assign attribute names for columns
322324
for column in table.columns:
323325
self._add_attribute(column.name, column)
326+
if _dataclass:
327+
if column.type.python_type.__module__ != 'builtins':
328+
collector.add_literal_import(column.type.python_type.__module__, column.type.python_type.__name__)
329+
324330

325331
# Add many-to-one relationships
326332
pk_column_names = set(col.name for col in table.primary_key.columns)
@@ -367,7 +373,13 @@ def add_imports(self, collector):
367373
child.add_imports(collector)
368374

369375
def render(self):
376+
global _dataclass
377+
370378
text = 'class {0}({1}):\n'.format(self.name, self.parent_name)
379+
380+
if _dataclass:
381+
text = '@dataclass\n' + text
382+
371383
text += ' __tablename__ = {0!r}\n'.format(self.table.name)
372384

373385
# Render constraints and indexes as __table_args__
@@ -402,6 +414,9 @@ def render(self):
402414
for attr, column in self.attributes.items():
403415
if isinstance(column, Column):
404416
show_name = attr != column.name
417+
if _dataclass:
418+
text += ' ' + attr + ' : ' + column.type.python_type.__name__ + '\n'
419+
405420
text += ' {0} = {1}\n'.format(attr, _render_column(column, show_name))
406421

407422
# Render relationships
@@ -534,7 +549,7 @@ class CodeGenerator(object):
534549

535550
def __init__(self, metadata, noindexes=False, noconstraints=False,
536551
nojoined=False, noinflect=False, nobackrefs=False,
537-
flask=False, ignore_cols=None, noclasses=False, nocomments=False, notables=False):
552+
flask=False, ignore_cols=None, noclasses=False, nocomments=False, notables=False, dataclass=False):
538553
super(CodeGenerator, self).__init__()
539554

540555
if noinflect:
@@ -552,6 +567,11 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
552567
_flask_prepend = ''
553568

554569
self.nocomments = nocomments
570+
571+
self.dataclass = dataclass
572+
if self.dataclass:
573+
global _dataclass
574+
_dataclass = True
555575

556576
# Pick association tables from the metadata into their own set, don't process them normally
557577
links = defaultdict(lambda: [])
@@ -610,13 +630,13 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
610630

611631
# Only generate classes when notables is set to True
612632
if notables:
613-
model = ModelClass(table, links[table.name], inflect_engine, not nojoined)
633+
model = ModelClass(table, links[table.name], inflect_engine, not nojoined, self.collector)
614634
classes[model.name] = model
615635
elif not table.primary_key or table.name in association_tables or noclasses:
616636
# Only form model classes for tables that have a primary key and are not association tables
617637
model = ModelTable(table)
618638
elif not noclasses:
619-
model = ModelClass(table, links[table.name], inflect_engine, not nojoined)
639+
model = ModelClass(table, links[table.name], inflect_engine, not nojoined, self.collector)
620640
classes[model.name] = model
621641

622642
self.models.append(model)
@@ -652,8 +672,13 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
652672
else:
653673
self.collector.add_literal_import('sqlalchemy.ext.declarative', 'declarative_base')
654674
self.collector.add_literal_import('sqlalchemy', 'MetaData')
675+
676+
677+
if self.dataclass:
678+
self.collector.add_literal_import('dataclasses', 'dataclass')
655679

656680
def render(self, outfile=sys.stdout):
681+
657682
print(self.header, file=outfile)
658683

659684
# Render the collected imports

sqlacodegen/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def main():
3939
parser.add_argument('--flask', action='store_true', help="use Flask-SQLAlchemy columns")
4040
parser.add_argument('--ignore-cols', help="Don't check foreign key constraints on specified columns (comma-separated)")
4141
parser.add_argument('--nocomments', action='store_true', help="don't render column comments")
42+
parser.add_argument('--dataclass', action='store_true', help="add dataclass decorators for JSON serialization")
4243
args = parser.parse_args()
4344

4445
if args.version:
@@ -58,7 +59,7 @@ def main():
5859
outfile = codecs.open(args.outfile, 'w', encoding='utf-8') if args.outfile else sys.stdout
5960
generator = CodeGenerator(metadata, args.noindexes, args.noconstraints,
6061
args.nojoined, args.noinflect, args.nobackrefs,
61-
args.flask, ignore_cols, args.noclasses, args.nocomments, args.notables)
62+
args.flask, ignore_cols, args.noclasses, args.nocomments, args.notables, args.dataclass)
6263
generator.render(outfile)
6364

6465

0 commit comments

Comments
 (0)