3131
3232_flask_prepend = 'db.'
3333
34+ _dataclass = False
35+
3436
3537class _DummyInflectEngine (object ):
3638 def singular_noun (self , noun ):
@@ -312,7 +314,7 @@ def render(self):
312314class ModelClass (Model ):
313315 parent_name = 'Base'
314316
315- def __init__ (self , table , association_tables , inflect_engine , detect_joined ):
317+ def __init__ (self , table , association_tables , inflect_engine , detect_joined , collector ):
316318 super (ModelClass , self ).__init__ (table )
317319 self .name = self ._tablename_to_classname (table .name , inflect_engine )
318320 self .children = []
@@ -321,6 +323,10 @@ def __init__(self, table, association_tables, inflect_engine, detect_joined):
321323 # Assign attribute names for columns
322324 for column in table .columns :
323325 self ._add_attribute (column .name , column )
326+ if _dataclass :
327+ if column .type .python_type .__module__ != 'builtins' :
328+ collector .add_literal_import (column .type .python_type .__module__ , column .type .python_type .__name__ )
329+
324330
325331 # Add many-to-one relationships
326332 pk_column_names = set (col .name for col in table .primary_key .columns )
@@ -367,7 +373,13 @@ def add_imports(self, collector):
367373 child .add_imports (collector )
368374
369375 def render (self ):
376+ global _dataclass
377+
370378 text = 'class {0}({1}):\n ' .format (self .name , self .parent_name )
379+
380+ if _dataclass :
381+ text = '@dataclass\n ' + text
382+
371383 text += ' __tablename__ = {0!r}\n ' .format (self .table .name )
372384
373385 # Render constraints and indexes as __table_args__
@@ -402,6 +414,9 @@ def render(self):
402414 for attr , column in self .attributes .items ():
403415 if isinstance (column , Column ):
404416 show_name = attr != column .name
417+ if _dataclass :
418+ text += ' ' + attr + ' : ' + column .type .python_type .__name__ + '\n '
419+
405420 text += ' {0} = {1}\n ' .format (attr , _render_column (column , show_name ))
406421
407422 # Render relationships
@@ -534,7 +549,7 @@ class CodeGenerator(object):
534549
535550 def __init__ (self , metadata , noindexes = False , noconstraints = False ,
536551 nojoined = False , noinflect = False , nobackrefs = False ,
537- flask = False , ignore_cols = None , noclasses = False , nocomments = False , notables = False ):
552+ flask = False , ignore_cols = None , noclasses = False , nocomments = False , notables = False , dataclass = False ):
538553 super (CodeGenerator , self ).__init__ ()
539554
540555 if noinflect :
@@ -552,6 +567,11 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
552567 _flask_prepend = ''
553568
554569 self .nocomments = nocomments
570+
571+ self .dataclass = dataclass
572+ if self .dataclass :
573+ global _dataclass
574+ _dataclass = True
555575
556576 # Pick association tables from the metadata into their own set, don't process them normally
557577 links = defaultdict (lambda : [])
@@ -610,13 +630,13 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
610630
611631 # Only generate classes when notables is set to True
612632 if notables :
613- model = ModelClass (table , links [table .name ], inflect_engine , not nojoined )
633+ model = ModelClass (table , links [table .name ], inflect_engine , not nojoined , self . collector )
614634 classes [model .name ] = model
615635 elif not table .primary_key or table .name in association_tables or noclasses :
616636 # Only form model classes for tables that have a primary key and are not association tables
617637 model = ModelTable (table )
618638 elif not noclasses :
619- model = ModelClass (table , links [table .name ], inflect_engine , not nojoined )
639+ model = ModelClass (table , links [table .name ], inflect_engine , not nojoined , self . collector )
620640 classes [model .name ] = model
621641
622642 self .models .append (model )
@@ -652,8 +672,13 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
652672 else :
653673 self .collector .add_literal_import ('sqlalchemy.ext.declarative' , 'declarative_base' )
654674 self .collector .add_literal_import ('sqlalchemy' , 'MetaData' )
675+
676+
677+ if self .dataclass :
678+ self .collector .add_literal_import ('dataclasses' , 'dataclass' )
655679
656680 def render (self , outfile = sys .stdout ):
681+
657682 print (self .header , file = outfile )
658683
659684 # Render the collected imports
0 commit comments