Skip to content

Commit 55b7f54

Browse files
committed
Add option to specify default postgres db name
Not all postgres hosted services use `postgres` Also rename `database` config parameter to `vss_database`
1 parent d127bd1 commit 55b7f54

File tree

4 files changed

+68
-39
lines changed

4 files changed

+68
-39
lines changed

rust/impls/src/postgres_store.rs

Lines changed: 56 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,16 @@ pub type PostgresPlaintextBackend = PostgresBackend<NoTls>;
6464
/// A postgres backend with TLS connections to the database
6565
pub type PostgresTlsBackend = PostgresBackend<MakeTlsConnector>;
6666

67-
async fn make_postgres_db_connection<T>(postgres_endpoint: &str, tls: T) -> Result<Client, Error>
67+
async fn make_db_connection<T>(
68+
postgres_endpoint: &str, db_name: &str, tls: T,
69+
) -> Result<Client, Error>
6870
where
6971
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
7072
T::Stream: Send + Sync,
7173
T::TlsConnect: Send,
7274
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
7375
{
74-
let dsn = format!("{}/{}", postgres_endpoint, "postgres");
76+
let dsn = format!("{}/{}", postgres_endpoint, db_name);
7577
let (client, connection) = tokio_postgres::connect(&dsn, tls)
7678
.await
7779
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
@@ -84,16 +86,16 @@ where
8486
Ok(client)
8587
}
8688

