diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 1fb8f397e0..b30bc1eab4 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -103,21 +103,27 @@ endif() set(OPTIMIZERS adagrad - adam - approx_rowwise_adagrad - approx_rowwise_adagrad_with_weight_decay - approx_rowwise_adagrad_with_counter approx_sgd - lamb - lars_sgd - partial_rowwise_adam - partial_rowwise_lamb rowwise_adagrad - rowwise_adagrad_with_weight_decay rowwise_adagrad_with_counter rowwise_weighted_adagrad sgd) +set(CPU_ONLY_OPTIMIZERS + approx_rowwise_adagrad + approx_rowwise_adagrad_with_counter) + +set(GPU_ONLY_OPTIMIZERS + adam + lamb + lars_sgd + partial_rowwise_adam + partial_rowwise_lamb) + +set(CPU_OPTIMIZERS ${OPTIMIZERS} ${CPU_ONLY_OPTIMIZERS}) +set(GPU_OPTIMIZERS ${OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS}) +set(ALL_OPTIMIZERS ${OPTIMIZERS} ${CPU_ONLY_OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS}) + set(gen_gpu_source_files "gen_embedding_forward_dense_weighted_codegen_cuda.cu" "gen_embedding_forward_dense_unweighted_codegen_cuda.cu" @@ -137,16 +143,16 @@ set(gen_cpu_source_files set(gen_python_files ${CMAKE_BINARY_DIR}/__init__.py) -foreach(optimizer ${OPTIMIZERS}) - list(APPEND gen_gpu_host_source_files - "gen_embedding_backward_split_${optimizer}.cpp") - +foreach(optimizer ${CPU_OPTIMIZERS}) list(APPEND gen_cpu_source_files "gen_embedding_backward_split_${optimizer}_cpu.cpp") list(APPEND gen_cpu_source_files "gen_embedding_backward_${optimizer}_split_cpu.cpp") +endforeach() - list(APPEND gen_python_files "${CMAKE_BINARY_DIR}/lookup_${optimizer}.py") +foreach(optimizer ${GPU_OPTIMIZERS}) + list(APPEND gen_gpu_host_source_files + "gen_embedding_backward_split_${optimizer}.cpp") foreach(weight weighted unweighted) list(APPEND gen_gpu_source_files @@ -154,6 +160,10 @@ foreach(optimizer ${OPTIMIZERS}) endforeach() endforeach() +foreach(optimizer ${ALL_OPTIMIZERS}) + list(APPEND gen_python_files "${CMAKE_BINARY_DIR}/lookup_${optimizer}.py") +endforeach() + set(CMAKE_CODEGEN_DIR ${CMAKE_CURRENT_SOURCE_DIR}/codegen) set(codegen_dependencies diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index fd69a22f6e..aa832947c3 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -127,53 +127,60 @@ def int_arg(name: str, default: int = 0) -> str: def generate(**kwargs: Any) -> None: gen_args = kwargs["args"] - # Generates CUDA variants. kwargs["args"] = gen_args["cuda"] + if kwargs.get("has_gpu_support"): + # Generates CUDA variants. + template = env.get_template("embedding_backward_split_template.cu") + src_cu = template.render(weighted=False, **kwargs) + write( + f"gen_embedding_backward_{kwargs.get('optimizer')}_split_unweighted_cuda.cu", + src_cu, + ) + src_cu = template.render(weighted=True, **kwargs) + write( + f"gen_embedding_backward_{kwargs.get('optimizer')}_split_weighted_cuda.cu", + src_cu, + ) + if not kwargs.get("dense"): + template = env.get_template("embedding_backward_split_host_template.cpp") + src_cpp = template.render(**kwargs) + write( + f"gen_embedding_backward_split_{kwargs.get('optimizer')}.cpp", src_cpp + ) - template = env.get_template("embedding_backward_split_template.cu") - src_cu = template.render(weighted=False, **kwargs) - write( - f"gen_embedding_backward_{kwargs.get('optimizer')}_split_unweighted_cuda.cu", - src_cu, - ) - src_cu = template.render(weighted=True, **kwargs) - write( - f"gen_embedding_backward_{kwargs.get('optimizer')}_split_weighted_cuda.cu", - src_cu, - ) if not kwargs.get("dense"): - template = env.get_template("embedding_backward_split_host_template.cpp") - src_cpp = template.render(**kwargs) - write(f"gen_embedding_backward_split_{kwargs.get('optimizer')}.cpp", src_cpp) - # Generates Python invoker for CUDA + CPU template = env.get_template("split_embedding_codegen_lookup_invoker.template") src_py = template.render(is_fbcode=args.is_fbcode, **kwargs) write(f"lookup_{kwargs.get('optimizer')}.py", src_py) - # Generates CPU variants. - kwargs["args"] = gen_args["cpu"] + if kwargs.get("has_cpu_support"): + # Generates CPU variants. + kwargs["args"] = gen_args["cpu"] - is_approx = "approx" in kwargs.get("optimizer") - template = ( - env.get_template("embedding_backward_split_cpu_approx_template.cpp") - if is_approx - else env.get_template("embedding_backward_split_cpu_template.cpp") - ) - - src_cpp = template.render(**kwargs) - write( - f"gen_embedding_backward_{kwargs.get('optimizer')}_split_cpu.cpp", - src_cpp, - ) + is_approx = "approx" in kwargs.get("optimizer") + template = ( + env.get_template("embedding_backward_split_cpu_approx_template.cpp") + if is_approx + else env.get_template("embedding_backward_split_cpu_template.cpp") + ) - if not kwargs.get("dense"): - template = env.get_template("embedding_backward_split_host_cpu_template.cpp") src_cpp = template.render(**kwargs) write( - f"gen_embedding_backward_split_{kwargs.get('optimizer')}_cpu.cpp", src_cpp + f"gen_embedding_backward_{kwargs.get('optimizer')}_split_cpu.cpp", + src_cpp, ) + if not kwargs.get("dense"): + template = env.get_template( + "embedding_backward_split_host_cpu_template.cpp" + ) + src_cpp = template.render(**kwargs) + write( + f"gen_embedding_backward_split_{kwargs.get('optimizer')}_cpu.cpp", + src_cpp, + ) + @dataclass class Args: @@ -369,6 +376,8 @@ def adagrad() -> None: split_precomputation="", split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=True, + has_gpu_support=True, ) @@ -490,6 +499,8 @@ def rowwise_adagrad() -> None: split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=True, + has_gpu_support=True, ) approx_split_weight_update = """ @@ -512,6 +523,8 @@ def rowwise_adagrad() -> None: split_precomputation=split_precomputation, split_weight_update=approx_split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=True, + has_gpu_support=False, ) @@ -611,6 +624,9 @@ def rowwise_adagrad_with_weight_decay() -> None: split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + # Disable both CPU and GPU support + has_cpu_support=False, + has_gpu_support=False, ) approx_split_weight_update = """ @@ -633,6 +649,9 @@ def rowwise_adagrad_with_weight_decay() -> None: split_precomputation=split_precomputation, split_weight_update=approx_split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + # Disable both CPU and GPU support + has_cpu_support=False, + has_gpu_support=False, ) @@ -771,6 +790,8 @@ def rowwise_adagrad_with_counter() -> None: split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=True, + has_gpu_support=True, ) approx_split_weight_update = """ @@ -804,6 +825,8 @@ def rowwise_adagrad_with_counter() -> None: split_precomputation=split_precomputation, split_weight_update=approx_split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=True, + has_gpu_support=False, ) @@ -874,6 +897,8 @@ def rowwise_weighted_adagrad() -> None: split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=True, + has_gpu_support=True, ) @@ -893,6 +918,8 @@ def sgd() -> None: split_precomputation="", split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=True, + has_gpu_support=True, ) approx_split_weight_update = """ @@ -908,6 +935,8 @@ def sgd() -> None: split_precomputation="", split_weight_update=approx_split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=True, + has_gpu_support=True, ) @@ -978,6 +1007,8 @@ def lamb() -> None: split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=False, + has_gpu_support=True, ) @@ -1064,6 +1095,8 @@ def partial_rowwise_lamb() -> None: split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=False, + has_gpu_support=True, ) @@ -1114,6 +1147,8 @@ def adam() -> None: split_precomputation="", split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=False, + has_gpu_support=True, ) @@ -1174,6 +1209,8 @@ def partial_rowwise_adam() -> None: split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=False, + has_gpu_support=True, ) @@ -1232,6 +1269,8 @@ def lars_sgd() -> None: split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, + has_cpu_support=False, + has_gpu_support=True, ) @@ -1296,6 +1335,8 @@ def backward_dense() -> None: (FLOAT, "unused"), ] ), + has_cpu_support=True, + has_gpu_support=True, ) @@ -1323,7 +1364,7 @@ def emb_codegen( partial_rowwise_adam() partial_rowwise_lamb() rowwise_adagrad() - rowwise_adagrad_with_weight_decay() + # rowwise_adagrad_with_weight_decay() # Disabled rowwise_adagrad_with_counter() rowwise_weighted_adagrad() sgd() diff --git a/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template index bd406d39fa..844f04782b 100644 --- a/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template @@ -49,6 +49,7 @@ def invoke( max_counter: float, {% endif %} ) -> torch.Tensor: + {% if has_cpu_support %} if (common_args.host_weights.numel() > 0): return torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function_cpu( # common_args @@ -147,112 +148,119 @@ def invoke( max_counter=max_counter, {% endif %} ) + {% if not has_gpu_support %} else: - return torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function( - # common_args - {% if not dense %} - placeholder_autograd_tensor=common_args.placeholder_autograd_tensor, - {% endif %} - dev_weights=common_args.dev_weights, - uvm_weights=common_args.uvm_weights, - lxu_cache_weights=common_args.lxu_cache_weights, - weights_placements=common_args.weights_placements, - weights_offsets=common_args.weights_offsets, - D_offsets=common_args.D_offsets, - total_D=common_args.total_D, - max_D=common_args.max_D, - hash_size_cumsum=common_args.hash_size_cumsum, - total_hash_size_bits=common_args.total_hash_size_bits, - indices=common_args.indices, - offsets=common_args.offsets, - pooling_mode=common_args.pooling_mode, - indice_weights=common_args.indice_weights, - feature_requires_grad=common_args.feature_requires_grad, - lxu_cache_locations=common_args.lxu_cache_locations, - # optimizer_args - gradient_clipping = optimizer_args.gradient_clipping, - max_gradient=optimizer_args.max_gradient, - stochastic_rounding=optimizer_args.stochastic_rounding, - {% if "learning_rate" in args.split_function_arg_names %} - learning_rate=optimizer_args.learning_rate, - {% endif %} - {% if "eps" in args.split_function_arg_names %} - eps=optimizer_args.eps, - {% endif %} - {% if "beta1" in args.split_function_arg_names %} - beta1=optimizer_args.beta1, - {% endif %} - {% if "beta2" in args.split_function_arg_names %} - beta2=optimizer_args.beta2, - {% endif %} - {% if "weight_decay" in args.split_function_arg_names %} - weight_decay=optimizer_args.weight_decay, - {% endif %} - {% if "weight_decay_mode" in args.split_function_arg_names %} - weight_decay_mode=optimizer_args.weight_decay_mode, - {% endif %} - {% if "eta" in args.split_function_arg_names %} - eta=optimizer_args.eta, - {% endif %} - {% if "momentum" in args.split_function_arg_names %} - momentum=optimizer_args.momentum, - {% endif %} - {% if "counter_halflife" in args.split_function_arg_names %} - counter_halflife=optimizer_args.counter_halflife, - {% endif %} - {% if "adjustment_iter" in args.split_function_arg_names %} - adjustment_iter=optimizer_args.adjustment_iter, - {% endif %} - {% if "adjustment_ub" in args.split_function_arg_names %} - adjustment_ub=optimizer_args.adjustment_ub, - {% endif %} - {% if "learning_rate_mode" in args.split_function_arg_names %} - learning_rate_mode=optimizer_args.learning_rate_mode, - {% endif %} - {% if "grad_sum_decay" in args.split_function_arg_names %} - grad_sum_decay=optimizer_args.grad_sum_decay, - {% endif %} - {% if "tail_id_threshold" in args.split_function_arg_names %} - tail_id_threshold=optimizer_args.tail_id_threshold, - {% endif %} - {% if "is_tail_id_thresh_ratio" in args.split_function_arg_names %} - is_tail_id_thresh_ratio=optimizer_args.is_tail_id_thresh_ratio, - {% endif %} - # momentum1 - {% if "momentum1_dev" in args.split_function_arg_names %} - momentum1_dev=momentum1.dev, - momentum1_uvm=momentum1.uvm, - momentum1_offsets=momentum1.offsets, - momentum1_placements=momentum1.placements, - {% endif %} - # momentum2 - {% if "momentum2_dev" in args.split_function_arg_names %} - momentum2_dev=momentum2.dev, - momentum2_uvm=momentum2.uvm, - momentum2_offsets=momentum2.offsets, - momentum2_placements=momentum2.placements, - {% endif %} - # prev_iter - {% if "prev_iter_dev" in args.split_function_arg_names %} - prev_iter_dev=prev_iter.dev, - prev_iter_uvm=prev_iter.uvm, - prev_iter_offsets=prev_iter.offsets, - prev_iter_placements=prev_iter.placements, - {% endif %} - # row_counter - {% if "row_counter_dev" in args.split_function_arg_names %} - row_counter_dev=row_counter.dev, - row_counter_uvm=row_counter.uvm, - row_counter_offsets=row_counter.offsets, - row_counter_placements=row_counter.placements, - {% endif %} - # iter - {% if "iter" in args.split_function_arg_names %} - iter=iter, - {% endif %} - # max counter - {% if "max_counter" in args.split_function_arg_names %} - max_counter=max_counter, - {% endif %} - output_dtype=common_args.output_dtype, - ) + assert False, "{{ optimizer }} has only CPU support. host_weights.numel() must be greater than 0." + {% endif %} + {% endif %} + + {% if has_gpu_support %} + return torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function( + # common_args + {% if not dense %} + placeholder_autograd_tensor=common_args.placeholder_autograd_tensor, + {% endif %} + dev_weights=common_args.dev_weights, + uvm_weights=common_args.uvm_weights, + lxu_cache_weights=common_args.lxu_cache_weights, + weights_placements=common_args.weights_placements, + weights_offsets=common_args.weights_offsets, + D_offsets=common_args.D_offsets, + total_D=common_args.total_D, + max_D=common_args.max_D, + hash_size_cumsum=common_args.hash_size_cumsum, + total_hash_size_bits=common_args.total_hash_size_bits, + indices=common_args.indices, + offsets=common_args.offsets, + pooling_mode=common_args.pooling_mode, + indice_weights=common_args.indice_weights, + feature_requires_grad=common_args.feature_requires_grad, + lxu_cache_locations=common_args.lxu_cache_locations, + # optimizer_args + gradient_clipping = optimizer_args.gradient_clipping, + max_gradient=optimizer_args.max_gradient, + stochastic_rounding=optimizer_args.stochastic_rounding, + {% if "learning_rate" in args.split_function_arg_names %} + learning_rate=optimizer_args.learning_rate, + {% endif %} + {% if "eps" in args.split_function_arg_names %} + eps=optimizer_args.eps, + {% endif %} + {% if "beta1" in args.split_function_arg_names %} + beta1=optimizer_args.beta1, + {% endif %} + {% if "beta2" in args.split_function_arg_names %} + beta2=optimizer_args.beta2, + {% endif %} + {% if "weight_decay" in args.split_function_arg_names %} + weight_decay=optimizer_args.weight_decay, + {% endif %} + {% if "weight_decay_mode" in args.split_function_arg_names %} + weight_decay_mode=optimizer_args.weight_decay_mode, + {% endif %} + {% if "eta" in args.split_function_arg_names %} + eta=optimizer_args.eta, + {% endif %} + {% if "momentum" in args.split_function_arg_names %} + momentum=optimizer_args.momentum, + {% endif %} + {% if "counter_halflife" in args.split_function_arg_names %} + counter_halflife=optimizer_args.counter_halflife, + {% endif %} + {% if "adjustment_iter" in args.split_function_arg_names %} + adjustment_iter=optimizer_args.adjustment_iter, + {% endif %} + {% if "adjustment_ub" in args.split_function_arg_names %} + adjustment_ub=optimizer_args.adjustment_ub, + {% endif %} + {% if "learning_rate_mode" in args.split_function_arg_names %} + learning_rate_mode=optimizer_args.learning_rate_mode, + {% endif %} + {% if "grad_sum_decay" in args.split_function_arg_names %} + grad_sum_decay=optimizer_args.grad_sum_decay, + {% endif %} + {% if "tail_id_threshold" in args.split_function_arg_names %} + tail_id_threshold=optimizer_args.tail_id_threshold, + {% endif %} + {% if "is_tail_id_thresh_ratio" in args.split_function_arg_names %} + is_tail_id_thresh_ratio=optimizer_args.is_tail_id_thresh_ratio, + {% endif %} + # momentum1 + {% if "momentum1_dev" in args.split_function_arg_names %} + momentum1_dev=momentum1.dev, + momentum1_uvm=momentum1.uvm, + momentum1_offsets=momentum1.offsets, + momentum1_placements=momentum1.placements, + {% endif %} + # momentum2 + {% if "momentum2_dev" in args.split_function_arg_names %} + momentum2_dev=momentum2.dev, + momentum2_uvm=momentum2.uvm, + momentum2_offsets=momentum2.offsets, + momentum2_placements=momentum2.placements, + {% endif %} + # prev_iter + {% if "prev_iter_dev" in args.split_function_arg_names %} + prev_iter_dev=prev_iter.dev, + prev_iter_uvm=prev_iter.uvm, + prev_iter_offsets=prev_iter.offsets, + prev_iter_placements=prev_iter.placements, + {% endif %} + # row_counter + {% if "row_counter_dev" in args.split_function_arg_names %} + row_counter_dev=row_counter.dev, + row_counter_uvm=row_counter.uvm, + row_counter_offsets=row_counter.offsets, + row_counter_placements=row_counter.placements, + {% endif %} + # iter + {% if "iter" in args.split_function_arg_names %} + iter=iter, + {% endif %} + # max counter + {% if "max_counter" in args.split_function_arg_names %} + max_counter=max_counter, + {% endif %} + output_dtype=common_args.output_dtype, + ) + {% endif %}