Skip to content

Commit 0994939

Browse files
committed
duckdb: add config parity with sqlite
1 parent cda0657 commit 0994939

File tree

3 files changed

+95
-13
lines changed

3 files changed

+95
-13
lines changed

database/duckdb/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
| URL Query | Description |
66
|------------|-------------|
77
| `x-migrations-table` | Name of the migrations table (default: `schema_migrations`) |
8+
| `x-no-tx-wrap` | Disable automatic transaction wrapping for migrations (default: `false`) |
89

910
## Notes
1011

database/duckdb/duckdb.go

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ import (
66
"fmt"
77
"io"
88
nurl "net/url"
9+
"strconv"
910
"strings"
10-
1111
"sync/atomic"
1212

1313
"github.com/golang-migrate/migrate/v4"
1414
"github.com/golang-migrate/migrate/v4/database"
15-
15+
1616
_ "github.com/duckdb/duckdb-go/v2"
1717
)
1818

@@ -22,9 +22,19 @@ func init() {
2222

2323
var DefaultMigrationsTable = "schema_migrations"
2424

25+
var (
26+
ErrNilConfig = errors.New("no config")
27+
)
28+
29+
type Config struct {
30+
MigrationsTable string
31+
NoTxWrap bool
32+
}
33+
2534
type DuckDB struct {
2635
db *sql.DB
2736
isLocked atomic.Bool
37+
config *Config
2838
}
2939

3040
func (d *DuckDB) Open(url string) (database.Driver, error) {
@@ -38,16 +48,28 @@ func (d *DuckDB) Open(url string) (database.Driver, error) {
3848
return nil, fmt.Errorf("opening '%s': %w", dbfile, err)
3949
}
4050

41-
if err := db.Ping(); err != nil {
42-
return nil, fmt.Errorf("pinging: %w", err)
51+
qv := purl.Query()
52+
migrationsTable := qv.Get("x-migrations-table")
53+
if len(migrationsTable) == 0 {
54+
migrationsTable = DefaultMigrationsTable
4355
}
44-
d.db = db
4556

46-
if err := d.ensureVersionTable(); err != nil {
47-
return nil, fmt.Errorf("ensuring version table: %w", err)
57+
noTxWrap := false
58+
if v := qv.Get("x-no-tx-wrap"); v != "" {
59+
noTxWrap, err = strconv.ParseBool(v)
60+
if err != nil {
61+
return nil, fmt.Errorf("x-no-tx-wrap: %s", err)
62+
}
4863
}
4964

50-
return d, nil
65+
if err := db.Ping(); err != nil {
66+
return nil, fmt.Errorf("pinging: %w", err)
67+
}
68+
cfg := &Config{
69+
MigrationsTable: migrationsTable,
70+
NoTxWrap: noTxWrap,
71+
}
72+
return WithInstance(db, cfg)
5173
}
5274

5375
func (d *DuckDB) Close() error {
@@ -118,7 +140,7 @@ func (d *DuckDB) SetVersion(version int, dirty bool) error {
118140
return &database.Error{OrigErr: err, Err: "transaction start failed"}
119141
}
120142

121-
query := "DELETE FROM " + DefaultMigrationsTable
143+
query := "DELETE FROM " + d.config.MigrationsTable
122144
if _, err := tx.Exec(query); err != nil {
123145
return &database.Error{OrigErr: err, Query: []byte(query)}
124146
}
@@ -130,7 +152,7 @@ func (d *DuckDB) SetVersion(version int, dirty bool) error {
130152
// NOTE: Copied from sqlite implementation, unsure if this is necessary for
131153
// duckdb
132154
if version >= 0 || (version == database.NilVersion && dirty) {
133-
query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, DefaultMigrationsTable)
155+
query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, d.config.MigrationsTable)
134156
if _, err := tx.Exec(query, version, dirty); err != nil {
135157
if errRollback := tx.Rollback(); errRollback != nil {
136158
err = errors.Join(err, errRollback)
@@ -147,7 +169,7 @@ func (d *DuckDB) SetVersion(version int, dirty bool) error {
147169
}
148170

149171
func (m *DuckDB) Version() (version int, dirty bool, err error) {
150-
query := "SELECT version, dirty FROM " + DefaultMigrationsTable + " LIMIT 1"
172+
query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
151173
err = m.db.QueryRow(query).Scan(&version, &dirty)
152174
if err != nil {
153175
return database.NilVersion, false, nil
@@ -162,6 +184,13 @@ func (d *DuckDB) Run(migration io.Reader) error {
162184
}
163185
query := string(migr[:])
164186

187+
if d.config.NoTxWrap {
188+
if _, err := d.db.Exec(query); err != nil {
189+
return &database.Error{OrigErr: err, Query: []byte(query)}
190+
}
191+
return nil
192+
}
193+
165194
tx, err := d.db.Begin()
166195
if err != nil {
167196
return &database.Error{OrigErr: err, Err: "transaction start failed"}
@@ -196,10 +225,36 @@ func (d *DuckDB) ensureVersionTable() (err error) {
196225
}
197226
}()
198227

199-
query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (version BIGINT, dirty BOOLEAN);`, DefaultMigrationsTable)
228+
query := fmt.Sprintf(`
229+
CREATE TABLE IF NOT EXISTS %s (version BIGINT, dirty BOOLEAN);
230+
CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
231+
`, d.config.MigrationsTable, d.config.MigrationsTable)
200232

201233
if _, err := d.db.Exec(query); err != nil {
202234
return fmt.Errorf("creating version table via '%s': %w", query, err)
203235
}
204236
return nil
205237
}
238+
239+
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
240+
if config == nil {
241+
return nil, ErrNilConfig
242+
}
243+
244+
if err := instance.Ping(); err != nil {
245+
return nil, err
246+
}
247+
248+
if len(config.MigrationsTable) == 0 {
249+
config.MigrationsTable = DefaultMigrationsTable
250+
}
251+
252+
mx := &DuckDB{
253+
db: instance,
254+
config: config,
255+
}
256+
if err := mx.ensureVersionTable(); err != nil {
257+
return nil, err
258+
}
259+
return mx, nil
260+
}

database/duckdb/duckdb_test.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ import (
55
"path/filepath"
66
"testing"
77

8+
_ "github.com/duckdb/duckdb-go/v2"
89
"github.com/golang-migrate/migrate/v4"
910
dt "github.com/golang-migrate/migrate/v4/database/testing"
1011
_ "github.com/golang-migrate/migrate/v4/source/file"
11-
_ "github.com/duckdb/duckdb-go/v2"
1212
)
1313

1414
func Test(t *testing.T) {
@@ -47,3 +47,29 @@ func TestMigrate(t *testing.T) {
4747

4848
dt.TestMigrate(t, m)
4949
}
50+
51+
func TestNoTxWrap(t *testing.T) {
52+
dir := t.TempDir()
53+
dbFile := filepath.Join(dir, "test.duckdb")
54+
addr := fmt.Sprintf("duckdb://%s?x-no-tx-wrap=true", dbFile)
55+
56+
ddb := &DuckDB{}
57+
d, err := ddb.Open(addr)
58+
if err != nil {
59+
t.Fatalf("calling Open() on addr %s: %s", addr, err)
60+
}
61+
62+
dt.Test(t, d, []byte("BEGIN TRANSACTION; CREATE TABLE t (Qty int, Name string); COMMIT;"))
63+
}
64+
65+
func TestNoTxWrapInvalidValue(t *testing.T) {
66+
dir := t.TempDir()
67+
dbFile := filepath.Join(dir, "test.duckdb")
68+
addr := fmt.Sprintf("duckdb://%s?x-no-tx-wrap=definitely", dbFile)
69+
70+
ddb := &DuckDB{}
71+
_, err := ddb.Open(addr)
72+
if err == nil {
73+
t.Fatal("expected error for invalid x-no-tx-wrap value")
74+
}
75+
}

0 commit comments

Comments
 (0)