Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions unicore/distributed/bp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import torch.distributed as dist
from torch._C._distributed_c10d import BroadcastOptions, AllreduceOptions, ReduceOp
from .comm_group import scg

def broadcast(tensor, src):
""" broadcast tensor from src rank in bp group """
if scg.get_bp_world_size() == 1:
return tensor

assert src in [0, 1], "Branch Parallel is only support bp_degree=2 now!"

group = scg.get_bp_group()

opts = BroadcastOptions()
opts.rootRank = src
opts.rootTensor = 0
work = group.broadcast([tensor], opts)
work.wait()

def all_reduce(tensor):
""" allreduce a tensor in bp group """
if scg.get_bp_world_size() == 1:
return tensor

group = scg.get_bp_group()

opts = AllreduceOptions()
opts.reduceOp = ReduceOp.SUM

work = group.allreduce([tensor], opts)
work.wait()

return tensor

class SyncEvoformerResults(torch.autograd.Function):
""" A PyLayer Op broadcast gradient in backward stage """
@staticmethod
def forward(ctx, outer, msa, pair, training):
broadcast(outer, 0)
if scg.get_bp_rank_in_group() == 1:
if training:
pair = pair + outer
else:
pair += outer
broadcast(pair, 1)
broadcast(msa, 0)
return msa.clone(), pair.clone()

@staticmethod
def backward(ctx, *grad_output):
msa_grad = grad_output[0]
pair_grad = grad_output[1]

if scg.get_bp_rank_in_group() == 0:
pair_grad = torch.zeros_like(pair_grad)

outer_grad = pair_grad.clone()
broadcast(outer_grad, 1)

return outer_grad.clone(), msa_grad.clone(), pair_grad.clone()

def sync_evoformer_results(outer, msa, pair, training):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like the functions in this file are better to be in Uni-Fold.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same problem as above. It is necessary to design the code together and merge them into UniFold and UniCore respectively.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

""" a warpper for boradcast gradient in backward stage """
if scg.get_bp_world_size() == 1:
return msa, pair

if torch.is_grad_enabled() and outer.requires_grad and msa.requires_grad and pair.requires_grad:
return msa, pair

msa, pair = SyncEvoformerResults.apply(outer, msa, pair, training)

