Skip to content

Commit

Permalink
support pp-sharding reshard
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Sep 19, 2024
1 parent 83cecc7 commit 0c972d7
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions paddlenlp/trainer/utils/reshard/pp_reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from collections import OrderedDict

from paddle.distributed.fleet.model import PipelineParallel
Expand Down Expand Up @@ -46,6 +45,20 @@ def get_index_layer_func():
return _GLOBAL_INDEX_LAYER_FUNC


_GLOBAL_SNAME_TO_TNAME_FUNC = None


def register_sname_to_tname_func(func):
global _GLOBAL_SNAME_TO_TNAME_FUNC
_GLOBAL_SNAME_TO_TNAME_FUNC = func

Check warning on line 53 in paddlenlp/trainer/utils/reshard/pp_reshard.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/pp_reshard.py#L53

Added line #L53 was not covered by tests


def get_sname_to_tname_func():
global _GLOBAL_SNAME_TO_TNAME_FUNC
assert _GLOBAL_SNAME_TO_TNAME_FUNC is not None, "sname to tname func is not registered yet"
return _GLOBAL_SNAME_TO_TNAME_FUNC

Check warning on line 59 in paddlenlp/trainer/utils/reshard/pp_reshard.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/pp_reshard.py#L58-L59

Added lines #L58 - L59 were not covered by tests


class LayerNameScope:
"""
layer name scope for a layer, layer name of the same kind of layer will be named consecutively
Expand Down Expand Up @@ -206,6 +219,7 @@ def __init__(self):
self._segments = OrderedDict()
self._layer_to_segment = OrderedDict()
self._param_to_tname = OrderedDict()
self._wname_to_rname = OrderedDict()

Check warning on line 222 in paddlenlp/trainer/utils/reshard/pp_reshard.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/pp_reshard.py#L222

Added line #L222 was not covered by tests

def add_segment(self, start_index, end_index):
segment = PipeLineSegment(start_index, end_index)
Expand All @@ -218,19 +232,24 @@ def add_layer(self, layer_index, layer_name, param_names):
segment = self._layer_to_segment[layer_index]
segment.add_layer(layer_name, param_names)

def build_name_mapping(self):
def build_name_mapping(self, sname_to_tname=None):
for (k, segment) in self._segments.items():
for (i, layer) in segment.layers.items():
for param in layer.params.items():
(param_name, tensor_name) = param
# map to a new name
n_name = self._rename_mgr.get_new_param_name(layer.name, tensor_name)
if sname_to_tname is not None:
if param_name in sname_to_tname.keys():
self._wname_to_rname[param_name] = sname_to_tname[param_name]

Check warning on line 244 in paddlenlp/trainer/utils/reshard/pp_reshard.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/pp_reshard.py#L242-L244

Added lines #L242 - L244 were not covered by tests
# logger.info(f"{param_name} {tensor_name}=>{n_name}")
self._param_to_tname[param_name] = (tensor_name, n_name)

def map_name(self, param_name, t_name):
assert param_name in self._param_to_tname
tensor_name, n_name = self._param_to_tname[param_name]
if param_name in self._wname_to_rname:
n_name = self._wname_to_rname[param_name]

Check warning on line 252 in paddlenlp/trainer/utils/reshard/pp_reshard.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/pp_reshard.py#L251-L252

Added lines #L251 - L252 were not covered by tests
assert tensor_name == t_name
return n_name

Expand Down Expand Up @@ -261,6 +280,8 @@ def __init__(
self._index_layers()

stage_segments = self._segment()
self._sname_to_tname = get_sname_to_tname_func()(pp_model)

Check warning on line 283 in paddlenlp/trainer/utils/reshard/pp_reshard.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/pp_reshard.py#L283

Added line #L283 was not covered by tests

for (i, stage_seg) in enumerate(stage_segments):
pipe_stage = PipeLineStage()
self._stages.append(pipe_stage)
Expand All @@ -275,7 +296,7 @@ def __init__(
self._layer_name_to_stage[layer_name] = i

for stage in self._stages:
stage.build_name_mapping()
stage.build_name_mapping(self._sname_to_tname)

Check warning on line 299 in paddlenlp/trainer/utils/reshard/pp_reshard.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/pp_reshard.py#L299

Added line #L299 was not covered by tests

def _index_layers(self):
for layer_name in self._param_names_by_layer.keys():
Expand Down

0 comments on commit 0c972d7

Please sign in to comment.