Skip to content

Commit 41a5a8a

Browse files
committed
Add option to specify default postgres db name
Not all postgres hosted services use `postgres` Also rename `database` config parameter to `vss_database`
1 parent 10f667c commit 41a5a8a

File tree

4 files changed

+77
-47
lines changed

4 files changed

+77
-47
lines changed

rust/impls/src/postgres_store.rs

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,16 @@ pub type PostgresPlaintextBackend = PostgresBackend<NoTls>;
6464
/// A postgres backend with TLS connections to the database
6565
pub type PostgresTlsBackend = PostgresBackend<MakeTlsConnector>;
6666

67-
async fn make_postgres_db_connection<T>(postgres_endpoint: &str, tls: T) -> Result<Client, Error>
67+
async fn make_db_connection<T>(
68+
postgres_endpoint: &str, db_name: &str, tls: T,
69+
) -> Result<Client, Error>
6870
where
6971
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
7072
T::Stream: Send + Sync,
7173
T::TlsConnect: Send,
7274
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
7375
{
74-
let dsn = format!("{}/{}", postgres_endpoint, "postgres");
76+
let dsn = format!("{}/{}", postgres_endpoint, db_name);
7577
let (client, connection) = tokio_postgres::connect(&dsn, tls)
7678
.await
7779
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
@@ -84,16 +86,16 @@ where
8486
Ok(client)
8587
}
8688