87-
async fn initialize_vss_database<T>(
88-
postgres_endpoint: &str, db_name: &str, tls: T,
89+
async fn create_database<T>(
90+
postgres_endpoint: &str, default_db: &str, db_name: &str, tls: T,
8991
) -> Result<(), Error>
9092
where
9193
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
9294
T::Stream: Send + Sync,
9395
T::TlsConnect: Send,
9496
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
9597
{
96-
let client = make_postgres_db_connection(&postgres_endpoint, tls).await?;
98+
let client = make_db_connection(postgres_endpoint, default_db, tls).await?;
9799

98100
let num_rows = client.execute(CHECK_DB_STMT, &[&db_name]).await.map_err(|e| {
99101
Error::new(
@@ -113,14 +115,16 @@ where
113115
}
114116

115117
#[cfg(test)]
116-
async fn drop_database<T>(postgres_endpoint: &str, db_name: &str, tls: T) -> Result<(), Error>
118+
async fn drop_database<T>(
119+
postgres_endpoint: &str, default_db: &str, db_name: &str, tls: T,
120+
) -> Result<(), Error>
117121
where
118122
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
119123
T::Stream: Send + Sync,
120124
T::TlsConnect: Send,
121125
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
122126
{
123-
let client = make_postgres_db_connection(&postgres_endpoint, tls).await?;
127+
let client = make_db_connection(postgres_endpoint, default_db, tls).await?;
124128

125129
let drop_database_statement = format!("{} {};", DROP_DB_CMD, db_name);
126130
let num_rows = client.execute(&drop_database_statement, &[]).await.map_err(|e| {
@@ -133,15 +137,18 @@ where
133137

134138
impl PostgresPlaintextBackend {
135139
/// Constructs a [`PostgresPlaintextBackend`] using `postgres_endpoint` for PostgreSQL connection information.
136-
pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result<Self, Error> {
137-
PostgresBackend::new_internal(postgres_endpoint, db_name, NoTls).await
140+
pub async fn new(
141+
postgres_endpoint: &str, default_db: &str, vss_db: &str,
142+
) -> Result<Self, Error> {
143+
PostgresBackend::new_internal(postgres_endpoint, default_db, vss_db, NoTls).await
138144
}
139145
}
140146

141147
impl PostgresTlsBackend {
142148
/// Constructs a [`PostgresTlsBackend`] using `postgres_endpoint` for PostgreSQL connection information.
143149
pub async fn new(
144-
postgres_endpoint: &str, db_name: &str, additional_certificate: Option<Certificate>,
150+
postgres_endpoint: &str, default_db: &str, vss_db: &str,
151+
additional_certificate: Option<Certificate>,
145152
) -> Result<Self, Error> {
146153
let mut builder = TlsConnector::builder();
147154
if let Some(cert) = additional_certificate {
@@ -150,8 +157,13 @@ impl PostgresTlsBackend {
150157
let connector = builder.build().map_err(|e| {
151158
Error::new(ErrorKind::Other, format!("Error building tls connector: {}", e))
152159
})?;
153-
PostgresBackend::new_internal(postgres_endpoint, db_name, MakeTlsConnector::new(connector))
154-
.await
160+
PostgresBackend::new_internal(
161+
postgres_endpoint,
162+
default_db,
163+
vss_db,
164+
MakeTlsConnector::new(connector),
165+
)
166+
.await
155167
}
156168
}
157169

@@ -162,9 +174,11 @@ where
162174
T::TlsConnect: Send,
163175
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
164176
{
165-
async fn new_internal(postgres_endpoint: &str, db_name: &str, tls: T) -> Result<Self, Error> {
166-
initialize_vss_database(postgres_endpoint, db_name, tls.clone()).await?;
167-
let vss_dsn = format!("{}/{}", postgres_endpoint, db_name);
177+
async fn new_internal(
178+
postgres_endpoint: &str, default_db: &str, vss_db: &str, tls: T,
179+
) -> Result<Self, Error> {
180+
create_database(postgres_endpoint, default_db, vss_db, tls.clone()).await?;
181+
let vss_dsn = format!("{}/{}", postgres_endpoint, vss_db);
168182
let manager =
169183
PostgresConnectionManager::new_from_stringlike(vss_dsn, tls).map_err(|e| {
170184
Error::new(
@@ -649,24 +663,27 @@ mod tests {
649663
use tokio_postgres::NoTls;
650664

651665
const POSTGRES_ENDPOINT: &str = "postgresql://postgres:postgres@localhost:5432";
666+
const DEFAULT_DB: &str = "postgres";
652667
const MIGRATIONS_START: usize = 0;
653668
const MIGRATIONS_END: usize = MIGRATIONS.len();
654669

655670
static START: OnceCell<()> = OnceCell::const_new();
656671

657672
define_kv_store_tests!(PostgresKvStoreTest, PostgresPlaintextBackend, {
658-
let db_name = "postgres_kv_store_tests";
673+
let vss_db = "postgres_kv_store_tests";
659674
START
660675
.get_or_init(|| async {
661-
let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await;
662-
let store =
663-
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
676+
let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await;
677+
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db)
678+
.await
679+
.unwrap();
664680
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
665681
assert_eq!(start, MIGRATIONS_START);
666682
assert_eq!(end, MIGRATIONS_END);
667683
})
668684
.await;
669-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
685+
let store =
686+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
670687
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
671688
assert_eq!(start, MIGRATIONS_END);
672689
assert_eq!(end, MIGRATIONS_END);
@@ -678,36 +695,40 @@ mod tests {
678695
#[tokio::test]
679696
#[should_panic(expected = "We do not allow downgrades")]
680697
async fn panic_on_downgrade() {
681-
let db_name = "panic_on_downgrade_test";
682-
let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await;
698+
let vss_db = "panic_on_downgrade_test";
699+
let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await;
683700
{
684701
let mut migrations = MIGRATIONS.to_vec();
685702
migrations.push(DUMMY_MIGRATION);
686-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
703+
let store =
704+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
687705
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
688706
assert_eq!(start, MIGRATIONS_START);
689707
assert_eq!(end, MIGRATIONS_END + 1);
690708
};
691709
{
692-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
710+
let store =
711+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
693712
let _ = store.migrate_vss_database(MIGRATIONS).await.unwrap();
694713
};
695714
}
696715

697716
#[tokio::test]
698717
async fn new_migrations_increments_upgrades() {
699-
let db_name = "new_migrations_increments_upgrades_test";
700-
let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await;
718+
let vss_db = "new_migrations_increments_upgrades_test";
719+
let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await;
701720
{
702-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
721+
let store =
722+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
703723
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
704724
assert_eq!(start, MIGRATIONS_START);
705725
assert_eq!(end, MIGRATIONS_END);
706726
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
707727
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
708728
};
709729
{
710-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
730+
let store =
731+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
711732
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
712733
assert_eq!(start, MIGRATIONS_END);
713734
assert_eq!(end, MIGRATIONS_END);
@@ -718,7 +739,8 @@ mod tests {
718739
let mut migrations = MIGRATIONS.to_vec();
719740
migrations.push(DUMMY_MIGRATION);
720741
{
721-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
742+
let store =
743+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
722744
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
723745
assert_eq!(start, MIGRATIONS_END);
724746
assert_eq!(end, MIGRATIONS_END + 1);
@@ -729,7 +751,8 @@ mod tests {
729751
migrations.push(DUMMY_MIGRATION);
730752
migrations.push(DUMMY_MIGRATION);
731753
{
732-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
754+
let store =
755+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
733756
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
734757
assert_eq!(start, MIGRATIONS_END + 1);
735758
assert_eq!(end, MIGRATIONS_END + 3);
@@ -741,13 +764,14 @@ mod tests {
741764
};
742765

743766
{
744-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
767+
let store =
768+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
745769
let list = store.get_upgrades_list().await;
746770
assert_eq!(list, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]);
747771
let version = store.get_schema_version().await;
748772
assert_eq!(version, MIGRATIONS_END + 3);
749773
}
750774

751-
drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await.unwrap();
775+
drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await.unwrap();
752776
}
753777
}

rust/server/src/main.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,10 @@ fn main() {
9090
};
9191

9292
let endpoint = postgresql_config.to_postgresql_endpoint();
93-
let db_name = postgresql_config.database;
93+
let default_db = postgresql_config.default_database;
94+
let vss_db = postgresql_config.vss_database;
9495
let store: Arc<dyn KvStore> = if let Some(tls_config) = postgresql_config.tls {
95-
let additional_certificate = tls_config.ca_file.map(|file| {
96+
let addl_certificate = tls_config.ca_file.map(|file| {
9697
let certificate = match std::fs::read(&file) {
9798
Ok(cert) => cert,
9899
Err(e) => {
@@ -109,7 +110,9 @@ fn main() {
109110
}
110111
});
111112
let postgres_tls_backend =
112-
match PostgresTlsBackend::new(&endpoint, &db_name, additional_certificate).await {
113+
match PostgresTlsBackend::new(&endpoint, &default_db, &vss_db, addl_certificate)
114+
.await
115+
{
113116
Ok(backend) => backend,
114117
Err(e) => {
115118
println!("Failed to start postgres tls backend: {}", e);
@@ -119,7 +122,7 @@ fn main() {
119122
Arc::new(postgres_tls_backend)
120123
} else {
121124
let postgres_plaintext_backend =
122-
match PostgresPlaintextBackend::new(&endpoint, &db_name).await {
125+
match PostgresPlaintextBackend::new(&endpoint, &default_db, &vss_db).await {
123126
Ok(backend) => backend,
124127
Err(e) => {
125128
println!("Failed to start postgres plaintext backend: {}", e);
@@ -128,7 +131,7 @@ fn main() {
128131
};
129132
Arc::new(postgres_plaintext_backend)
130133
};
131-
println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, db_name);
134+
println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, vss_db);
132135

133136
let rest_svc_listener =
134137
TcpListener::bind(&addr).await.expect("Failed to bind listening port");

rust/server/src/util/config.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ pub(crate) struct PostgreSQLConfig {
1919
pub(crate) password: Option<String>, // Optional in TOML, can be overridden by env
2020
pub(crate) host: String,
2121
pub(crate) port: u16,
22-
pub(crate) database: String,
22+
pub(crate) default_database: String,
23+
pub(crate) vss_database: String,
2324
pub(crate) tls: Option<TlsConfig>,
2425
}
2526

rust/server/vss-server-config.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ username = "postgres" # Optional in TOML, can be overridden by env var `VSS_POS
88
password = "postgres" # Optional in TOML, can be overridden by env var `VSS_POSTGRESQL_PASSWORD`
99
host = "localhost"
1010
port = 5432
11-
database = "postgres"
11+
default_database = "postgres"
12+
vss_database = "vss"
1213
# tls = { } # Uncomment to make TLS connections to the postgres database using your machine's PKI
1314
# tls = { ca_file = "ca.pem" } # Uncomment to make TLS connections to the postgres database with an additional root certificate

0 commit comments

Comments
 (0)