diff --git a/cmd/atlas/internal/cmdapi/cmdapi.go b/cmd/atlas/internal/cmdapi/cmdapi.go index 3d272236947..747eccabc95 100644 --- a/cmd/atlas/internal/cmdapi/cmdapi.go +++ b/cmd/atlas/internal/cmdapi/cmdapi.go @@ -283,6 +283,7 @@ const ( flagGitDir = "git-dir" flagLatest = "latest" flagLockTimeout = "lock-timeout" + flagLockName = "lock-name" flagLog = "log" flagPlan = "plan" flagRevisionSchema = "revisions-schema" @@ -315,6 +316,10 @@ func addFlagLockTimeout(set *pflag.FlagSet, target *time.Duration) { set.DurationVar(target, flagLockTimeout, 10*time.Second, "set how long to wait for the database lock") } +func addFlagLockName(set *pflag.FlagSet, target *string, defaultVal string) { + set.StringVar(target, flagLockName, defaultVal, "set lock name for database lock") +} + // addFlagURL adds a URL flag. If given, args[0] override the name, args[1] the shorthand, args[2] the default value. func addFlagDirURL(set *pflag.FlagSet, target *string, args ...string) { name, short, val := flagDirURL, "", "file://migrations" diff --git a/cmd/atlas/internal/cmdapi/migrate.go b/cmd/atlas/internal/cmdapi/migrate.go index 9a08b4038d6..1057abee39d 100644 --- a/cmd/atlas/internal/cmdapi/migrate.go +++ b/cmd/atlas/internal/cmdapi/migrate.go @@ -55,6 +55,7 @@ type migrateApplyFlags struct { dryRun bool logFormat string lockTimeout time.Duration + lockName string allowDirty bool // allow working on a database that already has resources baselineVersion string // apply with this version as baseline txMode string // (none, file, all) @@ -149,6 +150,7 @@ If run with the "--dry-run" flag, atlas will not execute any SQL.`, addFlagRevisionSchema(cmd.Flags(), &flags.revisionSchema) addFlagDryRun(cmd.Flags(), &flags.dryRun) addFlagLockTimeout(cmd.Flags(), &flags.lockTimeout) + addFlagLockName(cmd.Flags(), &flags.lockName, applyLockValue) cmd.Flags().StringVarP(&flags.baselineVersion, flagBaseline, "", "", "start the first migration after the given baseline version") cmd.Flags().StringVarP(&flags.txMode, flagTxMode, "", txModeFile, "set transaction mode [none, file, all]") cmd.Flags().StringVarP(&flags.execOrder, flagExecOrder, "", execOrderLinear, "set file execution order [linear, linear-skip, non-linear]") @@ -886,6 +888,7 @@ type migrateSetFlags struct { url string dirURL, dirFormat string revisionSchema string + lockName string } // migrateSetCmd represents the 'atlas migrate set' subcommand. @@ -919,6 +922,7 @@ to be applied. This command is usually used after manually making changes to the addFlagDirURL(cmd.Flags(), &flags.dirURL) addFlagDirFormat(cmd.Flags(), &flags.dirFormat) addFlagRevisionSchema(cmd.Flags(), &flags.revisionSchema) + addFlagLockName(cmd.Flags(), &flags.lockName, applyLockValue) return cmd } @@ -934,7 +938,7 @@ func migrateSetRun(cmd *cobra.Command, args []string, flags migrateSetFlags) (re } defer client.Close() // Acquire a lock. - unlock, err := client.Driver.Lock(ctx, applyLockValue, 0) + unlock, err := client.Driver.Lock(ctx, flags.lockName, 0) if err != nil { return fmt.Errorf("acquiring database lock: %w", err) } diff --git a/cmd/atlas/internal/cmdapi/migrate_oss.go b/cmd/atlas/internal/cmdapi/migrate_oss.go index e11dbbd11ad..45ffdf30892 100644 --- a/cmd/atlas/internal/cmdapi/migrate_oss.go +++ b/cmd/atlas/internal/cmdapi/migrate_oss.go @@ -60,7 +60,7 @@ func migrateApplyRun(cmd *cobra.Command, args []string, flags migrateApplyFlags, // Prevent usage printing after input validation. cmd.SilenceUsage = true // Acquire a lock. - unlock, err := client.Driver.Lock(ctx, applyLockValue, flags.lockTimeout) + unlock, err := client.Driver.Lock(ctx, flags.lockName, flags.lockTimeout) if err != nil { return fmt.Errorf("acquiring database lock: %w", err) } diff --git a/cmd/atlas/internal/cmdapi/migrate_test.go b/cmd/atlas/internal/cmdapi/migrate_test.go index 7414081bfaa..6aa6b9388f8 100644 --- a/cmd/atlas/internal/cmdapi/migrate_test.go +++ b/cmd/atlas/internal/cmdapi/migrate_test.go @@ -16,6 +16,7 @@ import ( "os/exec" "path/filepath" "runtime" + "slices" "strings" "testing" "time" @@ -137,6 +138,7 @@ func TestMigrate_Apply(t *testing.T) { )) // A lock will prevent execution. + var lockDriver *sqliteLockerDriver sqlclient.Register( "sqlitelockapply", sqlclient.OpenerFunc(func(ctx context.Context, u *url.URL) (*sqlclient.Client, error) { @@ -144,7 +146,8 @@ func TestMigrate_Apply(t *testing.T) { if err != nil { return nil, err } - client.Driver = &sqliteLockerDriver{client.Driver} + lockDriver = newTestSqlLockerDriver(client.Driver, errLock) + client.Driver = lockDriver return client, nil }), sqlclient.RegisterDriverOpener(func(db schema.ExecQuerier) (migrate.Driver, error) { @@ -152,7 +155,7 @@ func TestMigrate_Apply(t *testing.T) { if err != nil { return nil, err } - return &sqliteLockerDriver{drv}, nil + return newTestSqlLockerDriver(drv, errLock), nil }), ) f, err := os.Create(filepath.Join(p, "test.db")) @@ -163,9 +166,11 @@ func TestMigrate_Apply(t *testing.T) { migrateApplyCmd(), "--dir", "file://testdata/sqlite", "--url", fmt.Sprintf("sqlitelockapply://file:%s?cache=shared&_fk=1", filepath.Join(p, "test.db")), + "--lock-name", "testLock", ) require.ErrorIs(t, err, errLock) require.True(t, strings.HasPrefix(s, "Error: acquiring database lock: "+errLock.Error())) + require.True(t, slices.Index(lockDriver.recordedLockNames, "testLock") >= 0) // Apply zero throws error. for _, n := range []string{"-1", "0"} { @@ -182,6 +187,7 @@ func TestMigrate_Apply(t *testing.T) { s, err = runCmd( migrateApplyCmd(), "--dir", "file://testdata/sqlite", + "--lock-name", "testMigrateLock", "--url", fmt.Sprintf("sqlite://file:%s?cache=shared&_fk=1", filepath.Join(p, "test.db")), "1", ) @@ -923,7 +929,7 @@ func TestMigrate_Diff(t *testing.T) { if err != nil { return nil, err } - client.Driver = &sqliteLockerDriver{Driver: client.Driver} + client.Driver = newTestSqlLockerDriver(client.Driver, errLock) return client, nil })) f, err := os.Create(filepath.Join(p, "test.db")) @@ -1771,12 +1777,21 @@ func copyFile(src, dst string) error { return err } -type sqliteLockerDriver struct{ migrate.Driver } +type sqliteLockerDriver struct { + migrate.Driver + lockingError error + recordedLockNames []string +} + +func newTestSqlLockerDriver(d migrate.Driver, err error) *sqliteLockerDriver { + return &sqliteLockerDriver{d, err, nil} +} var errLock = errors.New("lockErr") -func (d *sqliteLockerDriver) Lock(context.Context, string, time.Duration) (schema.UnlockFunc, error) { - return func() error { return nil }, errLock +func (d *sqliteLockerDriver) Lock(ctx context.Context, lockName string, lockWait time.Duration) (schema.UnlockFunc, error) { + d.recordedLockNames = append(d.recordedLockNames, lockName) + return func() error { return nil }, d.lockingError } func countFiles(t *testing.T, p string) int { diff --git a/cmd/atlas/internal/cmdapi/project.go b/cmd/atlas/internal/cmdapi/project.go index 3e4394eaba3..50e06c6b09c 100644 --- a/cmd/atlas/internal/cmdapi/project.go +++ b/cmd/atlas/internal/cmdapi/project.go @@ -79,6 +79,7 @@ type ( Baseline string `spec:"baseline"` ExecOrder string `spec:"exec_order"` LockTimeout string `spec:"lock_timeout"` + LockName string `spec:"lock_name"` RevisionsSchema string `spec:"revisions_schema"` Repo *Repo `spec:"repo"` } diff --git a/cmd/atlas/internal/cmdapi/project_test.go b/cmd/atlas/internal/cmdapi/project_test.go index 5633e353a98..c779580dba6 100644 --- a/cmd/atlas/internal/cmdapi/project_test.go +++ b/cmd/atlas/internal/cmdapi/project_test.go @@ -81,6 +81,7 @@ env "local" { dir = "file://migrations" format = atlas lock_timeout = "1s" + lock_name = "migrate_lock" revisions_schema = "revisions" exec_order = LINEAR_SKIP } @@ -158,6 +159,7 @@ env "multi" { Dir: "file://migrations", Format: cmdmigrate.FormatAtlas, LockTimeout: "1s", + LockName: "migrate_lock", RevisionsSchema: "revisions", ExecOrder: "LINEAR_SKIP", },