Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【AutoParallelism】Add refined_ops_patterns in refined-reompute #58533

Merged
merged 8 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@ enum Mode {
HETER = 4; // support XPU and GPU computing server
}

message RefinedOpsPattern {
repeated string main_ops = 1;
optional int32 num = 2 [default = 0];
repeated string pre_ops = 3;
repeated string suf_ops = 4;
}

message RecomputeConfig {
repeated string checkpoints = 1;
optional bool enable_offload = 2 [ default = false ];
repeated int32 checkpoint_shape = 3;
optional bool enable_tuning = 4 [ default = false ]; // incubate for auto parallel
repeated RefinedOpsPattern refined_ops_patterns = 5;
}

message ShardingConfig {
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", [])
set_field_default_config(RECOMPUTE, "no_recompute_segments", [])
set_field_default_config(RECOMPUTE, "sr", 0)
set_field_default_config(RECOMPUTE, "refined_ops_patterns", []) # List[Dict]
set_field_default_config(RECOMPUTE, "enable_tuning", False)

#########################################
Expand Down
58 changes: 56 additions & 2 deletions python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ def __impl__(*args, **kwargs):
is_strict_auto = wrap_decorator(__non_auto_func_called__)


def get_repeated_msg_dict(msg):
res_list = []
for item in msg:
fields = item.DESCRIPTOR.fields
res_dict = {}
for f in fields:
v = getattr(item, f.name)
if (
f.label
== google.protobuf.descriptor.FieldDescriptor.LABEL_REPEATED
):
v = list(v)
res_dict[f.name] = v
res_list.append(res_dict)
return res_list


def get_msg_dict(msg):
res_dict = {}
fields = msg.DESCRIPTOR.fields
Expand All @@ -52,11 +69,40 @@ def get_msg_dict(msg):
# I guess the type or value of protobuf item is NULL when
# dealloc.
if f.label == google.protobuf.descriptor.FieldDescriptor.LABEL_REPEATED:
v = list(v)
if (
f.type
!= google.protobuf.descriptor.FieldDescriptor.TYPE_MESSAGE
):
v = list(v)
else:
v = get_repeated_msg_dict(v)
res_dict[f.name] = v
return res_dict


def assign_repeated_msg(msg, config):
for key in config:
new_item = msg.add()
fields = new_item.DESCRIPTOR.fields
for f in fields:
if key == f.name:
# LABEL_OPTIONAL = 1
# LABEL_REPEATED = 3
# LABEL_REQUIRED = 2
if f.label == 3:
if config[f.name] is not None:
new_item = getattr(msg, f.name)
if (
f.type
!= google.protobuf.descriptor.FieldDescriptor.TYPE_MESSAGE
):
new_item.extend(config[f.name])
else:
assign_configs_value(new_item, config[f.name])
elif f.label == 1 or f.label == 2:
setattr(msg, f.name, config[f.name])


def assign_configs_value(msg, config):
fields = msg.DESCRIPTOR.fields
for key in config:
Expand All @@ -67,7 +113,15 @@ def assign_configs_value(msg, config):
# LABEL_REQUIRED = 2
if f.label == 3:
if config[f.name] is not None:
getattr(msg, f.name).extend(config[f.name])
new_item = getattr(msg, f.name)
# deal with repeated message
if (
f.type
!= google.protobuf.descriptor.FieldDescriptor.TYPE_MESSAGE
):
new_item.extend(config[f.name])
else:
assign_repeated_msg(new_item, config[f.name])
elif f.label == 1 or f.label == 2:
setattr(msg, f.name, config[f.name])

Expand Down
119 changes: 110 additions & 9 deletions python/paddle/distributed/passes/auto_parallel_recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@
set_dist_op_desc_original_id,
set_var_dist_attr,
)
from ..utils.log_utils import get_logger
from .pass_base import PassBase, register_pass

logger = get_logger(logging.INFO)


class RecomputeState(ProgramStats):
def __init__(self, block, ops):
Expand Down Expand Up @@ -282,54 +285,152 @@ def _check_self(self):
def _check_conflict(self, other_pass):
return True

