Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmd/atlas/internal/cmdapi/cmdapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ const (
flagGitDir = "git-dir"
flagLatest = "latest"
flagLockTimeout = "lock-timeout"
flagLockName = "lock-name"
flagLog = "log"
flagPlan = "plan"
flagRevisionSchema = "revisions-schema"
Expand Down Expand Up @@ -315,6 +316,10 @@ func addFlagLockTimeout(set *pflag.FlagSet, target *time.Duration) {
set.DurationVar(target, flagLockTimeout, 10*time.Second, "set how long to wait for the database lock")
}

func addFlagLockName(set *pflag.FlagSet, target *string, defaultVal string) {
set.StringVar(target, flagLockName, defaultVal, "set lock name for database lock")
}

// addFlagURL adds a URL flag. If given, args[0] override the name, args[1] the shorthand, args[2] the default value.
func addFlagDirURL(set *pflag.FlagSet, target *string, args ...string) {
name, short, val := flagDirURL, "", "file://migrations"
Expand Down
6 changes: 5 additions & 1 deletion cmd/atlas/internal/cmdapi/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type migrateApplyFlags struct {
dryRun bool
logFormat string
lockTimeout time.Duration
lockName string
allowDirty bool // allow working on a database that already has resources
baselineVersion string // apply with this version as baseline
txMode string // (none, file, all)
Expand Down Expand Up @@ -149,6 +150,7 @@ If run with the "--dry-run" flag, atlas will not execute any SQL.`,
addFlagRevisionSchema(cmd.Flags(), &flags.revisionSchema)
addFlagDryRun(cmd.Flags(), &flags.dryRun)
addFlagLockTimeout(cmd.Flags(), &flags.lockTimeout)
addFlagLockName(cmd.Flags(), &flags.lockName, applyLockValue)
cmd.Flags().StringVarP(&flags.baselineVersion, flagBaseline, "", "", "start the first migration after the given baseline version")
cmd.Flags().StringVarP(&flags.txMode, flagTxMode, "", txModeFile, "set transaction mode [none, file, all]")
cmd.Flags().StringVarP(&flags.execOrder, flagExecOrder, "", execOrderLinear, "set file execution order [linear, linear-skip, non-linear]")
Expand Down Expand Up @@ -886,6 +888,7 @@ type migrateSetFlags struct {
url string
dirURL, dirFormat string
revisionSchema string
lockName string
}

// migrateSetCmd represents the 'atlas migrate set' subcommand.
Expand Down Expand Up @@ -919,6 +922,7 @@ to be applied. This command is usually used after manually making changes to the
addFlagDirURL(cmd.Flags(), &flags.dirURL)
addFlagDirFormat(cmd.Flags(), &flags.dirFormat)
addFlagRevisionSchema(cmd.Flags(), &flags.revisionSchema)
addFlagLockName(cmd.Flags(), &flags.lockName, applyLockValue)
return cmd
}

Expand All @@ -934,7 +938,7 @@ func migrateSetRun(cmd *cobra.Command, args []string, flags migrateSetFlags) (re
}
defer client.Close()
// Acquire a lock.
unlock, err := client.Driver.Lock(ctx, applyLockValue, 0)
unlock, err := client.Driver.Lock(ctx, flags.lockName, 0)
if err != nil {
return fmt.Errorf("acquiring database lock: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/atlas/internal/cmdapi/migrate_oss.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func migrateApplyRun(cmd *cobra.Command, args []string, flags migrateApplyFlags,
// Prevent usage printing after input validation.
cmd.SilenceUsage = true
// Acquire a lock.
unlock, err := client.Driver.Lock(ctx, applyLockValue, flags.lockTimeout)
unlock, err := client.Driver.Lock(ctx, flags.lockName, flags.lockTimeout)
if err != nil {
return fmt.Errorf("acquiring database lock: %w", err)
}
Expand Down
27 changes: 21 additions & 6 deletions cmd/atlas/internal/cmdapi/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -137,22 +138,24 @@ func TestMigrate_Apply(t *testing.T) {
))

// A lock will prevent execution.
var lockDriver *sqliteLockerDriver
sqlclient.Register(
"sqlitelockapply",
sqlclient.OpenerFunc(func(ctx context.Context, u *url.URL) (*sqlclient.Client, error) {
client, err := sqlclient.Open(ctx, strings.Replace(u.String(), u.Scheme, "sqlite", 1))
if err != nil {
return nil, err
}
client.Driver = &sqliteLockerDriver{client.Driver}
lockDriver = newTestSqlLockerDriver(client.Driver, errLock)
client.Driver = lockDriver
return client, nil
}),
sqlclient.RegisterDriverOpener(func(db schema.ExecQuerier) (migrate.Driver, error) {
drv, err := sqlite.Open(db)
if err != nil {
return nil, err
}
return &sqliteLockerDriver{drv}, nil
return newTestSqlLockerDriver(drv, errLock), nil
}),
)
f, err := os.Create(filepath.Join(p, "test.db"))
Expand All @@ -163,9 +166,11 @@ func TestMigrate_Apply(t *testing.T) {
migrateApplyCmd(),
"--dir", "file://testdata/sqlite",
"--url", fmt.Sprintf("sqlitelockapply://file:%s?cache=shared&_fk=1", filepath.Join(p, "test.db")),
"--lock-name", "testLock",
)
require.ErrorIs(t, err, errLock)
require.True(t, strings.HasPrefix(s, "Error: acquiring database lock: "+errLock.Error()))
require.True(t, slices.Index(lockDriver.recordedLockNames, "testLock") >= 0)

// Apply zero throws error.
for _, n := range []string{"-1", "0"} {
Expand All @@ -182,6 +187,7 @@ func TestMigrate_Apply(t *testing.T) {
s, err = runCmd(
migrateApplyCmd(),
"--dir", "file://testdata/sqlite",
"--lock-name", "testMigrateLock",
"--url", fmt.Sprintf("sqlite://file:%s?cache=shared&_fk=1", filepath.Join(p, "test.db")),
"1",
)
Expand Down Expand Up @@ -923,7 +929,7 @@ func TestMigrate_Diff(t *testing.T) {
if err != nil {
return nil, err
}
client.Driver = &sqliteLockerDriver{Driver: client.Driver}
client.Driver = newTestSqlLockerDriver(client.Driver, errLock)
return client, nil
}))
f, err := os.Create(filepath.Join(p, "test.db"))
Expand Down Expand Up @@ -1771,12 +1777,21 @@ func copyFile(src, dst string) error {
return err
}

type sqliteLockerDriver struct{ migrate.Driver }
type sqliteLockerDriver struct {
migrate.Driver
lockingError error
recordedLockNames []string
}

func newTestSqlLockerDriver(d migrate.Driver, err error) *sqliteLockerDriver {
return &sqliteLockerDriver{d, err, nil}
}

var errLock = errors.New("lockErr")

func (d *sqliteLockerDriver) Lock(context.Context, string, time.Duration) (schema.UnlockFunc, error) {
return func() error { return nil }, errLock
func (d *sqliteLockerDriver) Lock(ctx context.Context, lockName string, lockWait time.Duration) (schema.UnlockFunc, error) {
d.recordedLockNames = append(d.recordedLockNames, lockName)
return func() error { return nil }, d.lockingError
}

func countFiles(t *testing.T, p string) int {
Expand Down
1 change: 1 addition & 0 deletions cmd/atlas/internal/cmdapi/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ type (
Baseline string `spec:"baseline"`
ExecOrder string `spec:"exec_order"`
LockTimeout string `spec:"lock_timeout"`
LockName string `spec:"lock_name"`
RevisionsSchema string `spec:"revisions_schema"`
Repo *Repo `spec:"repo"`
}
Expand Down
2 changes: 2 additions & 0 deletions cmd/atlas/internal/cmdapi/project_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ env "local" {
dir = "file://migrations"
format = atlas
lock_timeout = "1s"
lock_name = "migrate_lock"
revisions_schema = "revisions"
exec_order = LINEAR_SKIP
}
Expand Down Expand Up @@ -158,6 +159,7 @@ env "multi" {
Dir: "file://migrations",
Format: cmdmigrate.FormatAtlas,
LockTimeout: "1s",
LockName: "migrate_lock",
RevisionsSchema: "revisions",
ExecOrder: "LINEAR_SKIP",
},
Expand Down