87-
async fn initialize_vss_database<T>(
88-
postgres_endpoint: &str, db_name: &str, tls: T,
89+
async fn create_database<T>(
90+
postgres_endpoint: &str, default_db: &str, db_name: &str, tls: T,
8991
) -> Result<(), Error>
9092
where
9193
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
9294
T::Stream: Send + Sync,
9395
T::TlsConnect: Send,
9496
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
9597
{
96-
let client = make_postgres_db_connection(&postgres_endpoint, tls).await?;
98+
let client = make_db_connection(postgres_endpoint, default_db, tls).await?;
9799

98100
let num_rows = client.execute(CHECK_DB_STMT, &[&db_name]).await.map_err(|e| {
99101
Error::new(
@@ -113,14 +115,16 @@ where
113115
}
114116

115117
#[cfg(test)]
116-
async fn drop_database<T>(postgres_endpoint: &str, db_name: &str, tls: T) -> Result<(), Error>
118+
async fn drop_database<T>(
119+
postgres_endpoint: &str, default_db: &str, db_name: &str, tls: T,
120+
) -> Result<(), Error>
117121
where
118122
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
119123
T::Stream: Send + Sync,
120124
T::TlsConnect: Send,
121125
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
122126
{
123-
let client = make_postgres_db_connection(&postgres_endpoint, tls).await?;
127+
let client = make_db_connection(postgres_endpoint, default_db, tls).await?;
124128

125129
let drop_database_statement = format!("{} {};", DROP_DB_CMD, db_name);
126130
let num_rows = client.execute(&drop_database_statement, &[]).await.map_err(|e| {
@@ -133,15 +137,17 @@ where
133137

134138
impl PostgresPlaintextBackend {
135139
/// Constructs a [`PostgresPlaintextBackend`] using `postgres_endpoint` for PostgreSQL connection information.
136-
pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result<Self, Error> {
137-
PostgresBackend::new_internal(postgres_endpoint, db_name, NoTls).await
140+
pub async fn new(
141+
postgres_endpoint: &str, default_db: &str, vss_db: &str,
142+
) -> Result<Self, Error> {
143+
PostgresBackend::new_internal(postgres_endpoint, default_db, vss_db, NoTls).await
138144
}
139145
}
140146

141147
impl PostgresTlsBackend {
142148
/// Constructs a [`PostgresTlsBackend`] using `postgres_endpoint` for PostgreSQL connection information.
143149
pub async fn new(
144-
postgres_endpoint: &str, db_name: &str, crt_pem: Option<&str>,
150+
postgres_endpoint: &str, default_db: &str, vss_db: &str, crt_pem: Option<&str>,
145151
) -> Result<Self, Error> {
146152
let mut builder = TlsConnector::builder();
147153
if let Some(pem) = crt_pem {
@@ -156,8 +162,13 @@ impl PostgresTlsBackend {
156162
let connector = builder.build().map_err(|e| {
157163
Error::new(ErrorKind::Other, format!("Error building tls connector: {}", e))
158164
})?;
159-
PostgresBackend::new_internal(postgres_endpoint, db_name, MakeTlsConnector::new(connector))
160-
.await
165+
PostgresBackend::new_internal(
166+
postgres_endpoint,
167+
default_db,
168+
vss_db,
169+
MakeTlsConnector::new(connector),
170+
)
171+
.await
161172
}
162173
}
163174

@@ -168,9 +179,11 @@ where
168179
T::TlsConnect: Send,
169180
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
170181
{
171-
async fn new_internal(postgres_endpoint: &str, db_name: &str, tls: T) -> Result<Self, Error> {
172-
initialize_vss_database(postgres_endpoint, db_name, tls.clone()).await?;
173-
let vss_dsn = format!("{}/{}", postgres_endpoint, db_name);
182+
async fn new_internal(
183+
postgres_endpoint: &str, default_db: &str, vss_db: &str, tls: T,
184+
) -> Result<Self, Error> {
185+
create_database(postgres_endpoint, default_db, vss_db, tls.clone()).await?;
186+
let vss_dsn = format!("{}/{}", postgres_endpoint, vss_db);
174187
let manager =
175188
PostgresConnectionManager::new_from_stringlike(vss_dsn, tls).map_err(|e| {
176189
Error::new(
@@ -655,24 +668,27 @@ mod tests {
655668
use tokio_postgres::NoTls;
656669

657670
const POSTGRES_ENDPOINT: &str = "postgresql://postgres:postgres@localhost:5432";
671+
const DEFAULT_DB: &str = "postgres";
658672
const MIGRATIONS_START: usize = 0;
659673
const MIGRATIONS_END: usize = MIGRATIONS.len();
660674

661675
static START: OnceCell<()> = OnceCell::const_new();
662676

663677
define_kv_store_tests!(PostgresKvStoreTest, PostgresPlaintextBackend, {
664-
let db_name = "postgres_kv_store_tests";
678+
let vss_db = "postgres_kv_store_tests";
665679
START
666680
.get_or_init(|| async {
667-
let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await;
668-
let store =
669-
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
681+
let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await;
682+
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db)
683+
.await
684+
.unwrap();
670685
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
671686
assert_eq!(start, MIGRATIONS_START);
672687
assert_eq!(end, MIGRATIONS_END);
673688
})
674689
.await;
675-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
690+
let store =
691+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
676692
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
677693
assert_eq!(start, MIGRATIONS_END);
678694
assert_eq!(end, MIGRATIONS_END);
@@ -684,36 +700,40 @@ mod tests {
684700
#[tokio::test]
685701
#[should_panic(expected = "We do not allow downgrades")]
686702
async fn panic_on_downgrade() {
687-
let db_name = "panic_on_downgrade_test";
688-
let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await;
703+
let vss_db = "panic_on_downgrade_test";
704+
let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await;
689705
{
690706
let mut migrations = MIGRATIONS.to_vec();
691707
migrations.push(DUMMY_MIGRATION);
692-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
708+
let store =
709+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
693710
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
694711
assert_eq!(start, MIGRATIONS_START);
695712
assert_eq!(end, MIGRATIONS_END + 1);
696713
};
697714
{
698-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
715+
let store =
716+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
699717
let _ = store.migrate_vss_database(MIGRATIONS).await.unwrap();
700718
};
701719
}
702720

703721
#[tokio::test]
704722
async fn new_migrations_increments_upgrades() {
705-
let db_name = "new_migrations_increments_upgrades_test";
706-
let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await;
723+
let vss_db = "new_migrations_increments_upgrades_test";
724+
let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await;
707725
{
708-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
726+
let store =
727+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
709728
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
710729
assert_eq!(start, MIGRATIONS_START);
711730
assert_eq!(end, MIGRATIONS_END);
712731
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
713732
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
714733
};
715734
{
716-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
735+
let store =
736+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
717737
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
718738
assert_eq!(start, MIGRATIONS_END);
719739
assert_eq!(end, MIGRATIONS_END);
@@ -724,7 +744,8 @@ mod tests {
724744
let mut migrations = MIGRATIONS.to_vec();
725745
migrations.push(DUMMY_MIGRATION);
726746
{
727-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
747+
let store =
748+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
728749
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
729750
assert_eq!(start, MIGRATIONS_END);
730751
assert_eq!(end, MIGRATIONS_END + 1);
@@ -735,7 +756,8 @@ mod tests {
735756
migrations.push(DUMMY_MIGRATION);
736757
migrations.push(DUMMY_MIGRATION);
737758
{
738-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
759+
let store =
760+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
739761
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
740762
assert_eq!(start, MIGRATIONS_END + 1);
741763
assert_eq!(end, MIGRATIONS_END + 3);
@@ -747,13 +769,14 @@ mod tests {
747769
};
748770

749771
{
750-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
772+
let store =
773+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
751774
let list = store.get_upgrades_list().await;
752775
assert_eq!(list, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]);
753776
let version = store.get_schema_version().await;
754777
assert_eq!(version, MIGRATIONS_END + 3);
755778
}
756779

757-
drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await.unwrap();
780+
drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await.unwrap();
758781
}
759782
}

rust/server/src/main.rs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,27 @@ fn main() {
9595
let postgresql_config =
9696
postgresql_config.expect("PostgreSQLConfig must be defined in config file.");
9797
let endpoint = postgresql_config.to_postgresql_endpoint();
98-
let db_name = postgresql_config.database;
98+
let default_db = postgresql_config.default_database;
99+
let vss_db = postgresql_config.vss_database;
99100
let store: Arc<dyn KvStore> = if let Some(tls_config) = postgresql_config.tls {
100-
let postgres_tls_backend =
101-
match PostgresTlsBackend::new(&endpoint, &db_name, tls_config.crt_pem.as_deref())
102-
.await
103-
{
104-
Ok(backend) => backend,
105-
Err(e) => {
106-
println!("Failed to start postgres tls backend: {}", e);
107-
std::process::exit(-1);
108-
},
109-
};
101+
let postgres_tls_backend = match PostgresTlsBackend::new(
102+
&endpoint,
103+
&default_db,
104+
&vss_db,
105+
tls_config.crt_pem.as_deref(),
106+
)
107+
.await
108+
{
109+
Ok(backend) => backend,
110+
Err(e) => {
111+
println!("Failed to start postgres tls backend: {}", e);
112+
std::process::exit(-1);
113+
},
114+
};
110115
Arc::new(postgres_tls_backend)
111116
} else {
112117
let postgres_plaintext_backend =
113-
match PostgresPlaintextBackend::new(&endpoint, &db_name).await {
118+
match PostgresPlaintextBackend::new(&endpoint, &default_db, &vss_db).await {
114119
Ok(backend) => backend,
115120
Err(e) => {
116121
println!("Failed to start postgres plaintext backend: {}", e);
@@ -119,7 +124,7 @@ fn main() {
119124
};
120125
Arc::new(postgres_plaintext_backend)
121126
};
122-
println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, db_name);
127+
println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, vss_db);
123128

124129
let rest_svc_listener =
125130
TcpListener::bind(&addr).await.expect("Failed to bind listening port");

rust/server/src/util/config.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ pub(crate) struct PostgreSQLConfig {
2424
pub(crate) password: Option<String>, // Optional in TOML, can be overridden by env
2525
pub(crate) host: String,
2626
pub(crate) port: u16,
27-
pub(crate) database: String,
27+
pub(crate) default_database: String,
28+
pub(crate) vss_database: String,
2829
pub(crate) tls: Option<TlsConfig>,
2930
}
3031

rust/server/vss-server-config.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ username = "postgres" # Optional in TOML, can be overridden by env var `VSS_POS
1515
password = "postgres" # Optional in TOML, can be overridden by env var `VSS_POSTGRESQL_PASSWORD`
1616
host = "localhost"
1717
port = 5432
18-
database = "postgres"
18+
default_database = "postgres"
19+
vss_database = "vss"
1920

2021
# [postgresql_config.tls] # Uncomment to make TLS connections to the postgres database
2122
#

0 commit comments

Comments
 (0)