Skip to content

Commit

Permalink
【AutoParallelism】Add refined_ops_patterns in refined-reompute (Paddle…
Browse files Browse the repository at this point in the history
…Paddle#58533)

* add refined-recompute support

* fix bug in recompute_pass

* fix coverage

* add ops_pattern

* add testcase

* add sr
  • Loading branch information
heavyrain-lzy authored and zeroRains committed Nov 8, 2023
1 parent 804f0b6 commit 95b6a85
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 12 deletions.
8 changes: 8 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@ enum Mode {
HETER = 4; // support XPU and GPU computing server
}

message RefinedOpsPattern {
repeated string main_ops = 1;
optional int32 num = 2 [default = 0];
repeated string pre_ops = 3;
repeated string suf_ops = 4;
}

message RecomputeConfig {
repeated string checkpoints = 1;
optional bool enable_offload = 2 [ default = false ];
repeated int32 checkpoint_shape = 3;
optional bool enable_tuning = 4 [ default = false ]; // incubate for auto parallel
repeated RefinedOpsPattern refined_ops_patterns = 5;
}

message ShardingConfig {
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", [])
set_field_default_config(RECOMPUTE, "no_recompute_segments", [])
set_field_default_config(RECOMPUTE, "sr", 0)
set_field_default_config(RECOMPUTE, "refined_ops_patterns", []) # List[Dict]
set_field_default_config(RECOMPUTE, "enable_tuning", False)

#########################################
Expand Down
58 changes: 56 additions & 2 deletions python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ def __impl__(*args, **kwargs):
is_strict_auto = wrap_decorator(__non_auto_func_called__)


def get_repeated_msg_dict(msg):
res_list = []
for item in msg:
fields = item.DESCRIPTOR.fields
res_dict = {}
for f in fields:
v = getattr(item, f.name)
if (
f.label
== google.protobuf.descriptor.FieldDescriptor.LABEL_REPEATED
):
v = list(v)
res_dict[f.name] = v
res_list.append(res_dict)
return res_list


def get_msg_dict(msg):
res_dict = {}
fields = msg.DESCRIPTOR.fields
Expand All @@ -52,11 +69,40 @@ def get_msg_dict(msg):
# I guess the type or value of protobuf item is NULL when
# dealloc.
if f.label == google.protobuf.descriptor.FieldDescriptor.LABEL_REPEATED:
v = list(v)
if (
f.type
!= google.protobuf.descriptor.FieldDescriptor.TYPE_MESSAGE
):
v = list(v)
else:
v = get_repeated_msg_dict(v)
res_dict[f.name] = v
return res_dict


def assign_repeated_msg(msg, config):
for key in config:
new_item = msg.add()
fields = new_item.DESCRIPTOR.fields
for f in fields:
if key == f.name:
# LABEL_OPTIONAL = 1
# LABEL_REPEATED = 3
# LABEL_REQUIRED = 2
if f.label == 3:
if config[f.name] is not None:
new_item = getattr(msg, f.name)
if (
f.type
!= google.protobuf.descriptor.FieldDescriptor.TYPE_MESSAGE
):
new_item.extend(config[f.name])
else:
assign_configs_value(new_item, config[f.name])
elif f.label == 1 or f.label == 2:
setattr(msg, f.name, config[f.name])


def assign_configs_value(msg, config):
fields = msg.DESCRIPTOR.fields
for key in config:
Expand All @@ -67,7 +113,15 @@ def assign_configs_value(msg, config):
# LABEL_REQUIRED = 2
if f.label == 3:
if config[f.name] is not None:
getattr(msg, f.name).extend(config[f.name])
new_item = getattr(msg, f.name)
# deal with repeated message
if (
f.type
!= google.protobuf.descriptor.FieldDescriptor.TYPE_MESSAGE
):
new_item.extend(config[f.name])
else:
assign_repeated_msg(new_item, config[f.name])
elif f.label == 1 or f.label == 2:
setattr(msg, f.name, config[f.name])

Expand Down
119 changes: 110 additions & 9 deletions python/paddle/distributed/passes/auto_parallel_recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@
set_dist_op_desc_original_id,
set_var_dist_attr,
)
from ..utils.log_utils import get_logger
from .pass_base import PassBase, register_pass

logger = get_logger(logging.INFO)


class RecomputeState(ProgramStats):
def __init__(self, block, ops):
Expand Down Expand Up @@ -282,54 +285,152 @@ def _check_self(self):
def _check_conflict(self, other_pass):
return True

def get_ops_per_device(self, ops, all_ops_process_meshs, sr=0):
"""
Get ops and op_names of each process mesh excluding ops within the first "sr" chunks
"""

def reset_recomupte_op(op):
if is_recompute_op(op) or is_recompute_exclude_op(op):
op._set_attr("op_namescope", "")

all_process_meshes_count = len(all_ops_process_meshs)
ops_of_stages = [[] for _ in range(all_process_meshes_count)]
op_names_of_stages = [[] for _ in range(all_process_meshes_count)]
pushed_ops_count = 0
reset_ops_count = 0
chunk_id = 0
for op_id, op in enumerate(ops):
if chunk_id // all_process_meshes_count < sr:
reset_ops_count += 1
reset_recomupte_op(op)
if (
op_id < len(ops) - 1
and op.dist_attr.process_mesh
!= ops[op_id + 1].dist_attr.process_mesh
):
chunk_id += 1
if chunk_id // all_process_meshes_count < sr:
continue

for id, process_mesh in enumerate(all_ops_process_meshs):
if op.dist_attr.process_mesh == process_mesh:
pushed_ops_count += 1
ops_of_stages[id].append(op)
op_names_of_stages[id].append(op.type)
assert (
len(ops) == reset_ops_count + pushed_ops_count
), "The sum of pushed_ops_count and reset_ops_count must be the same as lenght of ops, but the sum is {} while lenght of ops is {}".format(
reset_ops_count + pushed_ops_count, len(ops)
)
return ops_of_stages, op_names_of_stages

def _apply_single_impl(self, main_program, startup_program, context):
loss = self.get_attr("loss")
no_grad_set = self.get_attr("no_grad_set")
no_recompute_segments = self.get_attr("no_recompute_segments")
self._dist_context = self.get_attr("dist_context")
self._sr = self.get_attr("sr", 0)
self._refined_ops_patterns = self.get_attr("refined_ops_patterns", [])

# 0. get op_path which is related to loss
main_block = main_program.global_block()
op_path = _find_op_path(main_program, loss, no_grad_set)

# 1. build recompute state
# 1. mark exclude ops for refined-reompute according to ops-patterns(mainly linear and flash_attn)
# 1.1 get all process_meshs in op_path
all_ops_process_meshs = []
for op in op_path:
if op.dist_attr.process_mesh not in all_ops_process_meshs:
all_ops_process_meshs.append(op.dist_attr.process_mesh)

# 1.2 get ops_devices and op_names_devices
ops_devices, op_names_devices = self.get_ops_per_device(
op_path, all_ops_process_meshs, self._sr
)
all_ops_len = len(op_path)
all_exclude_ops_ids = [[] for _ in op_names_devices]
# 1.3 find exclude ops for refined-reompute according to ops-patterns
for refined_ops_pattern in self._refined_ops_patterns:
num = refined_ops_pattern['num']
num = (
num if num >= 0 else all_ops_len
) # 'num == -1' represents to all ops
main_ops = refined_ops_pattern['main_ops']
pre_ops = refined_ops_pattern['pre_ops']
suf_ops = refined_ops_pattern['suf_ops']
main_start_id = len(pre_ops)
main_ops_len = len(main_ops)
pattern_ops = pre_ops + main_ops + suf_ops
pattern_ops_len = len(pattern_ops)

for id, op_names_device in enumerate(op_names_devices):
pattern_count = 0
ops_len_device = len(op_names_device)
for i in range(ops_len_device - pattern_ops_len + 1):
if (
op_names_device[i : i + pattern_ops_len] == pattern_ops
and pattern_count < num
):
pattern_count += 1
all_exclude_ops_ids[id].extend(
list(
range(
i + main_start_id,
i + main_start_id + main_ops_len,
)
)
)
logger.info(
f"The excluded ops in recompute segments are:\n{all_exclude_ops_ids}"
)
# 1.4 mark exclude ops in exclude_ops_ids
for id, exclude_ops_ids in enumerate(all_exclude_ops_ids):
for op_id in exclude_ops_ids:
if is_recompute_op(ops_devices[id][op_id]):
rc_mark_str = ops_devices[id][op_id].attr("op_namescope")
ops_devices[id][op_id]._set_attr(
"op_namescope", rc_mark_str + "_exclude_rc"
)

# 2. build recompute state
rc_state = RecomputeState(main_block, op_path)
if not rc_state.is_recompute():
return

# 2. get the segments to be recomputed
# 3. get the segments to be recomputed
rc_state.modify_forward_desc_for_recompute(self._dist_context)
rc_state.build_states()
segments = rc_state.get_recompute_segments(no_recompute_segments)
if segments == []:
return

for i, (idx1, idx2) in enumerate(segments):
logging.info(f"recompute segment[{i + 1}/{len(segments)}]")
logging.info(
logger.debug(f"recompute segment[{i + 1}/{len(segments)}]")
logger.debug(
"segment start op: [{}]: [{}] [{}]".format(
rc_state.ops[idx1].type,
rc_state.ops[idx1].input_arg_names,
rc_state.ops[idx1].output_arg_names,
)
)
logging.info(
logger.debug(
"segment end op: [{}]: [{}] [{}]".format(
rc_state.ops[idx2 - 1].type,
rc_state.ops[idx2 - 1].input_arg_names,
rc_state.ops[idx2 - 1].output_arg_names,
)
)

# 3. get vars that should be hold in memory
# 4. get vars that should be hold in memory
# list of var_names
vars_should_be_hold = []
for segment in segments:
vars_should_be_hold.extend(
rc_state.get_out_of_subgraph_vars(segment[0], segment[1])
)
cross_vars = set(vars_should_be_hold) - set(rc_state.checkpoints)
logging.info(
logger.debug(
"found [{}] vars which cross recompute segment: [{}],"
"better checkpoints might be set to reduce those vars".format(
len(cross_vars), cross_vars
Expand All @@ -341,7 +442,7 @@ def _apply_single_impl(self, main_program, startup_program, context):
set(vars_should_be_hold) | set(rc_state.checkpoints)
)

# 4. get the fwd ops desc to be recomputed.
# 5. get the fwd ops desc to be recomputed.
var_name_dict = {} # varname --> varname.subprog_XXX
ckpt_ops_dict = {} # ckpt_op_id --> segment_descs
buffer_block = main_block.program._create_block()
Expand Down Expand Up @@ -412,7 +513,7 @@ def _apply_single_impl(self, main_program, startup_program, context):
ckpt_op = op_path[segment[1] - 1]
ckpt_ops_dict[ckpt_op.desc.original_id()] = [True, segment_descs]

# 5. insert recomputed fwd ops into backward parse
# 6. insert recomputed fwd ops into backward parse
ops = main_block.ops
loss_op = get_loss_op(main_block)
loss_op_idx = _find_op_index(main_block, loss_op)
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/distributed/passes/pipeline_pass_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import paddle
from paddle.base import core

from ..utils.log_utils import get_logger
from .pass_base import PassBase
from .pass_utils import set_skip_gc_vars

logger = get_logger(logging.INFO)


class PipelinePassBase(PassBase):
def __init__(self):
Expand Down Expand Up @@ -54,6 +58,10 @@ def _apply_single_impl(self, main_program, startup_program, context):
to implement two interfaces above, 'create_job_list' and 'partial_programs'.
"""
job_types, sub_programs = self._partial_programs(main_program)
for i in range(len(job_types)):
logger.debug(
f"sub_program type: {job_types[i]}, sum_program:\n{sub_programs[i]}"
)
jobs = self._create_job_list()

type_to_program = set_skip_gc_vars(
Expand Down
12 changes: 11 additions & 1 deletion test/distributed_passes/test_auto_parallel_recompute_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,17 @@ def init(self):
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.recompute = True
dist_strategy.recompute_configs = {"checkpoints": ["tmp_3", "tmp_6"]}
dist_strategy.recompute_configs = {
"checkpoints": ["tmp_3", "tmp_6"],
"refined_ops_patterns": [
{
"main_ops": ["matmul_v2", "elementwise_add"],
"num": -1,
"pre_ops": [],
"suf_ops": [],
}
],
}
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)

Expand Down

0 comments on commit 95b6a85

Please sign in to comment.