@@ -64,14 +64,16 @@ pub type PostgresPlaintextBackend = PostgresBackend<NoTls>;
6464/// A postgres backend with TLS connections to the database
6565pub type PostgresTlsBackend = PostgresBackend < MakeTlsConnector > ;
6666
67- async fn make_postgres_db_connection < T > ( postgres_endpoint : & str , tls : T ) -> Result < Client , Error >
67+ async fn make_db_connection < T > (
68+ postgres_endpoint : & str , db_name : & str , tls : T ,
69+ ) -> Result < Client , Error >
6870where
6971 T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
7072 T :: Stream : Send + Sync ,
7173 T :: TlsConnect : Send ,
7274 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
7375{
74- let dsn = format ! ( "{}/{}" , postgres_endpoint, "postgres" ) ;
76+ let dsn = format ! ( "{}/{}" , postgres_endpoint, db_name ) ;
7577 let ( client, connection) = tokio_postgres:: connect ( & dsn, tls)
7678 . await
7779 . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
@@ -84,16 +86,16 @@ where
8486 Ok ( client)
8587}
8688
87- async fn initialize_vss_database < T > (
88- postgres_endpoint : & str , db_name : & str , tls : T ,
89+ async fn create_database < T > (
90+ postgres_endpoint : & str , default_db : & str , db_name : & str , tls : T ,
8991) -> Result < ( ) , Error >
9092where
9193 T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
9294 T :: Stream : Send + Sync ,
9395 T :: TlsConnect : Send ,
9496 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
9597{
96- let client = make_postgres_db_connection ( & postgres_endpoint, tls) . await ?;
98+ let client = make_db_connection ( postgres_endpoint, default_db , tls) . await ?;
9799
98100 let num_rows = client. execute ( CHECK_DB_STMT , & [ & db_name] ) . await . map_err ( |e| {
99101 Error :: new (
@@ -113,14 +115,16 @@ where
113115}
114116
115117#[ cfg( test) ]
116- async fn drop_database < T > ( postgres_endpoint : & str , db_name : & str , tls : T ) -> Result < ( ) , Error >
118+ async fn drop_database < T > (
119+ postgres_endpoint : & str , default_db : & str , db_name : & str , tls : T ,
120+ ) -> Result < ( ) , Error >
117121where
118122 T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
119123 T :: Stream : Send + Sync ,
120124 T :: TlsConnect : Send ,
121125 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
122126{
123- let client = make_postgres_db_connection ( & postgres_endpoint, tls) . await ?;
127+ let client = make_db_connection ( postgres_endpoint, default_db , tls) . await ?;
124128
125129 let drop_database_statement = format ! ( "{} {};" , DROP_DB_CMD , db_name) ;
126130 let num_rows = client. execute ( & drop_database_statement, & [ ] ) . await . map_err ( |e| {
@@ -133,15 +137,17 @@ where
133137
134138impl PostgresPlaintextBackend {
135139 /// Constructs a [`PostgresPlaintextBackend`] using `postgres_endpoint` for PostgreSQL connection information.
136- pub async fn new ( postgres_endpoint : & str , db_name : & str ) -> Result < Self , Error > {
137- PostgresBackend :: new_internal ( postgres_endpoint, db_name, NoTls ) . await
140+ pub async fn new (
141+ postgres_endpoint : & str , default_db : & str , vss_db : & str ,
142+ ) -> Result < Self , Error > {
143+ PostgresBackend :: new_internal ( postgres_endpoint, default_db, vss_db, NoTls ) . await
138144 }
139145}
140146
141147impl PostgresTlsBackend {
142148 /// Constructs a [`PostgresTlsBackend`] using `postgres_endpoint` for PostgreSQL connection information.
143149 pub async fn new (
144- postgres_endpoint : & str , db_name : & str , crt_pem : Option < & str > ,
150+ postgres_endpoint : & str , default_db : & str , vss_db : & str , crt_pem : Option < & str > ,
145151 ) -> Result < Self , Error > {
146152 let mut builder = TlsConnector :: builder ( ) ;
147153 if let Some ( pem) = crt_pem {
@@ -156,8 +162,13 @@ impl PostgresTlsBackend {
156162 let connector = builder. build ( ) . map_err ( |e| {
157163 Error :: new ( ErrorKind :: Other , format ! ( "Error building tls connector: {}" , e) )
158164 } ) ?;
159- PostgresBackend :: new_internal ( postgres_endpoint, db_name, MakeTlsConnector :: new ( connector) )
160- . await
165+ PostgresBackend :: new_internal (
166+ postgres_endpoint,
167+ default_db,
168+ vss_db,
169+ MakeTlsConnector :: new ( connector) ,
170+ )
171+ . await
161172 }
162173}
163174
@@ -168,9 +179,11 @@ where
168179 T :: TlsConnect : Send ,
169180 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
170181{
171- async fn new_internal ( postgres_endpoint : & str , db_name : & str , tls : T ) -> Result < Self , Error > {
172- initialize_vss_database ( postgres_endpoint, db_name, tls. clone ( ) ) . await ?;
173- let vss_dsn = format ! ( "{}/{}" , postgres_endpoint, db_name) ;
182+ async fn new_internal (
183+ postgres_endpoint : & str , default_db : & str , vss_db : & str , tls : T ,
184+ ) -> Result < Self , Error > {
185+ create_database ( postgres_endpoint, default_db, vss_db, tls. clone ( ) ) . await ?;
186+ let vss_dsn = format ! ( "{}/{}" , postgres_endpoint, vss_db) ;
174187 let manager =
175188 PostgresConnectionManager :: new_from_stringlike ( vss_dsn, tls) . map_err ( |e| {
176189 Error :: new (
@@ -655,24 +668,27 @@ mod tests {
655668 use tokio_postgres:: NoTls ;
656669
657670 const POSTGRES_ENDPOINT : & str = "postgresql://postgres:postgres@localhost:5432" ;
671+ const DEFAULT_DB : & str = "postgres" ;
658672 const MIGRATIONS_START : usize = 0 ;
659673 const MIGRATIONS_END : usize = MIGRATIONS . len ( ) ;
660674
661675 static START : OnceCell < ( ) > = OnceCell :: const_new ( ) ;
662676
663677 define_kv_store_tests ! ( PostgresKvStoreTest , PostgresPlaintextBackend , {
664- let db_name = "postgres_kv_store_tests" ;
678+ let vss_db = "postgres_kv_store_tests" ;
665679 START
666680 . get_or_init( || async {
667- let _ = drop_database( POSTGRES_ENDPOINT , db_name, NoTls ) . await ;
668- let store =
669- PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
681+ let _ = drop_database( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db, NoTls ) . await ;
682+ let store = PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db)
683+ . await
684+ . unwrap( ) ;
670685 let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
671686 assert_eq!( start, MIGRATIONS_START ) ;
672687 assert_eq!( end, MIGRATIONS_END ) ;
673688 } )
674689 . await ;
675- let store = PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
690+ let store =
691+ PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap( ) ;
676692 let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
677693 assert_eq!( start, MIGRATIONS_END ) ;
678694 assert_eq!( end, MIGRATIONS_END ) ;
@@ -684,36 +700,40 @@ mod tests {
684700 #[ tokio:: test]
685701 #[ should_panic( expected = "We do not allow downgrades" ) ]
686702 async fn panic_on_downgrade ( ) {
687- let db_name = "panic_on_downgrade_test" ;
688- let _ = drop_database ( POSTGRES_ENDPOINT , db_name , NoTls ) . await ;
703+ let vss_db = "panic_on_downgrade_test" ;
704+ let _ = drop_database ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db , NoTls ) . await ;
689705 {
690706 let mut migrations = MIGRATIONS . to_vec ( ) ;
691707 migrations. push ( DUMMY_MIGRATION ) ;
692- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
708+ let store =
709+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
693710 let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
694711 assert_eq ! ( start, MIGRATIONS_START ) ;
695712 assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
696713 } ;
697714 {
698- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
715+ let store =
716+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
699717 let _ = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
700718 } ;
701719 }
702720
703721 #[ tokio:: test]
704722 async fn new_migrations_increments_upgrades ( ) {
705- let db_name = "new_migrations_increments_upgrades_test" ;
706- let _ = drop_database ( POSTGRES_ENDPOINT , db_name , NoTls ) . await ;
723+ let vss_db = "new_migrations_increments_upgrades_test" ;
724+ let _ = drop_database ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db , NoTls ) . await ;
707725 {
708- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
726+ let store =
727+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
709728 let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
710729 assert_eq ! ( start, MIGRATIONS_START ) ;
711730 assert_eq ! ( end, MIGRATIONS_END ) ;
712731 assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
713732 assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
714733 } ;
715734 {
716- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
735+ let store =
736+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
717737 let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
718738 assert_eq ! ( start, MIGRATIONS_END ) ;
719739 assert_eq ! ( end, MIGRATIONS_END ) ;
@@ -724,7 +744,8 @@ mod tests {
724744 let mut migrations = MIGRATIONS . to_vec ( ) ;
725745 migrations. push ( DUMMY_MIGRATION ) ;
726746 {
727- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
747+ let store =
748+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
728749 let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
729750 assert_eq ! ( start, MIGRATIONS_END ) ;
730751 assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
@@ -735,7 +756,8 @@ mod tests {
735756 migrations. push ( DUMMY_MIGRATION ) ;
736757 migrations. push ( DUMMY_MIGRATION ) ;
737758 {
738- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
759+ let store =
760+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
739761 let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
740762 assert_eq ! ( start, MIGRATIONS_END + 1 ) ;
741763 assert_eq ! ( end, MIGRATIONS_END + 3 ) ;
@@ -747,13 +769,14 @@ mod tests {
747769 } ;
748770
749771 {
750- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
772+ let store =
773+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
751774 let list = store. get_upgrades_list ( ) . await ;
752775 assert_eq ! ( list, [ MIGRATIONS_START , MIGRATIONS_END , MIGRATIONS_END + 1 ] ) ;
753776 let version = store. get_schema_version ( ) . await ;
754777 assert_eq ! ( version, MIGRATIONS_END + 3 ) ;
755778 }
756779
757- drop_database ( POSTGRES_ENDPOINT , db_name , NoTls ) . await . unwrap ( ) ;
780+ drop_database ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db , NoTls ) . await . unwrap ( ) ;
758781 }
759782}
0 commit comments