-
Notifications
You must be signed in to change notification settings - Fork 46
support branch parallel for evoformer #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| import torch | ||
| import torch.distributed as dist | ||
| from torch._C._distributed_c10d import BroadcastOptions, AllreduceOptions, ReduceOp | ||
| from .comm_group import scg | ||
|
|
||
| def broadcast(tensor, src): | ||
| """ broadcast tensor from src rank in bp group """ | ||
| if scg.get_bp_world_size() == 1: | ||
| return tensor | ||
|
|
||
| assert src in [0, 1], "Branch Parallel is only support bp_degree=2 now!" | ||
|
|
||
| group = scg.get_bp_group() | ||
|
|
||
| opts = BroadcastOptions() | ||
| opts.rootRank = src | ||
| opts.rootTensor = 0 | ||
| work = group.broadcast([tensor], opts) | ||
| work.wait() | ||
|
|
||
| def all_reduce(tensor): | ||
| """ allreduce a tensor in bp group """ | ||
| if scg.get_bp_world_size() == 1: | ||
| return tensor | ||
|
|
||
| group = scg.get_bp_group() | ||
|
|
||
| opts = AllreduceOptions() | ||
| opts.reduceOp = ReduceOp.SUM | ||
|
|
||
| work = group.allreduce([tensor], opts) | ||
| work.wait() | ||
|
|
||
| return tensor | ||
|
|
||
| class SyncEvoformerResults(torch.autograd.Function): | ||
| """ A PyLayer Op broadcast gradient in backward stage """ | ||
| @staticmethod | ||
| def forward(ctx, outer, msa, pair, training): | ||
| broadcast(outer, 0) | ||
| if scg.get_bp_rank_in_group() == 1: | ||
| if training: | ||
| pair = pair + outer | ||
| else: | ||
| pair += outer | ||
| broadcast(pair, 1) | ||
| broadcast(msa, 0) | ||
| return msa.clone(), pair.clone() | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, *grad_output): | ||
| msa_grad = grad_output[0] | ||
| pair_grad = grad_output[1] | ||
|
|
||
| if scg.get_bp_rank_in_group() == 0: | ||
| pair_grad = torch.zeros_like(pair_grad) | ||
|
|
||
| outer_grad = pair_grad.clone() | ||
| broadcast(outer_grad, 1) | ||
|
|
||
| return outer_grad.clone(), msa_grad.clone(), pair_grad.clone() | ||
|
|
||
| def sync_evoformer_results(outer, msa, pair, training): | ||
| """ a warpper for boradcast gradient in backward stage """ | ||
| if scg.get_bp_world_size() == 1: | ||
| return msa, pair | ||
|
|
||
| if torch.is_grad_enabled() and outer.requires_grad and msa.requires_grad and pair.requires_grad: | ||
| return msa, pair | ||
|
|
||
| msa, pair = SyncEvoformerResults.apply(outer, msa, pair, training) | ||
|
|
||
| return msa, pair | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Communication group manager | ||
| """ | ||
|
|
||
| import numpy as np | ||
| import torch.distributed as dist | ||
|
|
||
| def ensure_divisibility(numerator, denominator): | ||
| """Ensure that numerator is divisible by the denominator.""" | ||
| assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator) | ||
|
|
||
| class SingletonCommunicationGroup(object): | ||
| """ A singleton communication group for bp, dap, ddp hybrid parallel. """ | ||
| def __init__(self): | ||
| self.initialized = False | ||
|
|
||
| def init_group(self, bp_degree=1, dap_degree=1, dap_comm_sync=True): | ||
| """ init the hybrid parallel, it will auto calculate ddp_degree using bp_degree, dap_degree and world_size """ | ||
| assert self.initialized == False, "Communication group is already initialized!" | ||
|
|
||
| # check valid config | ||
| world_size = dist.get_world_size() | ||
| rank = dist.get_rank() | ||
| inner_degree = bp_degree * dap_degree | ||
| ensure_divisibility(world_size, bp_degree) | ||
| ensure_divisibility(world_size, dap_degree) | ||
| ensure_divisibility(world_size, inner_degree) | ||
|
|
||
| self.dp_degree = world_size // inner_degree | ||
| self.bp_degree = bp_degree | ||
| self.dap_degree = dap_degree | ||
| self.dap_comm_sync = dap_comm_sync | ||
|
|
||
| arr = np.arange(0, world_size).reshape(self.dp_degree, self.dap_degree, self.bp_degree) | ||
|
|
||
| # build bp group | ||
| bp_arr = arr.transpose((0, 1, 2)).reshape(-1, self.bp_degree) | ||
| for i in range(world_size // self.bp_degree): | ||
| ranks = bp_arr[i].tolist() | ||
| group = dist.new_group(ranks) | ||
| print('> bp ranks:', ranks, 'bp group:', group) | ||
| if rank in ranks: | ||
| self.bp_group = group | ||
|
|
||
| # build dap group | ||
| dap_arr = arr.transpose((0, 2, 1)).reshape(-1, self.dap_degree) | ||
| for i in range(world_size // self.dap_degree): | ||
| ranks = dap_arr[i].tolist() | ||
| group = dist.new_group(ranks) | ||
| print('> dap ranks:', ranks, 'dap group:', group) | ||
| if rank in ranks: | ||
| self.dap_group = group | ||
|
|
||
| # build dp group | ||
| dp_arr = arr.transpose((1, 2, 0)).reshape(-1, self.dp_degree) | ||
| for i in range(world_size // self.dp_degree): | ||
| ranks = dp_arr[i].tolist() | ||
| group = dist.new_group(ranks) | ||
| print('> dp ranks:', ranks, 'dp group:', group) | ||
| if rank in ranks: | ||
| self.dp_group = group | ||
|
|
||
| self.initialized = True | ||
| if dist.get_rank() == 0: | ||
| print('> initialize branch parallel with size {}'.format(self.bp_degree)) | ||
| print('> initialize dynamic axial parallel with size {}'.format(self.dap_degree)) | ||
| print('> initialize data parallel with size {}'.format(self.dp_degree)) | ||
|
|
||
| def dap_is_comm_sync(self): | ||
| """ get dap whether use sync or async communication """ | ||
| return self.dap_comm_sync | ||
|
|
||
| def bp_is_initialized(self): | ||
| """ get bp commnication group whether is initialized """ | ||
| return self.initialized | ||
|
|
||
| def dap_is_initialized(self): | ||
| """ get dap commnication group whether is initialized """ | ||
| return self.initialized | ||
|
|
||
| def dp_is_initialized(self): | ||
| """ get dp commnication group whether is initialized """ | ||
| return self.initialized | ||
|
|
||
| def is_initialized(self): | ||
| """ get hybird commnication group whether is initialized """ | ||
| return self.initialized | ||
|
|
||
| def get_bp_group(self): | ||
| """ get bp commnication group """ | ||
| assert self.initialized == True, "bp group is not initialized!" | ||
| return self.bp_group | ||
|
|
||
| def get_bp_rank(self): | ||
| """ get bp rank id in global group """ | ||
| if not self.initialized: | ||
| return 0 | ||
| return self.bp_group.rank | ||
|
|
||
| def get_bp_rank_in_group(self): | ||
| """ get bp rank id in bp group """ | ||
| if not self.initialized: | ||
| return -1 | ||
| return dist.get_rank(self.bp_group) | ||
|
|
||
| def get_bp_world_size(self): | ||
| """ get bp world size in bp group """ | ||
| if not self.initialized: | ||
| return 1 | ||
| return dist.get_world_size(self.bp_group) | ||
|
|
||
| def get_dap_group(self): | ||
| """ get dap commnication group """ | ||
| assert self.initialized == True, "dap group is not initialized!" | ||
| return self.dap_group | ||
|
|
||
| def get_dap_rank(self): | ||
| """ get dap rank id in global group """ | ||
| if not self.initialized: | ||
| return 0 | ||
| return self.dap_group.rank | ||
|
|
||
| def get_dap_rank_in_group(self): | ||
| """ get dap rank id in dap group """ | ||
| if not self.initialized: | ||
| return -1 | ||
| return dist.get_rank(self.dap_group) | ||
|
|
||
| def get_dap_world_size(self): | ||
| """ get dap world size in dap group """ | ||
| if not self.initialized: | ||
| return 1 | ||
| return dist.get_world_size(self.dap_group) | ||
|
|
||
| def get_dp_group(self): | ||
| """ get ddp commnication group """ | ||
| assert self.initialized == True, "dp group is not initialized!" | ||
| return self.dp_group | ||
|
|
||
| def get_dp_rank(self): | ||
| """ get ddp rank id in global group """ | ||
| if not self.initialized: | ||
| return 0 | ||
| return self.dp_group.rank | ||
|
|
||
| def get_dp_rank_in_group(self): | ||
| """ get ddp rank id in ddp group """ | ||
| if not self.initialized: | ||
| return -1 | ||
| rank = dist.get_rank() | ||
| return dist.get_rank(self.dp_group) | ||
|
|
||
| def get_dp_world_size(self): | ||
| """ get ddp world size in ddp group """ | ||
| if not self.initialized: | ||
| return 1 | ||
| return dist.get_world_size(self.dp_group) | ||
|
|
||
| scg = SingletonCommunicationGroup() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,8 @@ | |
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from .comm_group import scg | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -137,6 +139,9 @@ def distributed_init(args): | |
| if torch.cuda.is_available(): | ||
| dist.all_reduce(torch.zeros(1).cuda()) | ||
|
|
||
| scg.init_group(bp_degree=args.bp_degree, dap_degree=1) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will this affect the normal c10d, no_c10d mode?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not quite sure about this question. This PR is just to show how to use BP, not to merge this PR into UniCore.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, I may miss some contexts. |
||
|
|
||
| args.dp_rank = scg.get_dp_rank_in_group() if torch.distributed.get_world_size() > 1 else 0 | ||
| args.distributed_rank = torch.distributed.get_rank() | ||
|
|
||
| if is_master(args): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,6 +49,7 @@ def main(args) -> None: | |
| ), "Must specify batch size either with --batch-size" | ||
| metrics.reset() | ||
|
|
||
| args.seed += args.dp_rank | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this change needed?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When using a hybrid distributed parallel strategy, such as DP-BP, the parameters and data in the same BP group need to be the same, so the seeds need to be the same. |
||
| np.random.seed(args.seed) | ||
| torch.manual_seed(args.seed) | ||
| if torch.cuda.is_available(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like the functions in this file are better to be in Uni-Fold.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same problem as above. It is necessary to design the code together and merge them into UniFold and UniCore respectively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dptech-corp/Uni-Fold#73