def get_ops_per_device(self, ops, all_ops_process_meshs, sr=0):
"""
Get ops and op_names of each process mesh excluding ops within the first "sr" chunks
"""

def reset_recomupte_op(op):
if is_recompute_op(op) or is_recompute_exclude_op(op):
op._set_attr("op_namescope", "")

all_process_meshes_count = len(all_ops_process_meshs)
ops_of_stages = [[] for _ in range(all_process_meshes_count)]
op_names_of_stages = [[] for _ in range(all_process_meshes_count)]
pushed_ops_count = 0
reset_ops_count = 0
chunk_id = 0
for op_id, op in enumerate(ops):
if chunk_id // all_process_meshes_count < sr:
reset_ops_count += 1
reset_recomupte_op(op)
if (
op_id < len(ops) - 1
and op.dist_attr.process_mesh
!= ops[op_id + 1].dist_attr.process_mesh
):
chunk_id += 1
if chunk_id // all_process_meshes_count < sr:
continue

for id, process_mesh in enumerate(all_ops_process_meshs):
if op.dist_attr.process_mesh == process_mesh:
pushed_ops_count += 1
ops_of_stages[id].append(op)
op_names_of_stages[id].append(op.type)
assert (
len(ops) == reset_ops_count + pushed_ops_count
), "The sum of pushed_ops_count and reset_ops_count must be the same as lenght of ops, but the sum is {} while lenght of ops is {}".format(
reset_ops_count + pushed_ops_count, len(ops)
)
return ops_of_stages, op_names_of_stages

def _apply_single_impl(self, main_program, startup_program, context):
loss = self.get_attr("loss")
no_grad_set = self.get_attr("no_grad_set")
no_recompute_segments = self.get_attr("no_recompute_segments")
self._dist_context = self.get_attr("dist_context")
self._sr = self.get_attr("sr", 0)
self._refined_ops_patterns = self.get_attr("refined_ops_patterns", [])

# 0. get op_path which is related to loss
main_block = main_program.global_block()
op_path = _find_op_path(main_program, loss, no_grad_set)

# 1. build recompute state
# 1. mark exclude ops for refined-reompute according to ops-patterns(mainly linear and flash_attn)
# 1.1 get all process_meshs in op_path
all_ops_process_meshs = []
for op in op_path:
if op.dist_attr.process_mesh not in all_ops_process_meshs:
all_ops_process_meshs.append(op.dist_attr.process_mesh)

# 1.2 get ops_devices and op_names_devices
ops_devices, op_names_devices = self.get_ops_per_device(
op_path, all_ops_process_meshs, self._sr
)
all_ops_len = len(op_path)
all_exclude_ops_ids = [[] for _ in op_names_devices]
# 1.3 find exclude ops for refined-reompute according to ops-patterns
for refined_ops_pattern in self._refined_ops_patterns:
num = refined_ops_pattern['num']
num = (
num if num >= 0 else all_ops_len
) # 'num == -1' represents to all ops
main_ops = refined_ops_pattern['main_ops']
pre_ops = refined_ops_pattern['pre_ops']
suf_ops = refined_ops_pattern['suf_ops']
main_start_id = len(pre_ops)
main_ops_len = len(main_ops)
pattern_ops = pre_ops + main_ops + suf_ops
pattern_ops_len = len(pattern_ops)

for id, op_names_device in enumerate(op_names_devices):
pattern_count = 0
ops_len_device = len(op_names_device)
for i in range(ops_len_device - pattern_ops_len + 1):
if (
op_names_device[i : i + pattern_ops_len] == pattern_ops
and pattern_count < num
):
pattern_count += 1
all_exclude_ops_ids[id].extend(
list(
range(
i + main_start_id,
i + main_start_id + main_ops_len,
)
)
)
logger.info(
f"The excluded ops in recompute segments are:\n{all_exclude_ops_ids}"
)
# 1.4 mark exclude ops in exclude_ops_ids
for id, exclude_ops_ids in enumerate(all_exclude_ops_ids):
for op_id in exclude_ops_ids:
if is_recompute_op(ops_devices[id][op_id]):
rc_mark_str = ops_devices[id][op_id].attr("op_namescope")
ops_devices[id][op_id]._set_attr(
"op_namescope", rc_mark_str + "_exclude_rc"
)