return msa, pair
173 changes: 173 additions & 0 deletions unicore/distributed/comm_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Communication group manager
"""

import numpy as np
import torch.distributed as dist

def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)

class SingletonCommunicationGroup(object):
""" A singleton communication group for bp, dap, ddp hybrid parallel. """
def __init__(self):
self.initialized = False

def init_group(self, bp_degree=1, dap_degree=1, dap_comm_sync=True):
""" init the hybrid parallel, it will auto calculate ddp_degree using bp_degree, dap_degree and world_size """
assert self.initialized == False, "Communication group is already initialized!"

# check valid config
world_size = dist.get_world_size()
rank = dist.get_rank()
inner_degree = bp_degree * dap_degree
ensure_divisibility(world_size, bp_degree)
ensure_divisibility(world_size, dap_degree)
ensure_divisibility(world_size, inner_degree)

self.dp_degree = world_size // inner_degree
self.bp_degree = bp_degree
self.dap_degree = dap_degree
self.dap_comm_sync = dap_comm_sync

arr = np.arange(0, world_size).reshape(self.dp_degree, self.dap_degree, self.bp_degree)

# build bp group
bp_arr = arr.transpose((0, 1, 2)).reshape(-1, self.bp_degree)
for i in range(world_size // self.bp_degree):
ranks = bp_arr[i].tolist()
group = dist.new_group(ranks)
print('> bp ranks:', ranks, 'bp group:', group)
if rank in ranks:
self.bp_group = group

# build dap group
dap_arr = arr.transpose((0, 2, 1)).reshape(-1, self.dap_degree)
for i in range(world_size // self.dap_degree):
ranks = dap_arr[i].tolist()
group = dist.new_group(ranks)
print('> dap ranks:', ranks, 'dap group:', group)
if rank in ranks:
self.dap_group = group

# build dp group
dp_arr = arr.transpose((1, 2, 0)).reshape(-1, self.dp_degree)
for i in range(world_size // self.dp_degree):
ranks = dp_arr[i].tolist()
group = dist.new_group(ranks)
print('> dp ranks:', ranks, 'dp group:', group)
if rank in ranks:
self.dp_group = group

self.initialized = True
if dist.get_rank() == 0:
print('> initialize branch parallel with size {}'.format(self.bp_degree))
print('> initialize dynamic axial parallel with size {}'.format(self.dap_degree))
print('> initialize data parallel with size {}'.format(self.dp_degree))

def dap_is_comm_sync(self):
""" get dap whether use sync or async communication """
return self.dap_comm_sync

def bp_is_initialized(self):
""" get bp commnication group whether is initialized """
return self.initialized

def dap_is_initialized(self):
""" get dap commnication group whether is initialized """
return self.initialized

def dp_is_initialized(self):
""" get dp commnication group whether is initialized """
return self.initialized

def is_initialized(self):
""" get hybird commnication group whether is initialized """
return self.initialized

def get_bp_group(self):
""" get bp commnication group """
assert self.initialized == True, "bp group is not initialized!"
return self.bp_group

def get_bp_rank(self):
""" get bp rank id in global group """
if not self.initialized:
return 0
return self.bp_group.rank

def get_bp_rank_in_group(self):
""" get bp rank id in bp group """
if not self.initialized:
return -1
return dist.get_rank(self.bp_group)

def get_bp_world_size(self):
""" get bp world size in bp group """
if not self.initialized:
return 1
return dist.get_world_size(self.bp_group)

def get_dap_group(self):
""" get dap commnication group """
assert self.initialized == True, "dap group is not initialized!"
return self.dap_group

def get_dap_rank(self):
""" get dap rank id in global group """
if not self.initialized:
return 0
return self.dap_group.rank

def get_dap_rank_in_group(self):
""" get dap rank id in dap group """
if not self.initialized:
return -1
return dist.get_rank(self.dap_group)

def get_dap_world_size(self):
""" get dap world size in dap group """
if not self.initialized:
return 1
return dist.get_world_size(self.dap_group)

def get_dp_group(self):
""" get ddp commnication group """
assert self.initialized == True, "dp group is not initialized!"
return self.dp_group

def get_dp_rank(self):
""" get ddp rank id in global group """
if not self.initialized:
return 0
return self.dp_group.rank

def get_dp_rank_in_group(self):
""" get ddp rank id in ddp group """
if not self.initialized:
return -1
rank = dist.get_rank()
return dist.get_rank(self.dp_group)

def get_dp_world_size(self):
""" get ddp world size in ddp group """
if not self.initialized:
return 1
return dist.get_world_size(self.dp_group)

scg = SingletonCommunicationGroup()
5 changes: 5 additions & 0 deletions unicore/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import torch
import torch.distributed as dist

from .comm_group import scg


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -137,6 +139,9 @@ def distributed_init(args):
if torch.cuda.is_available():
dist.all_reduce(torch.zeros(1).cuda())

scg.init_group(bp_degree=args.bp_degree, dap_degree=1)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this affect the normal c10d, no_c10d mode?
Can we make "bp" a choice, like currently c10d, no_c10d?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure about this question. This PR is just to show how to use BP, not to merge this PR into UniCore.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I may miss some contexts.


args.dp_rank = scg.get_dp_rank_in_group() if torch.distributed.get_world_size() > 1 else 0
args.distributed_rank = torch.distributed.get_rank()

if is_master(args):
Expand Down
1 change: 1 addition & 0 deletions unicore/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def add_distributed_training_args(parser):
help="number of GPUs in each node. An allreduce operation across GPUs in "
"a node is very fast. Hence, we do allreduce across GPUs in a node, "
"and gossip across different nodes")
group.add_argument('--bp-degree', default=1, type=int)
# fmt: on
return group

Expand Down
1 change: 1 addition & 0 deletions unicore_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def main(args) -> None:
), "Must specify batch size either with --batch-size"
metrics.reset()

args.seed += args.dp_rank
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change needed?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using a hybrid distributed parallel strategy, such as DP-BP, the parameters and data in the same BP group need to be the same, so the seeds need to be the same.

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
Expand Down