@@ -64,14 +64,16 @@ pub type PostgresPlaintextBackend = PostgresBackend<NoTls>;
6464/// A postgres backend with TLS connections to the database
6565pub type PostgresTlsBackend = PostgresBackend < MakeTlsConnector > ;
6666
67- async fn make_postgres_db_connection < T > ( postgres_endpoint : & str , tls : T ) -> Result < Client , Error >
67+ async fn make_db_connection < T > (
68+ postgres_endpoint : & str , db_name : & str , tls : T ,
69+ ) -> Result < Client , Error >
6870where
6971 T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
7072 T :: Stream : Send + Sync ,
7173 T :: TlsConnect : Send ,
7274 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
7375{
74- let dsn = format ! ( "{}/{}" , postgres_endpoint, "postgres" ) ;
76+ let dsn = format ! ( "{}/{}" , postgres_endpoint, db_name ) ;
7577 let ( client, connection) = tokio_postgres:: connect ( & dsn, tls)
7678 . await
7779 . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
@@ -84,16 +86,16 @@ where
8486 Ok ( client)
8587}
8688
87- async fn initialize_vss_database < T > (
88- postgres_endpoint : & str , db_name : & str , tls : T ,
89+ async fn create_database < T > (
90+ postgres_endpoint : & str , default_db : & str , db_name : & str , tls : T ,
8991) -> Result < ( ) , Error >
9092where
9193 T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
9294 T :: Stream : Send + Sync ,
9395 T :: TlsConnect : Send ,
9496 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
9597{
96- let client = make_postgres_db_connection ( & postgres_endpoint, tls) . await ?;
98+ let client = make_db_connection ( postgres_endpoint, default_db , tls) . await ?;
9799
98100 let num_rows = client. execute ( CHECK_DB_STMT , & [ & db_name] ) . await . map_err ( |e| {
99101 Error :: new (
@@ -113,14 +115,16 @@ where
113115}
114116
115117#[ cfg( test) ]
116- async fn drop_database < T > ( postgres_endpoint : & str , db_name : & str , tls : T ) -> Result < ( ) , Error >
118+ async fn drop_database < T > (
119+ postgres_endpoint : & str , default_db : & str , db_name : & str , tls : T ,
120+ ) -> Result < ( ) , Error >
117121where
118122 T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
119123 T :: Stream : Send + Sync ,
120124 T :: TlsConnect : Send ,
121125 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
122126{
123- let client = make_postgres_db_connection ( & postgres_endpoint, tls) . await ?;
127+ let client = make_db_connection ( postgres_endpoint, default_db , tls) . await ?;
124128
125129 let drop_database_statement = format ! ( "{} {};" , DROP_DB_CMD , db_name) ;
126130 let num_rows = client. execute ( & drop_database_statement, & [ ] ) . await . map_err ( |e| {
@@ -133,15 +137,18 @@ where
133137
134138impl PostgresPlaintextBackend {
135139 /// Constructs a [`PostgresPlaintextBackend`] using `postgres_endpoint` for PostgreSQL connection information.
136- pub async fn new ( postgres_endpoint : & str , db_name : & str ) -> Result < Self , Error > {
137- PostgresBackend :: new_internal ( postgres_endpoint, db_name, NoTls ) . await
140+ pub async fn new (
141+ postgres_endpoint : & str , default_db : & str , vss_db : & str ,
142+ ) -> Result < Self , Error > {
143+ PostgresBackend :: new_internal ( postgres_endpoint, default_db, vss_db, NoTls ) . await
138144 }
139145}
140146
141147impl PostgresTlsBackend {
142148 /// Constructs a [`PostgresTlsBackend`] using `postgres_endpoint` for PostgreSQL connection information.
143149 pub async fn new (
144- postgres_endpoint : & str , db_name : & str , additional_certificate : Option < Certificate > ,
150+ postgres_endpoint : & str , default_db : & str , vss_db : & str ,
151+ additional_certificate : Option < Certificate > ,
145152 ) -> Result < Self , Error > {
146153 let mut builder = TlsConnector :: builder ( ) ;
147154 if let Some ( cert) = additional_certificate {
@@ -150,8 +157,13 @@ impl PostgresTlsBackend {
150157 let connector = builder. build ( ) . map_err ( |e| {
151158 Error :: new ( ErrorKind :: Other , format ! ( "Error building tls connector: {}" , e) )
152159 } ) ?;
153- PostgresBackend :: new_internal ( postgres_endpoint, db_name, MakeTlsConnector :: new ( connector) )
154- . await
160+ PostgresBackend :: new_internal (
161+ postgres_endpoint,
162+ default_db,
163+ vss_db,
164+ MakeTlsConnector :: new ( connector) ,
165+ )
166+ . await
155167 }
156168}
157169
@@ -162,9 +174,11 @@ where
162174 T :: TlsConnect : Send ,
163175 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
164176{
165- async fn new_internal ( postgres_endpoint : & str , db_name : & str , tls : T ) -> Result < Self , Error > {
166- initialize_vss_database ( postgres_endpoint, db_name, tls. clone ( ) ) . await ?;
167- let vss_dsn = format ! ( "{}/{}" , postgres_endpoint, db_name) ;
177+ async fn new_internal (
178+ postgres_endpoint : & str , default_db : & str , vss_db : & str , tls : T ,
179+ ) -> Result < Self , Error > {
180+ create_database ( postgres_endpoint, default_db, vss_db, tls. clone ( ) ) . await ?;
181+ let vss_dsn = format ! ( "{}/{}" , postgres_endpoint, vss_db) ;
168182 let manager =
169183 PostgresConnectionManager :: new_from_stringlike ( vss_dsn, tls) . map_err ( |e| {
170184 Error :: new (
@@ -649,24 +663,27 @@ mod tests {
649663 use tokio_postgres:: NoTls ;
650664
651665 const POSTGRES_ENDPOINT : & str = "postgresql://postgres:postgres@localhost:5432" ;
666+ const DEFAULT_DB : & str = "postgres" ;
652667 const MIGRATIONS_START : usize = 0 ;
653668 const MIGRATIONS_END : usize = MIGRATIONS . len ( ) ;
654669
655670 static START : OnceCell < ( ) > = OnceCell :: const_new ( ) ;
656671
657672 define_kv_store_tests ! ( PostgresKvStoreTest , PostgresPlaintextBackend , {
658- let db_name = "postgres_kv_store_tests" ;
673+ let vss_db = "postgres_kv_store_tests" ;
659674 START
660675 . get_or_init( || async {
661- let _ = drop_database( POSTGRES_ENDPOINT , db_name, NoTls ) . await ;
662- let store =
663- PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
676+ let _ = drop_database( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db, NoTls ) . await ;
677+ let store = PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db)
678+ . await
679+ . unwrap( ) ;
664680 let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
665681 assert_eq!( start, MIGRATIONS_START ) ;
666682 assert_eq!( end, MIGRATIONS_END ) ;
667683 } )
668684 . await ;
669- let store = PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
685+ let store =
686+ PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap( ) ;
670687 let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
671688 assert_eq!( start, MIGRATIONS_END ) ;
672689 assert_eq!( end, MIGRATIONS_END ) ;
@@ -678,36 +695,40 @@ mod tests {
678695 #[ tokio:: test]
679696 #[ should_panic( expected = "We do not allow downgrades" ) ]
680697 async fn panic_on_downgrade ( ) {
681- let db_name = "panic_on_downgrade_test" ;
682- let _ = drop_database ( POSTGRES_ENDPOINT , db_name , NoTls ) . await ;
698+ let vss_db = "panic_on_downgrade_test" ;
699+ let _ = drop_database ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db , NoTls ) . await ;
683700 {
684701 let mut migrations = MIGRATIONS . to_vec ( ) ;
685702 migrations. push ( DUMMY_MIGRATION ) ;
686- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
703+ let store =
704+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
687705 let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
688706 assert_eq ! ( start, MIGRATIONS_START ) ;
689707 assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
690708 } ;
691709 {
692- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
710+ let store =
711+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
693712 let _ = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
694713 } ;
695714 }
696715
697716 #[ tokio:: test]
698717 async fn new_migrations_increments_upgrades ( ) {
699- let db_name = "new_migrations_increments_upgrades_test" ;
700- let _ = drop_database ( POSTGRES_ENDPOINT , db_name , NoTls ) . await ;
718+ let vss_db = "new_migrations_increments_upgrades_test" ;
719+ let _ = drop_database ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db , NoTls ) . await ;
701720 {
702- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
721+ let store =
722+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
703723 let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
704724 assert_eq ! ( start, MIGRATIONS_START ) ;
705725 assert_eq ! ( end, MIGRATIONS_END ) ;
706726 assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
707727 assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
708728 } ;
709729 {
710- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
730+ let store =
731+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
711732 let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
712733 assert_eq ! ( start, MIGRATIONS_END ) ;
713734 assert_eq ! ( end, MIGRATIONS_END ) ;
@@ -718,7 +739,8 @@ mod tests {
718739 let mut migrations = MIGRATIONS . to_vec ( ) ;
719740 migrations. push ( DUMMY_MIGRATION ) ;
720741 {
721- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
742+ let store =
743+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
722744 let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
723745 assert_eq ! ( start, MIGRATIONS_END ) ;
724746 assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
@@ -729,7 +751,8 @@ mod tests {
729751 migrations. push ( DUMMY_MIGRATION ) ;
730752 migrations. push ( DUMMY_MIGRATION ) ;
731753 {
732- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
754+ let store =
755+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
733756 let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
734757 assert_eq ! ( start, MIGRATIONS_END + 1 ) ;
735758 assert_eq ! ( end, MIGRATIONS_END + 3 ) ;
@@ -741,13 +764,14 @@ mod tests {
741764 } ;
742765
743766 {
744- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
767+ let store =
768+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
745769 let list = store. get_upgrades_list ( ) . await ;
746770 assert_eq ! ( list, [ MIGRATIONS_START , MIGRATIONS_END , MIGRATIONS_END + 1 ] ) ;
747771 let version = store. get_schema_version ( ) . await ;
748772 assert_eq ! ( version, MIGRATIONS_END + 3 ) ;
749773 }
750774
751- drop_database ( POSTGRES_ENDPOINT , db_name , NoTls ) . await . unwrap ( ) ;
775+ drop_database ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db , NoTls ) . await . unwrap ( ) ;
752776 }
753777}
0 commit comments