# 2. build recompute state
rc_state = RecomputeState(main_block, op_path)
if not rc_state.is_recompute():
return

# 2. get the segments to be recomputed
# 3. get the segments to be recomputed
rc_state.modify_forward_desc_for_recompute(self._dist_context)
rc_state.build_states()
segments = rc_state.get_recompute_segments(no_recompute_segments)
if segments == []:
return

for i, (idx1, idx2) in enumerate(segments):
logging.info(f"recompute segment[{i + 1}/{len(segments)}]")
logging.info(
logger.debug(f"recompute segment[{i + 1}/{len(segments)}]")
logger.debug(
"segment start op: [{}]: [{}] [{}]".format(
rc_state.ops[idx1].type,
rc_state.ops[idx1].input_arg_names,
rc_state.ops[idx1].output_arg_names,
)
)
logging.info(
logger.debug(
"segment end op: [{}]: [{}] [{}]".format(
rc_state.ops[idx2 - 1].type,
rc_state.ops[idx2 - 1].input_arg_names,
rc_state.ops[idx2 - 1].output_arg_names,
)
)

# 3. get vars that should be hold in memory
# 4. get vars that should be hold in memory
# list of var_names
vars_should_be_hold = []
for segment in segments:
vars_should_be_hold.extend(
rc_state.get_out_of_subgraph_vars(segment[0], segment[1])
)
cross_vars = set(vars_should_be_hold) - set(rc_state.checkpoints)
logging.info(
logger.debug(
"found [{}] vars which cross recompute segment: [{}],"
"better checkpoints might be set to reduce those vars".format(
len(cross_vars), cross_vars
Expand All @@ -341,7 +442,7 @@ def _apply_single_impl(self, main_program, startup_program, context):
set(vars_should_be_hold) | set(rc_state.checkpoints)
)

# 4. get the fwd ops desc to be recomputed.
# 5. get the fwd ops desc to be recomputed.
var_name_dict = {} # varname --> varname.subprog_XXX
ckpt_ops_dict = {} # ckpt_op_id --> segment_descs
buffer_block = main_block.program._create_block()
Expand Down Expand Up @@ -412,7 +513,7 @@ def _apply_single_impl(self, main_program, startup_program, context):
ckpt_op = op_path[segment[1] - 1]
ckpt_ops_dict[ckpt_op.desc.original_id()] = [True, segment_descs]

# 5. insert recomputed fwd ops into backward parse
# 6. insert recomputed fwd ops into backward parse
ops = main_block.ops
loss_op = get_loss_op(main_block)
loss_op_idx = _find_op_index(main_block, loss_op)
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/distributed/passes/pipeline_pass_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import paddle
from paddle.base import core

from ..utils.log_utils import get_logger
from .pass_base import PassBase
from .pass_utils import set_skip_gc_vars

logger = get_logger(logging.INFO)


class PipelinePassBase(PassBase):
def __init__(self):
Expand Down Expand Up @@ -54,6 +58,10 @@ def _apply_single_impl(self, main_program, startup_program, context):
to implement two interfaces above, 'create_job_list' and 'partial_programs'.
"""
job_types, sub_programs = self._partial_programs(main_program)
for i in range(len(job_types)):
logger.debug(
f"sub_program type: {job_types[i]}, sum_program:\n{sub_programs[i]}"
)
jobs = self._create_job_list()

type_to_program = set_skip_gc_vars(
Expand Down
12 changes: 11 additions & 1 deletion test/distributed_passes/test_auto_parallel_recompute_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,17 @@ def init(self):
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.recompute = True
dist_strategy.recompute_configs = {"checkpoints": ["tmp_3", "tmp_6"]}
dist_strategy.recompute_configs = {
"checkpoints": ["tmp_3", "tmp_6"],
"refined_ops_patterns": [
{
"main_ops": ["matmul_v2", "elementwise_add"],
"num": -1,
"pre_ops": [],
"suf_ops": [],
}
],
}
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)

Expand Down