diff --git a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_MirrorTaskStore.go b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_MirrorTaskStore.go index dd51d1cdf..b012321fa 100644 --- a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_MirrorTaskStore.go +++ b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_MirrorTaskStore.go @@ -24,6 +24,63 @@ func (_m *MockMirrorTaskStore) EXPECT() *MockMirrorTaskStore_Expecter { return &MockMirrorTaskStore_Expecter{mock: &_m.Mock} } +// CancelMirrorTaskByID provides a mock function with given fields: ctx, taskID +func (_m *MockMirrorTaskStore) CancelMirrorTaskByID(ctx context.Context, taskID int64) (bool, error) { + ret := _m.Called(ctx, taskID) + + if len(ret) == 0 { + panic("no return value specified for CancelMirrorTaskByID") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64) (bool, error)); ok { + return rf(ctx, taskID) + } + if rf, ok := ret.Get(0).(func(context.Context, int64) bool); ok { + r0 = rf(ctx, taskID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, taskID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMirrorTaskStore_CancelMirrorTaskByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CancelMirrorTaskByID' +type MockMirrorTaskStore_CancelMirrorTaskByID_Call struct { + *mock.Call +} + +// CancelMirrorTaskByID is a helper method to define mock.On call +// - ctx context.Context +// - taskID int64 +func (_e *MockMirrorTaskStore_Expecter) CancelMirrorTaskByID(ctx interface{}, taskID interface{}) *MockMirrorTaskStore_CancelMirrorTaskByID_Call { + return &MockMirrorTaskStore_CancelMirrorTaskByID_Call{Call: _e.mock.On("CancelMirrorTaskByID", ctx, taskID)} +} + +func (_c *MockMirrorTaskStore_CancelMirrorTaskByID_Call) Run(run func(ctx context.Context, taskID int64)) *MockMirrorTaskStore_CancelMirrorTaskByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockMirrorTaskStore_CancelMirrorTaskByID_Call) Return(_a0 bool, _a1 error) *MockMirrorTaskStore_CancelMirrorTaskByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMirrorTaskStore_CancelMirrorTaskByID_Call) RunAndReturn(run func(context.Context, int64) (bool, error)) *MockMirrorTaskStore_CancelMirrorTaskByID_Call { + _c.Call.Return(run) + return _c +} + // CancelOtherTasksAndCreate provides a mock function with given fields: ctx, task func (_m *MockMirrorTaskStore) CancelOtherTasksAndCreate(ctx context.Context, task database.MirrorTask) (database.MirrorTask, error) { ret := _m.Called(ctx, task) @@ -583,9 +640,66 @@ func (_c *MockMirrorTaskStore_Update_Call) RunAndReturn(run func(context.Context return _c } -// UpdateStatusAndRepoSyncStatus provides a mock function with given fields: ctx, task, syncStatus -func (_m *MockMirrorTaskStore) UpdateStatusAndRepoSyncStatus(ctx context.Context, task database.MirrorTask, syncStatus types.RepositorySyncStatus) (database.MirrorTask, error) { - ret := _m.Called(ctx, task, syncStatus) +// UpdateProgress provides a mock function with given fields: ctx, task +func (_m *MockMirrorTaskStore) UpdateProgress(ctx context.Context, task database.MirrorTask) (database.MirrorTask, error) { + ret := _m.Called(ctx, task) + + if len(ret) == 0 { + panic("no return value specified for UpdateProgress") + } + + var r0 database.MirrorTask + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, database.MirrorTask) (database.MirrorTask, error)); ok { + return rf(ctx, task) + } + if rf, ok := ret.Get(0).(func(context.Context, database.MirrorTask) database.MirrorTask); ok { + r0 = rf(ctx, task) + } else { + r0 = ret.Get(0).(database.MirrorTask) + } + + if rf, ok := ret.Get(1).(func(context.Context, database.MirrorTask) error); ok { + r1 = rf(ctx, task) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMirrorTaskStore_UpdateProgress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateProgress' +type MockMirrorTaskStore_UpdateProgress_Call struct { + *mock.Call +} + +// UpdateProgress is a helper method to define mock.On call +// - ctx context.Context +// - task database.MirrorTask +func (_e *MockMirrorTaskStore_Expecter) UpdateProgress(ctx interface{}, task interface{}) *MockMirrorTaskStore_UpdateProgress_Call { + return &MockMirrorTaskStore_UpdateProgress_Call{Call: _e.mock.On("UpdateProgress", ctx, task)} +} + +func (_c *MockMirrorTaskStore_UpdateProgress_Call) Run(run func(ctx context.Context, task database.MirrorTask)) *MockMirrorTaskStore_UpdateProgress_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(database.MirrorTask)) + }) + return _c +} + +func (_c *MockMirrorTaskStore_UpdateProgress_Call) Return(_a0 database.MirrorTask, _a1 error) *MockMirrorTaskStore_UpdateProgress_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMirrorTaskStore_UpdateProgress_Call) RunAndReturn(run func(context.Context, database.MirrorTask) (database.MirrorTask, error)) *MockMirrorTaskStore_UpdateProgress_Call { + _c.Call.Return(run) + return _c +} + +// UpdateStatusAndRepoSyncStatus provides a mock function with given fields: ctx, task, statusAction +func (_m *MockMirrorTaskStore) UpdateStatusAndRepoSyncStatus(ctx context.Context, task database.MirrorTask, statusAction string) (database.MirrorTask, error) { + ret := _m.Called(ctx, task, statusAction) if len(ret) == 0 { panic("no return value specified for UpdateStatusAndRepoSyncStatus") @@ -593,17 +707,17 @@ func (_m *MockMirrorTaskStore) UpdateStatusAndRepoSyncStatus(ctx context.Context var r0 database.MirrorTask var r1 error - if rf, ok := ret.Get(0).(func(context.Context, database.MirrorTask, types.RepositorySyncStatus) (database.MirrorTask, error)); ok { - return rf(ctx, task, syncStatus) + if rf, ok := ret.Get(0).(func(context.Context, database.MirrorTask, string) (database.MirrorTask, error)); ok { + return rf(ctx, task, statusAction) } - if rf, ok := ret.Get(0).(func(context.Context, database.MirrorTask, types.RepositorySyncStatus) database.MirrorTask); ok { - r0 = rf(ctx, task, syncStatus) + if rf, ok := ret.Get(0).(func(context.Context, database.MirrorTask, string) database.MirrorTask); ok { + r0 = rf(ctx, task, statusAction) } else { r0 = ret.Get(0).(database.MirrorTask) } - if rf, ok := ret.Get(1).(func(context.Context, database.MirrorTask, types.RepositorySyncStatus) error); ok { - r1 = rf(ctx, task, syncStatus) + if rf, ok := ret.Get(1).(func(context.Context, database.MirrorTask, string) error); ok { + r1 = rf(ctx, task, statusAction) } else { r1 = ret.Error(1) } @@ -619,14 +733,14 @@ type MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call struct { // UpdateStatusAndRepoSyncStatus is a helper method to define mock.On call // - ctx context.Context // - task database.MirrorTask -// - syncStatus types.RepositorySyncStatus -func (_e *MockMirrorTaskStore_Expecter) UpdateStatusAndRepoSyncStatus(ctx interface{}, task interface{}, syncStatus interface{}) *MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call { - return &MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call{Call: _e.mock.On("UpdateStatusAndRepoSyncStatus", ctx, task, syncStatus)} +// - statusAction string +func (_e *MockMirrorTaskStore_Expecter) UpdateStatusAndRepoSyncStatus(ctx interface{}, task interface{}, statusAction interface{}) *MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call { + return &MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call{Call: _e.mock.On("UpdateStatusAndRepoSyncStatus", ctx, task, statusAction)} } -func (_c *MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call) Run(run func(ctx context.Context, task database.MirrorTask, syncStatus types.RepositorySyncStatus)) *MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call { +func (_c *MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call) Run(run func(ctx context.Context, task database.MirrorTask, statusAction string)) *MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(database.MirrorTask), args[2].(types.RepositorySyncStatus)) + run(args[0].(context.Context), args[1].(database.MirrorTask), args[2].(string)) }) return _c } @@ -636,7 +750,7 @@ func (_c *MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call) Return(_a0 dat return _c } -func (_c *MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call) RunAndReturn(run func(context.Context, database.MirrorTask, types.RepositorySyncStatus) (database.MirrorTask, error)) *MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call { +func (_c *MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call) RunAndReturn(run func(context.Context, database.MirrorTask, string) (database.MirrorTask, error)) *MockMirrorTaskStore_UpdateStatusAndRepoSyncStatus_Call { _c.Call.Return(run) return _c } diff --git a/builder/rpc/mirror_svc_client.go b/builder/rpc/mirror_svc_client.go index 792b1e719..a261c7431 100644 --- a/builder/rpc/mirror_svc_client.go +++ b/builder/rpc/mirror_svc_client.go @@ -10,7 +10,7 @@ import ( ) type MirrorSvcClient interface { - CancelMirror(ctx context.Context, mirrorID int64) error + CancelMirror(ctx context.Context, taskID int64) error } type MirrorSvcClientImpl struct { @@ -23,12 +23,12 @@ func NewMirrorSvcClient(endpoint string, opts ...RequestOption) MirrorSvcClient } } -func (c *MirrorSvcClientImpl) CancelMirror(ctx context.Context, mirrorID int64) error { +func (c *MirrorSvcClientImpl) CancelMirror(ctx context.Context, taskID int64) error { type CancelReq struct { - MirrorID int64 `json:"mirror_id"` + TaskID int64 `json:"task_id"` } req := CancelReq{ - MirrorID: mirrorID, + TaskID: taskID, } path := "/api/v1/lfs_sync_internal/cancel" diff --git a/builder/store/database/mirror_task.go b/builder/store/database/mirror_task.go index 35aa99fd9..194e4ef12 100644 --- a/builder/store/database/mirror_task.go +++ b/builder/store/database/mirror_task.go @@ -17,9 +17,11 @@ type mirrorTaskStoreImpl struct { type MirrorTaskStore interface { CancelOtherTasksAndCreate(ctx context.Context, task MirrorTask) (MirrorTask, error) + CancelMirrorTaskByID(ctx context.Context, taskID int64) (bool, error) Create(ctx context.Context, task MirrorTask) (MirrorTask, error) Update(ctx context.Context, task MirrorTask) (MirrorTask, error) - UpdateStatusAndRepoSyncStatus(ctx context.Context, task MirrorTask, syncStatus types.RepositorySyncStatus) (MirrorTask, error) + UpdateProgress(ctx context.Context, task MirrorTask) (MirrorTask, error) + UpdateStatusAndRepoSyncStatus(ctx context.Context, task MirrorTask, statusAction string) (MirrorTask, error) FindByMirrorID(ctx context.Context, mirrorID int64) (*MirrorTask, error) Delete(ctx context.Context, ID int64) error GetHighestPriorityByTaskStatus(ctx context.Context, status []types.MirrorTaskStatus) (MirrorTask, error) @@ -29,6 +31,21 @@ type MirrorTaskStore interface { ResetRunningTasks(ctx context.Context, fromStatus types.MirrorTaskStatus, toStatus types.MirrorTaskStatus) (int, error) } +var mirrorTaskStatusToRepoStatusMap = map[types.MirrorTaskStatus]types.RepositorySyncStatus{ + types.MirrorQueued: types.SyncStatusPending, + types.MirrorRepoSyncStart: types.SyncStatusInProgress, + types.MirrorRepoSyncFailed: types.SyncStatusFailed, + types.MirrorRepoSyncFinished: types.SyncStatusInProgress, + types.MirrorRepoSyncFatal: types.SyncStatusFailed, + types.MirrorLfsSyncStart: types.SyncStatusInProgress, + types.MirrorLfsSyncFailed: types.SyncStatusFailed, + types.MirrorLfsSyncFinished: types.SyncStatusCompleted, + types.MirrorLfsSyncFatal: types.SyncStatusFailed, + types.MirrorLfsIncomplete: types.SyncStatusFailed, + types.MirrorCanceled: types.SyncStatusCanceled, + types.MirrorRepoTooLarge: types.SyncStatusFailed, +} + func NewMirrorTaskStore() MirrorTaskStore { return &mirrorTaskStoreImpl{ db: defaultDB, @@ -196,6 +213,15 @@ func (m *mirrorTaskStoreImpl) Update(ctx context.Context, task MirrorTask) (Mirr return task, errorx.HandleDBError(err, nil) } +func (m *mirrorTaskStoreImpl) UpdateProgress(ctx context.Context, task MirrorTask) (MirrorTask, error) { + _, err := m.db.Operator.Core.NewUpdate(). + Model(&task). + Column("progress", "error_message"). + WherePK(). + Exec(ctx) + return task, errorx.HandleDBError(err, nil) +} + func (m *mirrorTaskStoreImpl) FindByMirrorID(ctx context.Context, mirrorID int64) (*MirrorTask, error) { var task MirrorTask err := m.db.Operator.Core.NewSelect().Model(&task).Where("mirror_id = ?", mirrorID).Scan(ctx) @@ -277,7 +303,7 @@ func (m *mirrorTaskStoreImpl) ListByStatusWithPriority(ctx context.Context, stat Relation("Mirror"). Relation("Mirror.Repository"). Where("mirror_task.status IN (?)", bun.In(status)). - OrderExpr("mirror_task.priority DESC, mirror_task.created_at DESC"). + OrderExpr("mirror_task.priority DESC, mirror_task.updated_at DESC"). Limit(per). Offset((page - 1) * per). Scan(ctx) @@ -306,6 +332,33 @@ func (m *mirrorTaskStoreImpl) CancelOtherTasksAndCreate(ctx context.Context, tas return task, errorx.HandleDBError(err, nil) } +func (m *mirrorTaskStoreImpl) CancelMirrorTaskByID(ctx context.Context, ID int64) (bool, error) { + var cancelled bool + err := m.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + var task MirrorTask + err := tx.NewSelect(). + Model(&task). + Where("id = ?", ID). + For("UPDATE"). + Scan(ctx) + if err != nil { + return err + } + + tFSM := NewMirrorTaskWithFSM(&task) + if tFSM.SubmitEvent(ctx, MirrorCancel) { + task.Status = types.MirrorTaskStatus(tFSM.Current()) + _, err = tx.NewUpdate().Model(&task).WherePK().Exec(ctx) + if err != nil { + return err + } + cancelled = true + } + return nil + }) + return cancelled, errorx.HandleDBError(err, nil) +} + func (m *mirrorTaskStoreImpl) ResetRunningTasks(ctx context.Context, fromStatus types.MirrorTaskStatus, toStatus types.MirrorTaskStatus) (int, error) { var task MirrorTask result, err := m.db.Operator.Core.NewUpdate(). @@ -323,10 +376,32 @@ func (m *mirrorTaskStoreImpl) ResetRunningTasks(ctx context.Context, fromStatus return int(rowsAffected), nil } -func (m *mirrorTaskStoreImpl) UpdateStatusAndRepoSyncStatus(ctx context.Context, task MirrorTask, syncStatus types.RepositorySyncStatus) (MirrorTask, error) { +func (m *mirrorTaskStoreImpl) UpdateStatusAndRepoSyncStatus(ctx context.Context, task MirrorTask, statusAction string) (MirrorTask, error) { err := m.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - _, err := tx.NewUpdate(). + // Lock the mirror_task row to prevent concurrent status updates + var current MirrorTask + err := tx.NewSelect(). + Model(¤t). + Where("id = ?", task.ID). + For("UPDATE"). + Scan(ctx) + if err != nil { + return err + } + + // Validate FSM transition on the locked row to prevent TOCTOU race + tFSM := NewMirrorTaskWithFSM(¤t) + if !tFSM.SubmitEvent(ctx, statusAction) { + return fmt.Errorf("mirror task status %s not allow action %s", current.Status, statusAction) + } + task.Status = types.MirrorTaskStatus(tFSM.Current()) + syncStatus := mirrorTaskStatusToRepoStatusMap[task.Status] + + // Only update status and related fields, avoid overwriting fields changed + // by other processes (e.g. priority) + _, err = tx.NewUpdate(). Model(&task). + Column("status", "error_message", "progress", "updated_at", "retry_count", "before_last_commit_id", "after_last_commit_id"). WherePK(). Exec(ctx) if err != nil { diff --git a/builder/store/database/mirror_task_test.go b/builder/store/database/mirror_task_test.go index dbababc22..4c4db02a2 100644 --- a/builder/store/database/mirror_task_test.go +++ b/builder/store/database/mirror_task_test.go @@ -175,3 +175,246 @@ func TestMirrorTaskStore_ResetRunningTasks(t *testing.T) { require.Nil(t, err) require.Equal(t, types.MirrorQueued, queuedTask.Status) } + +func TestMirrorTaskStore_UpdateStatusAndRepoSyncStatus(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + taskStore := database.NewMirrorTaskStoreWithDB(db) + mirrorStore := database.NewMirrorStoreWithDB(db) + repoStore := database.NewRepoStoreWithDB(db) + + repo, err := repoStore.CreateRepo(ctx, database.Repository{ + UserID: 1, + Path: "test/repo", + GitPath: "test/repo.git", + Name: "repo", + Nickname: "Test Repo", + DefaultBranch: "main", + Private: false, + SyncStatus: types.SyncStatusPending, + }) + require.Nil(t, err) + + mirror, err := mirrorStore.Create(ctx, &database.Mirror{ + Interval: "1h", + SourceUrl: "https://example.com/test/repo.git", + RepositoryID: repo.ID, + MirrorSourceID: 1, + }) + require.Nil(t, err) + + task, err := taskStore.Create(ctx, database.MirrorTask{ + MirrorID: mirror.ID, + Status: types.MirrorRepoSyncStart, + Priority: types.LowMirrorPriority, + Mirror: &database.Mirror{ + RepositoryID: repo.ID, + }, + }) + require.Nil(t, err) + + updatedTask, err := taskStore.UpdateStatusAndRepoSyncStatus(ctx, task, database.MirrorSuccess) + require.Nil(t, err) + require.Equal(t, types.MirrorRepoSyncFinished, updatedTask.Status) + + var updatedRepo database.Repository + err = db.Core.NewSelect().Model(&updatedRepo).Where("id = ?", repo.ID).Scan(ctx) + require.Nil(t, err) + require.Equal(t, types.SyncStatusInProgress, updatedRepo.SyncStatus) + + var updatedTaskFromDB database.MirrorTask + err = db.Core.NewSelect().Model(&updatedTaskFromDB).Where("id = ?", task.ID).Scan(ctx) + require.Nil(t, err) + require.Equal(t, types.MirrorRepoSyncFinished, updatedTaskFromDB.Status) +} + +func TestMirrorTaskStore_UpdateStatusAndRepoSyncStatus_MultipleSyncStatuses(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + taskStore := database.NewMirrorTaskStoreWithDB(db) + mirrorStore := database.NewMirrorStoreWithDB(db) + repoStore := database.NewRepoStoreWithDB(db) + + repo, err := repoStore.CreateRepo(ctx, database.Repository{ + UserID: 1, + Path: "test/repo2", + GitPath: "test/repo2.git", + Name: "repo2", + Nickname: "Test Repo 2", + DefaultBranch: "main", + Private: false, + SyncStatus: types.SyncStatusPending, + }) + require.Nil(t, err) + + mirror, err := mirrorStore.Create(ctx, &database.Mirror{ + Interval: "1h", + SourceUrl: "https://example.com/test/repo2.git", + RepositoryID: repo.ID, + MirrorSourceID: 1, + }) + require.Nil(t, err) + + testCases := []struct { + name string + initialStatus types.MirrorTaskStatus + action string + expectedStatus types.MirrorTaskStatus + expectedSync types.RepositorySyncStatus + }{ + {"continue from queued", types.MirrorQueued, database.MirrorContinue, types.MirrorRepoSyncStart, types.SyncStatusInProgress}, + {"success from repo_sync_start", types.MirrorRepoSyncStart, database.MirrorSuccess, types.MirrorRepoSyncFinished, types.SyncStatusInProgress}, + {"continue from repo_sync_finished", types.MirrorRepoSyncFinished, database.MirrorContinue, types.MirrorLfsSyncStart, types.SyncStatusInProgress}, + {"success from lfs_sync_start", types.MirrorLfsSyncStart, database.MirrorSuccess, types.MirrorLfsSyncFinished, types.SyncStatusCompleted}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + task, err := taskStore.Create(ctx, database.MirrorTask{ + MirrorID: mirror.ID, + Status: tc.initialStatus, + Priority: types.HighMirrorPriority, + Mirror: &database.Mirror{ + RepositoryID: repo.ID, + }, + }) + require.Nil(t, err) + + updatedTask, err := taskStore.UpdateStatusAndRepoSyncStatus(ctx, task, tc.action) + require.Nil(t, err) + require.Equal(t, tc.expectedStatus, updatedTask.Status) + + var updatedRepo database.Repository + err = db.Core.NewSelect().Model(&updatedRepo).Where("id = ?", repo.ID).Scan(ctx) + require.Nil(t, err) + require.Equal(t, tc.expectedSync, updatedRepo.SyncStatus) + }) + } +} + +func TestMirrorTaskStore_UpdateStatusAndRepoSyncStatus_FailedStatus(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + taskStore := database.NewMirrorTaskStoreWithDB(db) + mirrorStore := database.NewMirrorStoreWithDB(db) + repoStore := database.NewRepoStoreWithDB(db) + + repo, err := repoStore.CreateRepo(ctx, database.Repository{ + UserID: 1, + Path: "test/repo3", + GitPath: "test/repo3.git", + Name: "repo3", + Nickname: "Test Repo 3", + DefaultBranch: "main", + Private: false, + SyncStatus: types.SyncStatusInProgress, + }) + require.Nil(t, err) + + mirror, err := mirrorStore.Create(ctx, &database.Mirror{ + Interval: "1h", + SourceUrl: "https://example.com/test/repo3.git", + RepositoryID: repo.ID, + MirrorSourceID: 1, + }) + require.Nil(t, err) + + task, err := taskStore.Create(ctx, database.MirrorTask{ + MirrorID: mirror.ID, + Status: types.MirrorRepoSyncStart, + Priority: types.HighMirrorPriority, + ErrorMessage: "sync failed", + Mirror: &database.Mirror{ + RepositoryID: repo.ID, + }, + }) + require.Nil(t, err) + + updatedTask, err := taskStore.UpdateStatusAndRepoSyncStatus(ctx, task, database.MirrorFail) + require.Nil(t, err) + require.Equal(t, types.MirrorRepoSyncFailed, updatedTask.Status) + + var updatedRepo database.Repository + err = db.Core.NewSelect().Model(&updatedRepo).Where("id = ?", repo.ID).Scan(ctx) + require.Nil(t, err) + require.Equal(t, types.SyncStatusFailed, updatedRepo.SyncStatus) +} + +func TestMirrorTaskStore_UpdateProgress(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewMirrorTaskStoreWithDB(db) + + task, err := store.Create(ctx, database.MirrorTask{ + MirrorID: 1, + Status: types.MirrorLfsSyncStart, + Priority: types.HighMirrorPriority, + Progress: 50, + ErrorMessage: "previous error", + }) + require.Nil(t, err) + require.Greater(t, task.ID, int64(0)) + + // Update progress and error message + task.Progress = 80 + task.ErrorMessage = "new error" + updatedTask, err := store.UpdateProgress(ctx, task) + require.Nil(t, err) + require.Equal(t, 80, updatedTask.Progress) + require.Equal(t, "new error", updatedTask.ErrorMessage) + require.Equal(t, types.MirrorLfsSyncStart, updatedTask.Status) + + // Verify in DB: progress and error_message updated, status preserved + var dbTask database.MirrorTask + err = db.Core.NewSelect().Model(&dbTask).Where("id = ?", task.ID).Scan(ctx) + require.Nil(t, err) + require.Equal(t, 80, dbTask.Progress) + require.Equal(t, "new error", dbTask.ErrorMessage) + require.Equal(t, types.MirrorLfsSyncStart, dbTask.Status, + "UpdateProgress must not change the status field") +} + +func TestMirrorTaskStore_UpdateProgress_DoesNotOverwritePriority(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewMirrorTaskStoreWithDB(db) + + task, err := store.Create(ctx, database.MirrorTask{ + MirrorID: 1, + Status: types.MirrorLfsSyncStart, + Priority: types.ASAPMirrorPriority, + Progress: 0, + }) + require.Nil(t, err) + + // Simulate: another process changes priority concurrently + _, err = db.Core.NewUpdate(). + Model(&database.MirrorTask{}). + Set("priority = ?", types.LowMirrorPriority). + Where("id = ?", task.ID). + Exec(ctx) + require.Nil(t, err) + + // UpdateProgress should NOT overwrite priority + task.Progress = 100 + task.Priority = types.ASAPMirrorPriority // original value in struct + _, err = store.UpdateProgress(ctx, task) + require.Nil(t, err) + + var dbTask database.MirrorTask + err = db.Core.NewSelect().Model(&dbTask).Where("id = ?", task.ID).Scan(ctx) + require.Nil(t, err) + require.Equal(t, 100, dbTask.Progress) + require.Equal(t, types.LowMirrorPriority, dbTask.Priority, + "UpdateProgress must not overwrite priority set by another process") +} diff --git a/common/types/mirror.go b/common/types/mirror.go index e48d89b9b..ea1f991d4 100644 --- a/common/types/mirror.go +++ b/common/types/mirror.go @@ -59,7 +59,7 @@ type CreateMirrorRepoReq struct { SourceNamespace string `json:"source_namespace" binding:"required"` SourceName string `json:"source_name" binding:"required"` // source id for HF,github etc - MirrorSourceID int64 `json:"mirror_source_id" binding:"required"` + MirrorSourceID int64 `json:"mirror_source_id"` // repo basic info RepoType RepositoryType `json:"repo_type" binding:"required"` @@ -198,6 +198,7 @@ type MirrorListResp struct { type MirrorTask struct { MirrorID int64 `json:"mirror_id"` + TaskID int64 `json:"task_id"` SourceUrl string `json:"source_url"` Priority int `json:"priority"` RepoPath string `json:"repo_path"` @@ -224,3 +225,8 @@ type UpdateMirrorNamespaceMappingReq struct { Enabled *bool `json:"enabled"` ID int64 `json:"id"` } + +type MirrorFilter struct { + Search string `json:"search"` + Status *MirrorTaskStatus `json:"status"` +} diff --git a/component/repo.go b/component/repo.go index 535860b4a..a3ba1f51a 100644 --- a/component/repo.go +++ b/component/repo.go @@ -119,8 +119,6 @@ type RepoComponent interface { DownloadFile(ctx context.Context, req *types.GetFileReq, userName string) (io.ReadCloser, int64, string, error) InternalDownloadFile(ctx context.Context, req *types.GetFileReq) (io.ReadCloser, int64, string, error) Branches(ctx context.Context, req *types.GetBranchesReq) ([]types.Branch, error) - CreateBranch(ctx context.Context, req *types.CreateBranchReq) error - DeleteBranch(ctx context.Context, req *types.DeleteBranchReq) error Tags(ctx context.Context, req *types.GetTagsReq) ([]database.Tag, error) UpdateTags(ctx context.Context, namespace, name string, repoType types.RepositoryType, category, currentUser string, tags []string) error Tree(ctx context.Context, req *types.GetFileReq) ([]*types.File, error) @@ -461,6 +459,14 @@ func (c *repoComponentImpl) DeleteRepo(ctx context.Context, req types.DeleteRepo return nil, fmt.Errorf("fail to find mirror, %w", err) } + // If the repository is a mirror, cancel the mirror task before deletion + if mirror != nil { + err = c.mirrorSvcClient.CancelMirror(ctx, mirror.CurrentTaskID) + if err != nil { + return nil, fmt.Errorf("fail to cancel mirror, %w", err) + } + } + // fetch lfs metas before database deletion lfsMetas, err := c.lfsMetaObjectStore.FindByRepoID(ctx, repo.ID) if err != nil { @@ -472,13 +478,6 @@ func (c *repoComponentImpl) DeleteRepo(ctx context.Context, req types.DeleteRepo return nil, fmt.Errorf("fail to clean repo relations, %w", err) } - if mirror != nil { - err = c.mirrorSvcClient.CancelMirror(ctx, mirror.ID) - if err != nil { - return nil, fmt.Errorf("fail to cancel mirror, %w", err) - } - } - err = c.git.DeleteRepo(ctx, repo.GitalyPath()) if err != nil && status.Code(err) != codes.NotFound { slog.Error("fail to update repo in git ", slog.Any("req", req), slog.String("error", err.Error())) @@ -581,7 +580,14 @@ func (c *repoComponentImpl) CreateFork(ctx context.Context, req types.CreateFork } func (c *repoComponentImpl) cleanLfsStorage(ctx context.Context, repoID int64, migrated bool, lfsMetas []database.LfsMetaObject) { - slog.Info("Cleaning LFS storage for repo", slog.Int64("repo_id", repoID), slog.Bool("migrated", migrated), slog.Int("file_count", len(lfsMetas))) + defer func() { + if r := recover(); r != nil { + slog.Error("cleanLfsStorage recovered from panic", + slog.Any("panic", r), slog.Int64("repo_id", repoID)) + } + }() + + slog.InfoContext(ctx, "Cleaning LFS storage for repo", slog.Int64("repo_id", repoID), slog.Bool("migrated", migrated), slog.Int("file_count", len(lfsMetas))) objectsCh := make(chan minio.ObjectInfo) go func() { @@ -591,11 +597,11 @@ func (c *repoComponentImpl) cleanLfsStorage(ctx context.Context, repoID int64, m // For non-migrated (shared) storage, check if other repos use this OID exists, err := c.lfsMetaObjectStore.ExistsByOidExclRepo(ctx, meta.Oid, repoID) if err != nil { - slog.Error("Failed to check OID references", slog.String("oid", meta.Oid), slog.Any("error", err)) + slog.ErrorContext(ctx, "Failed to check OID references", slog.String("oid", meta.Oid), slog.Any("error", err)) continue } if exists { - slog.Debug("Skipping shared LFS file", slog.String("oid", meta.Oid), slog.Int64("repo_id", repoID)) + slog.DebugContext(ctx, "Skipping shared LFS file", slog.String("oid", meta.Oid), slog.Int64("repo_id", repoID)) continue } } @@ -608,9 +614,9 @@ func (c *repoComponentImpl) cleanLfsStorage(ctx context.Context, repoID int64, m }() for rErr := range c.s3Client.RemoveObjects(ctx, c.config.S3.Bucket, objectsCh, minio.RemoveObjectsOptions{}) { - slog.Error("Failed to remove LFS object", slog.String("key", rErr.ObjectName), slog.Any("error", rErr.Err)) + slog.ErrorContext(ctx, "Failed to remove LFS object", slog.String("key", rErr.ObjectName), slog.Any("error", rErr.Err)) } - slog.Info("Completed LFS storage cleanup for repo", slog.Int64("repo_id", repoID)) + slog.InfoContext(ctx, "Completed LFS storage cleanup for repo", slog.Int64("repo_id", repoID)) } func (c *repoComponentImpl) copyLfsObjects(ctx context.Context, sourceRepoID, targetRepoID int64) error { @@ -660,91 +666,6 @@ func (c *repoComponentImpl) copyLfsObjects(ctx context.Context, sourceRepoID, ta } // PublicToUser gets visible repos of the given user and user's orgs -func (c *repoComponentImpl) CreateBranch(ctx context.Context, req *types.CreateBranchReq) error { - repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) - if err != nil { - return fmt.Errorf("failed to find repo, error: %w", err) - } - - permission, err := c.GetUserRepoPermission(ctx, req.CurrentUser, repo) - if err != nil { - return fmt.Errorf("failed to check user permission, error: %w", err) - } - if !permission.CanWrite { - return errorx.ErrForbidden - } - - sourceRef := req.CommitID - if sourceRef == "" { - sourceRef = repo.DefaultBranch - } - - lastCommit, err := c.git.GetRepoLastCommit(ctx, gitserver.GetRepoLastCommitReq{ - Namespace: req.Namespace, - Name: req.Name, - RepoType: req.RepoType, - Ref: sourceRef, - }) - if err != nil { - return fmt.Errorf("failed to get last commit for ref %s, error: %w", sourceRef, err) - } - - createBranchReq := gitserver.CreateBranchReq{ - Namespace: req.Namespace, - Name: req.Name, - BranchName: req.BranchName, - CommitID: lastCommit.ID, - RepoType: req.RepoType, - } - - err = c.git.CreateBranch(ctx, createBranchReq) - if err != nil { - return fmt.Errorf("failed to create branch in git server, error: %w", err) - } - - return nil -} - -func (c *repoComponentImpl) DeleteBranch(ctx context.Context, req *types.DeleteBranchReq) error { - repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) - if err != nil { - return fmt.Errorf("failed to find repo, error: %w", err) - } - - permission, err := c.GetUserRepoPermission(ctx, req.CurrentUser, repo) - if err != nil { - return fmt.Errorf("failed to check user permission, error: %w", err) - } - if !permission.CanWrite { - return errorx.ErrForbidden - } - - if req.BranchName == repo.DefaultBranch { - return fmt.Errorf("cannot delete default branch") - } - - user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) - if err != nil { - return fmt.Errorf("failed to find user, error: %w", err) - } - - deleteBranchReq := gitserver.DeleteBranchReq{ - Namespace: req.Namespace, - Name: req.Name, - Ref: req.BranchName, - RepoType: req.RepoType, - Username: user.Username, - Email: user.Email, - } - - err = c.git.DeleteRepoBranch(ctx, deleteBranchReq) - if err != nil { - return fmt.Errorf("failed to delete branch in git server, error: %w", err) - } - - return nil -} - func (c *repoComponentImpl) PublicToUser(ctx context.Context, repoType types.RepositoryType, userName string, filter *types.RepoFilter, per, page int) (repos []*database.Repository, count int, err error) { var repoOwnerIDs []int64 var isAdmin bool @@ -2344,10 +2265,6 @@ func (c *repoComponentImpl) UpdateMirror(ctx context.Context, req types.UpdateMi if err != nil { return nil, fmt.Errorf("failed to find mirror, error: %w", err) } - mirrorSource, err := c.mirrorSourceStore.Get(ctx, req.MirrorSourceID) - if err != nil { - return nil, fmt.Errorf("failed to get mirror source, err: %w, id: %d", err, req.MirrorSourceID) - } pushAccessToken, err := c.tokenStore.GetUserGitToken(ctx, req.CurrentUser) if err != nil { @@ -2363,7 +2280,7 @@ func (c *repoComponentImpl) UpdateMirror(ctx context.Context, req types.UpdateMi mirror.PushUsername = req.CurrentUser mirror.PushAccessToken = pushAccessToken.Token mirror.SourceRepoPath = req.SourceRepoPath - mirror.LocalRepoPath = fmt.Sprintf("%s_%s_%s_%s", mirrorSource.SourceName, req.RepoType, req.Namespace, req.Name) + mirror.LocalRepoPath = fmt.Sprintf("%s_%s_%s", req.RepoType, req.Namespace, req.Name) err = c.mirrorStore.Update(ctx, mirror) if err != nil { return nil, fmt.Errorf("failed to update mirror, error: %w", err) @@ -2720,15 +2637,7 @@ func (c *repoComponentImpl) Preupload(ctx context.Context, req types.PreuploadRe Paths: paths, }) if err != nil { - // If the branch doesn't exist yet, treat it as if there are no existing files. - // This allows uploading to a new branch that hasn't been created yet. - if errors.Is(err, errorx.ErrGitCommitNotFound) || status.Code(err) == codes.NotFound || status.Code(err) == codes.InvalidArgument { - slog.InfoContext(ctx, "branch not found when getting existing files for preupload, treating as empty", - slog.String("revision", req.Revision), - slog.String("repo", req.Namespace+"/"+req.Name)) - } else { - return nil, fmt.Errorf("failed to get repo files, err: %w", err) - } + return nil, fmt.Errorf("failed to get repo files, err: %w", err) } for _, file := range existFiles { @@ -2742,7 +2651,7 @@ func (c *repoComponentImpl) Preupload(ctx context.Context, req types.PreuploadRe Ref: req.Revision, Path: GitAttributesFileName, }) - if err != nil && status.Code(err) != codes.InvalidArgument && !errors.Is(err, errorx.ErrGitCommitNotFound) { + if err != nil && status.Code(err) != codes.InvalidArgument { return nil, fmt.Errorf("failed to get gitattributes file, err: %w", err) } @@ -2761,7 +2670,7 @@ func (c *repoComponentImpl) Preupload(ctx context.Context, req types.PreuploadRe Path: GitIgnoreFileName, }) code := status.Code(err) - if err != nil && code != codes.InvalidArgument && !errors.Is(err, errorx.ErrGitCommitNotFound) { + if err != nil && code != codes.InvalidArgument { return nil, fmt.Errorf("failed to get .gitignore file, err: %w", err) } diff --git a/component/repo_test.go b/component/repo_test.go index a3fb8e0b8..8969e0776 100644 --- a/component/repo_test.go +++ b/component/repo_test.go @@ -1026,15 +1026,12 @@ func TestRepoComponent_UpdateMirror(t *testing.T) { AccessToken: "ak", PushUsername: "user", PushAccessToken: "foo", - LocalRepoPath: "a_model_ns_n", + LocalRepoPath: "model_ns_n", MirrorSourceID: 111, } mi := m repo.mocks.stores.MirrorMock().EXPECT().FindByRepoID(ctx, int64(123)).Return(&mi, nil) repo.mocks.stores.AccessTokenMock().EXPECT().GetUserGitToken(ctx, "user").Return(&database.AccessToken{Token: "foo"}, nil) - repo.mocks.stores.MirrorSourceMock().EXPECT().Get(ctx, int64(111)).Return(&database.MirrorSource{ - SourceName: "a", - }, nil) repo.mocks.stores.MirrorMock().EXPECT().Update(ctx, &m).Return(nil) mm, err := repo.UpdateMirror(ctx, types.UpdateMirrorReq{ @@ -1379,7 +1376,7 @@ func TestRepoComponent_DeployDetail(t *testing.T) { mockUserRepoAdminPermission(ctx, repo.mocks.stores, "user") repo.mocks.stores.ClusterInfoMock().EXPECT().ByClusterID(ctx, "cluster").Return(database.ClusterInfo{ - Zone: "z", + Zone: "z", }, nil) repo.mocks.stores.DeployTaskMock().EXPECT().GetDeployByID(ctx, int64(1)).Return(&database.Deploy{ RepoID: 1, diff --git a/mirror/component/manager.go b/mirror/component/manager.go index 05ab8fe38..dcd96a869 100644 --- a/mirror/component/manager.go +++ b/mirror/component/manager.go @@ -55,12 +55,19 @@ func (c *managerComponentImpl) SyncNow(ctx context.Context, workerID int, mtID i return nil } -func (c *managerComponentImpl) Cancel(ctx context.Context, mirrorID int64) (bool, error) { - found, err := c.manager.StopWorkerByMirrorID(mirrorID) +func (c *managerComponentImpl) Cancel(ctx context.Context, taskID int64) (bool, error) { + dbCancelled, err := c.mirrorTaskStore.CancelMirrorTaskByID(ctx, taskID) if err != nil { - return found, fmt.Errorf("fail to stop worker: %w", err) + return false, fmt.Errorf("fail to cancel mirror task in db: %w", err) } - return found, nil + + workerStopped, _ := c.manager.StopWorkerByTaskID(taskID) + + if !dbCancelled && !workerStopped { + return false, fmt.Errorf("no task found for mirror %d", taskID) + } + + return true, nil } func (c *managerComponentImpl) ListTasks(ctx context.Context, per, page int) (types.MirrorListResp, error) { @@ -74,6 +81,7 @@ func (c *managerComponentImpl) ListTasks(ctx context.Context, per, page int) (ty if task.Mirror != nil && task.Mirror.Repository != nil { taskResp[id] = types.MirrorTask{ MirrorID: task.MirrorID, + TaskID: task.ID, SourceUrl: task.Mirror.SourceUrl, Priority: int(task.Priority), RepoPath: task.Mirror.RepoPath(), @@ -93,7 +101,8 @@ func (c *managerComponentImpl) ListTasks(ctx context.Context, per, page int) (ty for _, task := range waittingTasks { if task.Mirror != nil && task.Mirror.Repository != nil { lfsTasks = append(lfsTasks, types.MirrorTask{ - MirrorID: task.ID, + MirrorID: task.MirrorID, + TaskID: task.ID, SourceUrl: task.Mirror.SourceUrl, Priority: int(task.Priority), RepoPath: task.Mirror.RepoPath(), diff --git a/mirror/handler/manager.go b/mirror/handler/manager.go index 843d318cf..d812e2d05 100644 --- a/mirror/handler/manager.go +++ b/mirror/handler/manager.go @@ -40,12 +40,12 @@ type StopWorkerByIDReq struct { } type SyncNowReq struct { - MirrorID int64 `json:"mirror_id" binding:"required"` + TaskID int64 `json:"task_id" binding:"required"` WorkerID int `json:"worker_id"` } type CancelReq struct { - MirrorID int64 `json:"mirror_id" binding:"required"` + TaskID int64 `json:"task_id" binding:"required"` } func (h *ManagerHandler) StopWorkerByID(c *gin.Context) { @@ -68,7 +68,7 @@ func (h *ManagerHandler) SyncNow(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - err := h.managerComponent.SyncNow(c, req.WorkerID, req.MirrorID) + err := h.managerComponent.SyncNow(c, req.WorkerID, req.TaskID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -83,7 +83,7 @@ func (h *ManagerHandler) Cancel(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - found, err := h.managerComponent.Cancel(c, req.MirrorID) + found, err := h.managerComponent.Cancel(c, req.TaskID) if err != nil { if found { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) diff --git a/mirror/lfssyncer/lfs.go b/mirror/lfssyncer/lfs.go index 411200681..84a4b8de3 100644 --- a/mirror/lfssyncer/lfs.go +++ b/mirror/lfssyncer/lfs.go @@ -47,10 +47,10 @@ var ( rk repoPathKey = "repoPath" suk sourceUrlKey = "sourceUrl" dbk defaultBranchKey = "defaultBranch" - maxRetries int = 3 + maxRetries = 3 + MaxGroupCount = 15 + maxPartNum = 1000 MaxGroupSize int64 = 10 * 1024 * 1024 * 1024 // 10GB - MaxGroupCount int = 15 - maxPartNum int = 1000 ) type LfsSyncWorker struct { @@ -160,7 +160,9 @@ func (w *LfsSyncWorker) Run(mt *database.MirrorTask) { ) mt.ErrorMessage = "mirror not found" mt.Status = types.MirrorLfsSyncFailed - _, updateErr := w.mirrorTaskStore.Update(w.ctx, *mt) + updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer updateCancel() + _, updateErr := w.mirrorTaskStore.Update(updateCtx, *mt) if updateErr != nil { slog.Error("fail to update mirror task", slog.Int("workerID", w.id), @@ -176,6 +178,17 @@ func (w *LfsSyncWorker) Run(mt *database.MirrorTask) { slog.Int("workerID", w.id), slog.Any("error", err), ) + mt.ErrorMessage = err.Error() + mt.Status = types.MirrorLfsSyncFailed + updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer updateCancel() + _, updateErr := w.mirrorTaskStore.Update(updateCtx, *mt) + if updateErr != nil { + slog.Error("fail to update mirror task", + slog.Int("workerID", w.id), + slog.Any("error", updateErr), + ) + } return } @@ -185,6 +198,17 @@ func (w *LfsSyncWorker) Run(mt *database.MirrorTask) { slog.Int("workerId", w.id), slog.Any("error", err), ) + mt.ErrorMessage = err.Error() + mt.Status = types.MirrorLfsSyncFailed + updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer updateCancel() + _, updateErr := w.mirrorTaskStore.Update(updateCtx, *mt) + if updateErr != nil { + slog.Error("fail to update mirror task", + slog.Int("workerID", w.id), + slog.Any("error", updateErr), + ) + } return } @@ -244,41 +268,21 @@ func (w *LfsSyncWorker) Run(mt *database.MirrorTask) { mt.Progress = 100 } - mtFSM := database.NewMirrorTaskWithFSM(mt) // Can not use w.ctx cause it could be canceled ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - canContinue := mtFSM.SubmitEvent(ctx, action) - if !canContinue { - slog.Error("fail to submit event", + updatedMt, err := w.mirrorTaskStore.UpdateStatusAndRepoSyncStatus(ctx, *mt, action) + if err != nil { + slog.Error("fail to update mirror task status and repository status", slog.Int("workerID", w.id), slog.Any("status", mt.Status), slog.Any("action", action), - ) - - mt.ErrorMessage = fmt.Sprintf("fail to submit event, status: %s, action: %s", mt.Status, action) - mt.Status = types.MirrorLfsSyncFailed - repoSyncStatus := common.MirrorTaskStatusToRepoStatus(mt.Status) - _, updateErr := w.mirrorTaskStore.UpdateStatusAndRepoSyncStatus(w.ctx, *mt, repoSyncStatus) - if updateErr != nil { - slog.Error("fail to update mirror task", - slog.Int("workerID", w.id), - slog.Any("error", updateErr), - ) - } - return - } - mt.Status = types.MirrorTaskStatus(mtFSM.Current()) - repoSyncStatus := common.MirrorTaskStatusToRepoStatus(mt.Status) - _, err = w.mirrorTaskStore.UpdateStatusAndRepoSyncStatus(ctx, *mt, repoSyncStatus) - if err != nil { - slog.Error("fail to update mirror task", - slog.Int("workerID", w.id), slog.Any("error", err), ) return } + *mt = updatedMt err = w.sendMessage(ctx, mirror, mt.Status) if err != nil { @@ -425,7 +429,24 @@ func (w *LfsSyncWorker) getSyncPointers( return pointers, fmt.Errorf("fail to get lfs meta objects: %w", err) } repo := mt.Mirror.Repository - var toBeUpdateLfsMetaObjects []database.LfsMetaObject + + slog.Info( + "fetched lfs meta objects", + slog.Int("workerID", w.id), + slog.Int64("repoID", repo.ID), + slog.String("repoPath", repoPath), + slog.Int("lfsCount", len(lfsMetaObjects)), + ) + + if len(lfsMetaObjects) == 0 { + slog.Info("no lfs files to sync, finish sync lfs", slog.Int("workerId", w.id), slog.String("repoPath", repoPath)) + return pointers, nil + } + + var ( + existingOIDs []string + missingOIDs []string + ) for _, lfsMetaObject := range lfsMetaObjects { objectKey := common.BuildLfsPath(repo.ID, lfsMetaObject.Oid, repo.Migrated) exists, err := w.CheckIfLFSFileExists(ctx, objectKey, lfsMetaObject.Size) @@ -438,26 +459,37 @@ func (w *LfsSyncWorker) getSyncPointers( slog.Any("repoPath", repo.Path), slog.Any("repoType", repo.RepositoryType), ) + return pointers, fmt.Errorf("failed to check if lfs file exists: %w", err) } if exists { - lfsMetaObject.Existing = true - toBeUpdateLfsMetaObjects = append(toBeUpdateLfsMetaObjects, lfsMetaObject) + existingOIDs = append(existingOIDs, lfsMetaObject.Oid) } else { + missingOIDs = append(missingOIDs, lfsMetaObject.Oid) pointers = append(pointers, &types.Pointer{ Oid: lfsMetaObject.Oid, Size: lfsMetaObject.Size, }) } } - if len(toBeUpdateLfsMetaObjects) > 0 { - err = w.lfsMetaObjectStore.BulkUpdateOrCreate(ctx, repo.ID, toBeUpdateLfsMetaObjects) - if err != nil { - slog.Error( - "failed to update lfs meta objects", - slog.Int("workerID", w.id), - slog.Any("error", err), - ) - } + + slog.Info( + "checked lfs meta objects", + slog.Int("workerID", w.id), + slog.Int64("repoID", repo.ID), + slog.String("repoPath", repoPath), + slog.Int("lfsCount", len(lfsMetaObjects)), + slog.Int("existingCount", len(existingOIDs)), + slog.Int("missingCount", len(missingOIDs)), + ) + + if err = w.lfsMetaObjectStore.BulkUpdateExistingByOIDs(ctx, repo.ID, existingOIDs, missingOIDs); err != nil { + slog.Error( + "failed to update lfs meta objects", + slog.String("repoPath", repoPath), + slog.Int("workerID", w.id), + slog.Any("error", err), + ) + return pointers, fmt.Errorf("failed to update lfs meta objects: %w", err) } if len(pointers) == 0 { @@ -512,6 +544,7 @@ func (w *LfsSyncWorker) downloadAndUploadLFSFiles( pointerGroups [][]*types.Pointer, repo *database.Repository, ) error { + var finalErr error totalPointerCount := 0 syncedPointerCount := 0 @@ -520,6 +553,12 @@ func (w *LfsSyncWorker) downloadAndUploadLFSFiles( } for _, pointers := range pointerGroups { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + pointers, err := w.GetLFSDownloadURLs(ctx, mirror.SourceUrl, repo.DefaultBranch, pointers) if err != nil { slog.Error( @@ -534,10 +573,19 @@ func (w *LfsSyncWorker) downloadAndUploadLFSFiles( } for _, pointer := range pointers { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + err := w.downloadAndUploadLFSFile(ctx, repo, pointer) if err != nil { + if errors.Is(err, context.Canceled) { + return err + } + finalErr = err slog.Error("failed to download and upload lfs file", - slog.Any("error", err), slog.Int("workerID", w.id), slog.Any("error", err), slog.Any("sourceURL", mirror.SourceUrl), @@ -545,17 +593,21 @@ func (w *LfsSyncWorker) downloadAndUploadLFSFiles( slog.Any("repoType", repo.RepositoryType), slog.Any("pointer", pointer), ) + } else { + syncedPointerCount++ } - syncedPointerCount++ // Update the progress of the mirror task mt.Progress = int(math.Ceil(float64(syncedPointerCount) / float64(totalPointerCount) * 100)) - _, err = w.mirrorTaskStore.Update(ctx, *mt) + _, err = w.mirrorTaskStore.UpdateProgress(ctx, *mt) if err != nil { return fmt.Errorf("failed to update mirror task progress: %w", err) } } } + if finalErr != nil { + return finalErr + } return nil } @@ -577,6 +629,7 @@ func (w *LfsSyncWorker) downloadAndUploadLFSFile( slog.Any("repoPath", repo.Path), slog.Any("repoType", repo.RepositoryType), ) + return fmt.Errorf("failed to check if lfs file exists: %w", err) } lmo := database.LfsMetaObject{ Size: pointer.Size, @@ -1076,6 +1129,12 @@ func (w *LfsSyncWorker) downloadAndUploadSmallFile( pointer *types.Pointer, objectKey string, ) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + slog.Info( "downloading small file directly", slog.Int("workerID", w.id), @@ -1083,7 +1142,7 @@ func (w *LfsSyncWorker) downloadAndUploadSmallFile( slog.Any("size", pointer.Size), ) - resp, err := w.downloadRange(pointer.DownloadURL, 0, pointer.Size-1) + resp, err := w.downloadRange(ctx, pointer.DownloadURL, 0, pointer.Size-1) if err != nil { return fmt.Errorf("failed to download small file: %w", err) } @@ -1151,7 +1210,7 @@ func (w *LfsSyncWorker) downloadAndUploadPartWithRetry( slog.Any("offset", start), slog.Any("end", end), ) - resp, err := w.downloadRange(downloadURL, start, end) + resp, err := w.downloadRange(ctx, downloadURL, start, end) if err != nil { slog.Error( "failed to download range", @@ -1163,7 +1222,7 @@ func (w *LfsSyncWorker) downloadAndUploadPartWithRetry( slog.Any("attempt", attempt), slog.Any("error", err), ) - if resp.StatusCode == http.StatusForbidden { + if resp != nil && resp.StatusCode == http.StatusForbidden { sourceURL := ctx.Value(suk).(string) defaultBranch := ctx.Value(dbk).(string) pointers, err := w.GetLFSDownloadURLs(ctx, sourceURL, defaultBranch, []*types.Pointer{pointer}) @@ -1176,7 +1235,9 @@ func (w *LfsSyncWorker) downloadAndUploadPartWithRetry( return part, fmt.Errorf("failed to download range: %w", err) } - defer resp.Body.Close() + if resp != nil { + defer resp.Body.Close() + } slog.Info( "uploading range", @@ -1228,10 +1289,11 @@ func (w *LfsSyncWorker) downloadAndUploadPartWithRetry( } func (w *LfsSyncWorker) downloadRange( + ctx context.Context, downloadURL string, start, end int64, ) (*http.Response, error) { - req, err := http.NewRequest("GET", downloadURL, nil) + req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil) if err != nil { return nil, err } @@ -1250,11 +1312,11 @@ func (w *LfsSyncWorker) downloadRange( resp, err := w.httpClient.Do(req) if err != nil { - return nil, err + return resp, err } if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode) + return resp, fmt.Errorf("unexpected status code %d", resp.StatusCode) } return resp, nil @@ -1267,8 +1329,7 @@ func (w *LfsSyncWorker) CheckIfLFSFileExists( ) (bool, error) { objInfo, err := w.ossClient.StatObject(ctx, w.config.S3.Bucket, objectKey, minio.StatObjectOptions{}) if err != nil { - // Check if it's a "not found" error - if strings.Contains(err.Error(), "NoSuchKey") || strings.Contains(err.Error(), "not found") { + if isLFSObjectNotFound(err) { return false, nil } return false, err @@ -1286,6 +1347,27 @@ func (w *LfsSyncWorker) CheckIfLFSFileExists( return true, nil } +// isLFSObjectNotFound detects missing LFS objects from structured S3 errors first, with message matching as a compatibility fallback. +func isLFSObjectNotFound(err error) bool { + if err == nil { + return false + } + + minioErr := minio.ToErrorResponse(err) + if minioErr.Code == "NoSuchKey" { + return true + } + if minioErr.Code == "" && minioErr.StatusCode == http.StatusNotFound { + return true + } + + errMsg := strings.ToLower(err.Error()) + return strings.Contains(errMsg, "nosuchkey") || + strings.Contains(errMsg, "key does not exist") || + strings.Contains(errMsg, "object does not exist") || + strings.Contains(errMsg, "object not found") +} + func (w *LfsSyncWorker) GetLFSDownloadURLs( ctx context.Context, repoCloneURL, branch string, @@ -1419,6 +1501,7 @@ func (w *LfsSyncWorker) triggerGitCallback( workflowOptions := client.StartWorkflowOptions{ TaskQueue: workflow.HandlePushQueueName, + ID: fmt.Sprintf("mirror-lfs-%s-%s-%s-%s", repo.RepositoryType, namespace, name, commit.ID), } _, err = w.workflowClient.ExecuteWorkflow( diff --git a/mirror/lfssyncer/lfs_test.go b/mirror/lfssyncer/lfs_test.go index c6090db38..fd635a93b 100644 --- a/mirror/lfssyncer/lfs_test.go +++ b/mirror/lfssyncer/lfs_test.go @@ -192,7 +192,7 @@ func (suite *LfsSyncWorkerTestSuite) TestRun_MirrorNotFound() { suite.mocks.mirrorStore.EXPECT().FindByID(suite.ctx, task.MirrorID). Return(nil, errors.New("mirror not found")) - suite.mocks.mirrorTaskStore.EXPECT().Update(suite.ctx, mock.MatchedBy(func(mt database.MirrorTask) bool { + suite.mocks.mirrorTaskStore.EXPECT().Update(mock.Anything, mock.MatchedBy(func(mt database.MirrorTask) bool { return mt.Status == types.MirrorLfsSyncFailed && mt.ErrorMessage == "mirror not found" })).Return(*task, nil) @@ -211,8 +211,12 @@ func (suite *LfsSyncWorkerTestSuite) TestRun_RepoNotFound() { suite.mocks.repoStore.EXPECT().FindById(suite.ctx, mirror.RepositoryID). Return(nil, errors.New("repo not found")) + suite.mocks.mirrorTaskStore.EXPECT().Update(mock.Anything, mock.MatchedBy(func(mt database.MirrorTask) bool { + return mt.Status == types.MirrorLfsSyncFailed + })).Return(*task, nil) + suite.worker.Run(task) - // Should return early without further processing + // Should return early but updates DB with failed status } func (suite *LfsSyncWorkerTestSuite) TestRun_ShouldNotSync() { @@ -259,8 +263,11 @@ func (suite *LfsSyncWorkerTestSuite) TestRun_ContextCanceled() { suite.mocks.lfsMetaObjectStore.EXPECT().FindByRepoID(mock.Anything, repo.ID).Return(nil, context.Canceled) suite.mocks.mirrorTaskStore.EXPECT().UpdateStatusAndRepoSyncStatus(mock.Anything, mock.MatchedBy(func(mt database.MirrorTask) bool { - return mt.Status == types.MirrorCanceled - }), types.SyncStatusCanceled).Return(*task, nil) + return strings.Contains(mt.ErrorMessage, "canceled") + }), database.MirrorCancel).RunAndReturn(func(ctx context.Context, mt database.MirrorTask, action string) (database.MirrorTask, error) { + mt.Status = types.MirrorCanceled + return mt, nil + }) suite.mocks.msgSender.EXPECT().Send(mock.Anything, mock.Anything).Return(hook.Response{}, nil) suite.mocks.recomComponent.EXPECT().SetOpWeight(mock.Anything, mock.Anything, mock.Anything).Return(nil) @@ -305,8 +312,11 @@ func (suite *LfsSyncWorkerTestSuite) TestRun_Success_NoLfsFiles() { // Final updates suite.mocks.mirrorTaskStore.EXPECT().UpdateStatusAndRepoSyncStatus(mock.Anything, mock.MatchedBy(func(mt database.MirrorTask) bool { - return mt.Status == types.MirrorRepoSyncFinished && mt.Progress == 100 - }), types.SyncStatusInProgress).Return(*task, nil) + return mt.Progress == 100 + }), database.MirrorSuccess).RunAndReturn(func(ctx context.Context, mt database.MirrorTask, action string) (database.MirrorTask, error) { + mt.Status = types.MirrorRepoSyncFinished + return mt, nil + }) suite.mocks.msgSender.EXPECT().Send(mock.Anything, mock.MatchedBy(func(req types.MessageRequest) bool { return strings.Contains(req.Parameters, "finished") @@ -344,9 +354,11 @@ func (suite *LfsSyncWorkerTestSuite) TestRun_SyncLfsError() { // Expect failure status update suite.mocks.mirrorTaskStore.EXPECT().UpdateStatusAndRepoSyncStatus(mock.Anything, mock.MatchedBy(func(mt database.MirrorTask) bool { - return mt.Status == types.MirrorLfsSyncFailed && - strings.Contains(mt.ErrorMessage, expectedError.Error()) - }), types.SyncStatusFailed).Return(*task, nil) + return strings.Contains(mt.ErrorMessage, expectedError.Error()) + }), database.MirrorFail).RunAndReturn(func(ctx context.Context, mt database.MirrorTask, action string) (database.MirrorTask, error) { + mt.Status = types.MirrorLfsSyncFailed + return mt, nil + }) // Failure message suite.mocks.msgSender.EXPECT().Send(mock.Anything, mock.MatchedBy(func(req types.MessageRequest) bool { @@ -420,7 +432,7 @@ func (suite *LfsSyncWorkerTestSuite) TestRun_Success_HasLfsFiles() { Size: 1024, }, }, nil) - suite.mocks.lfsMetaObjectStore.EXPECT().BulkUpdateOrCreate(mock.Anything, mock.Anything, mock.Anything).Return(nil) + suite.mocks.lfsMetaObjectStore.EXPECT().BulkUpdateExistingByOIDs(mock.Anything, repo.ID, []string{"oid1"}, []string(nil)).Return(nil) suite.mocks.git.EXPECT().GetDiffBetweenTwoCommits(mock.Anything, mock.Anything).Return(&types.GiteaCallbackPushReq{}, nil) suite.mocks.ossClient.EXPECT().StatObject(mock.Anything, "test-bucket", "lfs/oi/d1", mock.Anything).Return(minio.ObjectInfo{Size: 1024}, nil) suite.mocks.lfsMetaObjectStore.EXPECT().UpdateOrCreate(mock.Anything, mock.Anything).Return(nil, nil) @@ -448,8 +460,11 @@ func (suite *LfsSyncWorkerTestSuite) TestRun_Success_HasLfsFiles() { // Final updates suite.mocks.mirrorTaskStore.EXPECT().UpdateStatusAndRepoSyncStatus(mock.Anything, mock.MatchedBy(func(mt database.MirrorTask) bool { - return mt.Status == types.MirrorRepoSyncFinished && mt.Progress == 100 - }), types.SyncStatusInProgress).Return(*task, nil) + return mt.Progress == 100 + }), database.MirrorSuccess).RunAndReturn(func(ctx context.Context, mt database.MirrorTask, action string) (database.MirrorTask, error) { + mt.Status = types.MirrorRepoSyncFinished + return mt, nil + }) suite.mocks.msgSender.EXPECT().Send(mock.Anything, mock.MatchedBy(func(req types.MessageRequest) bool { return strings.Contains(req.Parameters, "finished") @@ -473,10 +488,14 @@ func (suite *LfsSyncWorkerTestSuite) TestGetSyncPointers() { ctx := context.WithValue(suite.ctx, rk, "models/test/repo") tests := []struct { - name string - lfsObjects []database.LfsMetaObject - expectedCount int - expectedError bool + name string + lfsObjects []database.LfsMetaObject + storageExisting map[string]bool + expectedCount int + findError error + statError error + expectedError bool + expectedErrorMsg string }{ { name: "no lfs objects", @@ -487,8 +506,12 @@ func (suite *LfsSyncWorkerTestSuite) TestGetSyncPointers() { { name: "all existing objects", lfsObjects: []database.LfsMetaObject{ - {Oid: "oid1", Size: 100, Existing: true}, - {Oid: "oid2", Size: 200, Existing: true}, + {Oid: "oid1", Size: 100, Existing: false}, + {Oid: "oid2", Size: 200, Existing: false}, + }, + storageExisting: map[string]bool{ + "oid1": true, + "oid2": true, }, expectedCount: 0, expectedError: false, @@ -496,18 +519,36 @@ func (suite *LfsSyncWorkerTestSuite) TestGetSyncPointers() { { name: "mixed existing and non-existing objects", lfsObjects: []database.LfsMetaObject{ - {Oid: "oid1", Size: 100, Existing: true}, - {Oid: "oid2", Size: 200, Existing: false}, - {Oid: "oid3", Size: 300, Existing: false}, + {Oid: "oid1", Size: 100, Existing: false}, + {Oid: "oid2", Size: 200, Existing: true}, + {Oid: "oid3", Size: 300, Existing: true}, + }, + storageExisting: map[string]bool{ + "oid1": true, + "oid2": false, + "oid3": false, }, expectedCount: 2, expectedError: false, }, { - name: "database error", - lfsObjects: nil, - expectedCount: 0, - expectedError: true, + name: "database error", + lfsObjects: nil, + expectedCount: 0, + findError: errors.New("database error"), + expectedError: true, + expectedErrorMsg: "fail to get lfs meta objects", + }, + { + name: "object storage check error", + lfsObjects: []database.LfsMetaObject{ + {Oid: "oid1", Size: 100, Existing: false}, + }, + storageExisting: nil, + expectedCount: 0, + statError: errors.New("network error"), + expectedError: true, + expectedErrorMsg: "failed to check if lfs file exists", }, } @@ -516,25 +557,35 @@ func (suite *LfsSyncWorkerTestSuite) TestGetSyncPointers() { // Create a new worker and mocks for each test case worker, mocks := newTestLfsSyncWorker(t, "") - if tt.expectedError { - mocks.lfsMetaObjectStore.EXPECT().FindByRepoID(ctx, repo.ID).Return(nil, errors.New("database error")) + if tt.findError != nil { + mocks.lfsMetaObjectStore.EXPECT().FindByRepoID(ctx, repo.ID).Return(nil, tt.findError) } else { mocks.lfsMetaObjectStore.EXPECT().FindByRepoID(ctx, repo.ID).Return(tt.lfsObjects, nil) + var existingOIDs []string + var missingOIDs []string for _, obj := range tt.lfsObjects { - if obj.Existing { + if tt.statError != nil { + mocks.ossClient.EXPECT().StatObject(ctx, "test-bucket", "lfs/"+obj.Oid[:2]+"/"+obj.Oid[2:], mock.Anything).Return(minio.ObjectInfo{}, tt.statError) + continue + } + if tt.storageExisting[obj.Oid] { mocks.ossClient.EXPECT().StatObject(ctx, "test-bucket", "lfs/"+obj.Oid[:2]+"/"+obj.Oid[2:], mock.Anything).Return(minio.ObjectInfo{Size: obj.Size}, nil) + existingOIDs = append(existingOIDs, obj.Oid) } else { mocks.ossClient.EXPECT().StatObject(ctx, "test-bucket", "lfs/"+obj.Oid[:2]+"/"+obj.Oid[2:], mock.Anything).Return(minio.ObjectInfo{}, errors.New("NoSuchKey")) + missingOIDs = append(missingOIDs, obj.Oid) } } - mocks.lfsMetaObjectStore.EXPECT().BulkUpdateOrCreate(mock.Anything, mock.Anything, mock.Anything).Return(nil) + if len(tt.lfsObjects) > 0 && !tt.expectedError { + mocks.lfsMetaObjectStore.EXPECT().BulkUpdateExistingByOIDs(ctx, repo.ID, existingOIDs, missingOIDs).Return(nil) + } } pointers, err := worker.getSyncPointers(ctx, task) if tt.expectedError { assert.Error(t, err) - assert.Contains(t, err.Error(), "fail to get lfs meta objects") + assert.Contains(t, err.Error(), tt.expectedErrorMsg) } else { assert.NoError(t, err) assert.Len(t, pointers, tt.expectedCount) @@ -593,6 +644,30 @@ func (suite *LfsSyncWorkerTestSuite) TestCheckIfLFSFileExists() { expectedExists: false, expectedError: false, }, + { + name: "file not found from minio error code", + statError: minio.ErrorResponse{Code: "NoSuchKey"}, + expectedSize: 1024, + expectDelete: false, + expectedExists: false, + expectedError: false, + }, + { + name: "file not found from empty minio error code", + statError: minio.ErrorResponse{StatusCode: http.StatusNotFound}, + expectedSize: 1024, + expectDelete: false, + expectedExists: false, + expectedError: false, + }, + { + name: "bucket not found is an error", + statError: minio.ErrorResponse{Code: "NoSuchBucket", StatusCode: http.StatusNotFound}, + expectedSize: 1024, + expectDelete: false, + expectedExists: false, + expectedError: true, + }, { name: "other error", statError: errors.New("network error"), @@ -648,6 +723,23 @@ func (suite *LfsSyncWorkerTestSuite) TestDownloadAndUploadLFSFile_FileExists() { assert.NoError(suite.T(), err) } +func (suite *LfsSyncWorkerTestSuite) TestDownloadAndUploadLFSFile_CheckExistsError() { + ctx := context.WithValue(suite.ctx, rk, "models/test/repo") + repo := createTestRepository() + pointer := &types.Pointer{ + Oid: "test-oid-123", + Size: 1024, + } + + // The metadata must not be updated when object storage verification fails. + suite.mocks.ossClient.EXPECT().StatObject(ctx, suite.worker.config.S3.Bucket, mock.Anything, mock.Anything).Return(minio.ObjectInfo{}, errors.New("network error")) + + err := suite.worker.downloadAndUploadLFSFile(ctx, repo, pointer) + + assert.Error(suite.T(), err) + assert.Contains(suite.T(), err.Error(), "failed to check if lfs file exists") +} + func (suite *LfsSyncWorkerTestSuite) TestDownloadAndUploadLFSFile_EmptyDownloadURL() { ctx := context.WithValue(suite.ctx, rk, "models/test/repo") repo := createTestRepository() @@ -742,7 +834,7 @@ func (suite *LfsSyncWorkerTestSuite) TestDownloadRange() { })) defer server.Close() - resp, err := suite.worker.downloadRange(server.URL, 0, 9) + resp, err := suite.worker.downloadRange(suite.ctx, server.URL, 0, 9) assert.NoError(suite.T(), err) assert.NotNil(suite.T(), resp) @@ -921,7 +1013,7 @@ func TestLfsSyncWorker_SetContext(t *testing.T) { func TestLfsSyncWorker_DownloadRange_InvalidURL(t *testing.T) { worker, _ := newTestLfsSyncWorker(t, "") - resp, err := worker.downloadRange("invalid-url", 0, 10) + resp, err := worker.downloadRange(context.Background(), "invalid-url", 0, 10) if resp != nil { resp.Body.Close() } @@ -937,7 +1029,7 @@ func TestLfsSyncWorker_DownloadRange_ServerError(t *testing.T) { })) defer server.Close() - resp, err := worker.downloadRange(server.URL, 0, 10) + resp, err := worker.downloadRange(context.Background(), server.URL, 0, 10) if resp != nil { resp.Body.Close() } @@ -964,3 +1056,105 @@ func TestLfsSyncWorker_ConcurrentOperations(t *testing.T) { // Worker should still be functional assert.Equal(t, 1, worker.ID()) } + +func TestLfsSyncWorker_DownloadRange_NotFoundError(t *testing.T) { + worker, _ := newTestLfsSyncWorker(t, "") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + resp, err := worker.downloadRange(context.Background(), server.URL, 0, 10) + if resp != nil { + resp.Body.Close() + } + + assert.Error(t, err) +} + +func TestLfsSyncWorker_DownloadRange_ContextCanceled(t *testing.T) { + worker, _ := newTestLfsSyncWorker(t, "") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(500 * time.Millisecond) + w.WriteHeader(http.StatusPartialContent) + _, _ = w.Write([]byte("data")) + })) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + resp, err := worker.downloadRange(ctx, server.URL, 0, 10) + elapsed := time.Since(start) + + if resp != nil { + resp.Body.Close() + } + + assert.Error(t, err) + assert.True(t, elapsed < 100*time.Millisecond, + "downloadRange should fail fast on cancelled context, took %v", elapsed) +} + +func TestLfsSyncWorker_DownloadRange_Success(t *testing.T) { + worker, _ := newTestLfsSyncWorker(t, "") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", "4") + w.WriteHeader(http.StatusPartialContent) + _, _ = w.Write([]byte("data")) + })) + defer server.Close() + + resp, err := worker.downloadRange(context.Background(), server.URL, 0, 3) + require.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + assert.Equal(t, "data", string(body)) +} + +func TestLfsSyncWorker_DownloadAndUploadLFSFiles_ContextCanceled(t *testing.T) { + worker, mocks := newTestLfsSyncWorker(t, "") + ctx, cancel := context.WithCancel(context.Background()) + ctx = context.WithValue(ctx, rk, "models/test/repo") + ctx = context.WithValue(ctx, suk, "http://example.com") + ctx = context.WithValue(ctx, dbk, "main") + cancel() + + repo := createTestRepository() + mirror := createTestMirror(repo, "http://example.com/test/repo.git") + task := createTestMirrorTask(mirror, types.MirrorLfsSyncStart) + + pointerGroups := [][]*types.Pointer{ + {{Oid: "oid1", Size: 1024, DownloadURL: "http://example.com/file"}}, + } + + err := worker.downloadAndUploadLFSFiles(ctx, task, mirror, pointerGroups, repo) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled), + "expected context.Canceled, got %v", err) + _ = mocks // unused mocks but needed for worker setup +} + +func TestLfsSyncWorker_DownloadAndUploadSmallFile_ContextCanceled(t *testing.T) { + worker, _ := newTestLfsSyncWorker(t, "") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + repo := createTestRepository() + pointer := &types.Pointer{ + Oid: "oid1", + Size: 1024, + DownloadURL: "http://example.com/file", + } + + err := worker.downloadAndUploadSmallFile(ctx, repo, pointer, "test-key") + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled), + "expected context.Canceled, got %v", err) +} diff --git a/mirror/manager/lfs_worker_manager.go b/mirror/manager/lfs_worker_manager.go index ae0bce7ab..a104cd086 100644 --- a/mirror/manager/lfs_worker_manager.go +++ b/mirror/manager/lfs_worker_manager.go @@ -29,7 +29,6 @@ func InitManger(cfg *config.Config) error { once.Do(func() { manager = &Manager{ workerNumber: cfg.Mirror.WorkerNumber, - taskChan: make(chan database.MirrorTask), priorityTaskChan: make(chan database.MirrorTask), mirrorTaskStore: database.NewMirrorTaskStore(), config: cfg, @@ -42,7 +41,6 @@ func InitManger(cfg *config.Config) error { type Manager struct { config *config.Config - taskChan chan database.MirrorTask priorityTaskChan chan database.MirrorTask workerNumber int workers map[int]*Worker @@ -75,7 +73,6 @@ func (m *Manager) StopWorker(id int) error { if worker, ok := m.workers[id]; ok { worker.cancel() delete(m.workers, id) - m.conChan <- id } else { return fmt.Errorf("worker %d not found", id) } @@ -83,21 +80,20 @@ func (m *Manager) StopWorker(id int) error { return nil } -func (m *Manager) StopWorkerByMirrorID(mirrorID int64) (bool, error) { +func (m *Manager) StopWorkerByTaskID(taskID int64) (bool, error) { var found bool m.mu.Lock() defer m.mu.Unlock() for id, worker := range m.workers { - if worker.RunningTask.MirrorID == mirrorID { + if worker.RunningTask != nil && worker.RunningTask.ID == taskID { found = true worker.cancel() delete(m.workers, id) - m.conChan <- id } } if !found { - return false, fmt.Errorf("worker for mirror %d not found", mirrorID) + return false, fmt.Errorf("worker for mirror %d not found", taskID) } return true, nil @@ -115,9 +111,7 @@ func (m *Manager) ReRun(id int, mt *database.MirrorTask) error { } m.mu.Unlock() - go func() { - m.priorityTaskChan <- *mt - }() + m.priorityTaskChan <- *mt return nil } @@ -135,14 +129,12 @@ func (m *Manager) Start() { m.conChan <- i } - go m.dispatcher() - for id := range m.conChan { select { case mt := <-m.priorityTaskChan: go m.startWorker(id, &mt) - case mt := <-m.taskChan: - go m.startWorker(id, &mt) + default: + go m.claimAndStartWorker(id) } } } @@ -151,6 +143,7 @@ func (m *Manager) startWorker(id int, mt *database.MirrorTask) { lfsSyncWorker, err := mirror.NewLFSSyncWorker(m.config, id) if err != nil { slog.Error("failed to create lfs sync worker", slog.Any("error", err)) + m.conChan <- id return } @@ -159,8 +152,17 @@ func (m *Manager) startWorker(id int, mt *database.MirrorTask) { lfsSyncWorker.SetContext(ctx) m.mu.Lock() + + currentTask, err := m.mirrorTaskStore.FindByID(context.Background(), mt.ID) + if err == nil && currentTask.Status == types.MirrorCanceled { + m.mu.Unlock() + cancel() + m.conChan <- id + return + } + for workerID, worker := range m.workers { - if worker.RunningTask.MirrorID == mt.MirrorID { + if worker.RunningTask.ID == mt.ID { slog.Warn("worker for mirror is running, cancel it", slog.Any("mirrorID", mt.MirrorID), slog.Any("workerID", workerID)) worker.cancel() delete(m.workers, workerID) @@ -184,25 +186,23 @@ func (m *Manager) startWorker(id int, mt *database.MirrorTask) { m.conChan <- id } -func (m *Manager) dispatcher() { - for { - ctx := context.Background() - task, err := m.mirrorTaskStore.GetHighestPriorityByTaskStatus(ctx, expectedMirrorTaskStatus) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - slog.Info("no tasks to dispatch, sleep 5s") - time.Sleep(5 * time.Second) - continue - } - slog.Error("failed to get task from db", slog.Any("error", err)) - time.Sleep(5 * time.Second) - continue +func (m *Manager) claimAndStartWorker(id int) { + ctx := context.Background() + task, err := m.mirrorTaskStore.GetHighestPriorityByTaskStatus(ctx, expectedMirrorTaskStatus) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + slog.Error("failed to claim task from db", slog.Any("error", err)) } - m.taskChan <- task + time.Sleep(5 * time.Second) + m.conChan <- id + return } + m.startWorker(id, &task) } func (m *Manager) RunningTasks() map[int]database.MirrorTask { + m.mu.Lock() + defer m.mu.Unlock() tasks := make(map[int]database.MirrorTask) for id, worker := range m.workers { tasks[id] = *worker.RunningTask diff --git a/mirror/manager/lfs_worker_manager_test.go b/mirror/manager/lfs_worker_manager_test.go new file mode 100644 index 000000000..dd81dc03b --- /dev/null +++ b/mirror/manager/lfs_worker_manager_test.go @@ -0,0 +1,312 @@ +package manager + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/mirror" +) + +func newTestManager(workerNum int) *Manager { + if workerNum <= 0 { + workerNum = 3 + } + return &Manager{ + workerNumber: workerNum, + priorityTaskChan: make(chan database.MirrorTask), + conChan: make(chan int, workerNum), + workers: make(map[int]*Worker), + } +} + +type testWorker struct { + ctx context.Context + runCh chan struct{} // signals when Run is called and blocks until test releases + doneCh chan struct{} // closed when Run should complete +} + +func (w *testWorker) SetContext(ctx context.Context) { + w.ctx = ctx +} + +func (w *testWorker) Run(mt *database.MirrorTask) { + w.runCh <- struct{}{} + <-w.doneCh +} + +func newTestWorker() *testWorker { + return &testWorker{ + runCh: make(chan struct{}, 1), + doneCh: make(chan struct{}), + } +} + +func TestRunningTasks_HoldsLock(t *testing.T) { + m := newTestManager(3) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + m.workers[1] = &Worker{ + ID: 1, + ctx: ctx, + cancel: cancel, + RunningTask: &database.MirrorTask{ID: 100, MirrorID: 10}, + } + m.workers[2] = &Worker{ + ID: 2, + ctx: ctx, + cancel: cancel, + RunningTask: &database.MirrorTask{ID: 200, MirrorID: 20}, + } + + // Concurrent reads should not panic + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + tasks := m.RunningTasks() + assert.NotNil(t, tasks) + }() + } + + // Concurrent writes + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 50; j++ { + m.mu.Lock() + m.workers[3] = &Worker{ + ID: 3, + ctx: ctx, + cancel: cancel, + RunningTask: &database.MirrorTask{ID: int64(j)}, + } + m.mu.Unlock() + } + }() + + wg.Wait() + + tasks := m.RunningTasks() + assert.Len(t, tasks, 3) +} + +func TestStopWorker_RemovesWorkerFromMap(t *testing.T) { + m := newTestManager(3) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tw := newTestWorker() + m.workers[1] = &Worker{ + ID: 1, + ctx: ctx, + cancel: cancel, + Worker: tw, + RunningTask: &database.MirrorTask{ID: 100, MirrorID: 10}, + } + + // Verify worker exists before stop + tasks := m.RunningTasks() + assert.Len(t, tasks, 1) + + err := m.StopWorker(1) + assert.NoError(t, err) + + // Verify worker is removed from map + tasks = m.RunningTasks() + assert.Len(t, tasks, 0) + + // Verify stopping non-existent worker returns error + err = m.StopWorker(99) + assert.Error(t, err) + assert.Contains(t, err.Error(), "worker 99 not found") +} + +func TestStopWorkerByTaskID_HandlesNilRunningTask(t *testing.T) { + m := newTestManager(3) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Worker with nil RunningTask should not cause panic + m.workers[1] = &Worker{ + ID: 1, + ctx: ctx, + cancel: cancel, + RunningTask: nil, + } + + // This should not panic even with nil RunningTask + found, err := m.StopWorkerByTaskID(10) + assert.False(t, found) + assert.Error(t, err) + // Worker should still be in map (no match, no delete) + assert.Len(t, m.workers, 1) +} + +func TestStopWorkerByTaskID_FindsAndRemovesWorker(t *testing.T) { + m := newTestManager(3) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tw := newTestWorker() + tw2 := newTestWorker() + + m.workers[1] = &Worker{ + ID: 1, + ctx: ctx, + cancel: cancel, + Worker: tw, + RunningTask: &database.MirrorTask{ID: 100, MirrorID: 10}, + } + m.workers[2] = &Worker{ + ID: 2, + ctx: ctx, + cancel: cancel, + Worker: tw2, + RunningTask: &database.MirrorTask{ID: 200, MirrorID: 20}, + } + + // Stop worker for task 100 + found, err := m.StopWorkerByTaskID(100) + assert.True(t, found) + assert.NoError(t, err) + + // Worker 1 should be gone, worker 2 should remain + tasks := m.RunningTasks() + assert.Len(t, tasks, 1) + _, exists := tasks[2] + assert.True(t, exists) +} + +func TestStopWorker_DoesNotDoubleSendToConChan(t *testing.T) { + m := newTestManager(3) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tw := newTestWorker() + m.workers[1] = &Worker{ + ID: 1, + ctx: ctx, + cancel: cancel, + Worker: tw, + RunningTask: &database.MirrorTask{ID: 100, MirrorID: 10}, + } + + // Drain conChan (it starts empty since we created manager directly) + // Simulate startWorker's end-of-life send by starting a goroutine + go func() { + // Simulate the cleanup at end of startWorker: m.conChan <- id + time.Sleep(50 * time.Millisecond) + m.conChan <- 1 + }() + + // Call StopWorker - it should NOT send to conChan + err := m.StopWorker(1) + assert.NoError(t, err) + + // Wait for goroutine to send + select { + case id := <-m.conChan: + assert.Equal(t, 1, id) // Exactly one ID received + case <-time.After(500 * time.Millisecond): + t.Fatal("timeout waiting for conChan") + } + + // Verify no second send + select { + case id := <-m.conChan: + t.Fatalf("unexpected second send to conChan: %d", id) + case <-time.After(100 * time.Millisecond): + // Expected - no second send + } +} + +func TestRunningTasks_EmptyMap(t *testing.T) { + m := newTestManager(3) + tasks := m.RunningTasks() + assert.Empty(t, tasks) +} + +func TestConChan_CapacityPreserved(t *testing.T) { + m := newTestManager(5) + + // Fill conChan with initial IDs + for i := 1; i <= 5; i++ { + m.conChan <- i + } + + // Consume and verify all IDs are present + ids := make(map[int]bool) + for i := 0; i < 5; i++ { + select { + case id := <-m.conChan: + ids[id] = true + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for conChan") + } + } + + assert.Len(t, ids, 5) + for i := 1; i <= 5; i++ { + assert.True(t, ids[i], "missing id %d", i) + } +} + +func TestManager_RunningTasks_ConcurrentStopWorker(t *testing.T) { + m := newTestManager(5) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Add multiple workers + for i := 1; i <= 5; i++ { + tw := newTestWorker() + m.workers[i] = &Worker{ + ID: i, + ctx: ctx, + cancel: cancel, + Worker: tw, + RunningTask: &database.MirrorTask{ID: int64(i * 100), MirrorID: int64(i * 10)}, + } + } + + var wg sync.WaitGroup + // Concurrent stops and reads + for i := 0; i < 20; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _ = m.StopWorker(idx%5 + 1) + }(i) + + wg.Add(1) + go func() { + defer wg.Done() + _ = m.RunningTasks() + }() + } + + wg.Wait() + // If we reach here without panic, the lock is correctly preventing concurrent map access +} + +// Ensure Worker implements mirror.LFSSyncWorker +var _ mirror.LFSSyncWorker = (*testWorker)(nil) + +// Ensure testWorker satisfies the interface +func TestTestWorkerSatisfiesInterface(t *testing.T) { + var w mirror.LFSSyncWorker = newTestWorker() + require.NotNil(t, w) +} diff --git a/mirror/reposyncer/repo.go b/mirror/reposyncer/repo.go index cbfba5e8b..6959057f9 100644 --- a/mirror/reposyncer/repo.go +++ b/mirror/reposyncer/repo.go @@ -205,21 +205,9 @@ func (w *RepoSyncWorker) handleTask( } } - mtFSM := database.NewMirrorTaskWithFSM(mt) - canContinue := mtFSM.SubmitEvent(ctx, statusAction) - if !canContinue { - slog.Error( - "failed to transition to next status", - slog.Any("before status", mt.Status), - slog.Any("action", statusAction), - ) - return - } - mt.Status = types.MirrorTaskStatus(mtFSM.Current()) - repoSyncStatus := common.MirrorTaskStatusToRepoStatus(mt.Status) - _, err = w.mirrorTaskStore.UpdateStatusAndRepoSyncStatus(ctx, *mt, repoSyncStatus) + _, err = w.mirrorTaskStore.UpdateStatusAndRepoSyncStatus(ctx, *mt, statusAction) if err != nil { - slog.Error("failed to update mirror task status and repository status", slog.Any("error", err)) + slog.Error("failed to update mirror task status and repository status", slog.Any("error", err), slog.Any("action", statusAction)) } }