Skip to content

Commit

Permalink
Prune CPU/GPU TBE optimizer codegen (pytorch#1659)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1659

This diff aims to reduce the build time and libary size of
`//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops`.

The diff modifies the build target to generate and compile only the
necessary files. This is based on the fact that CPU and GPU do not
support all optimizers in `SplitTBE`.  (Before this diff, all optimizers
were generated and compiled for both CPU and GPU.)

The following is the list of supported optimizers

|OptimType|Generated optimizer|Supported on CPU|Supported on GPU|
|EXACT_ADAGRAD|adagrad|x|x|
|EXACT_ROWWISE_ADAGRAD|rowwise_adagrad_with_counter|x|x|
||rowwise_adagrad|x|x|
|EXACT_ROWWISE_WEIGHTED_ADAGRAD|rowwise_weighted_adagrad|x|x|
|EXACT_SGD|sgd|x|x|
|SGD|approx_sgd|x|x|
|ROWWISE_ADAGRAD|approx_rowwise_adagrad_with_counter|x||
||approx_rowwise_adagrad|x||
|ADAM|adam||x|
|LAMB|lamb||x|
|LARS_SGD|lars_sgd||x|
|PARTIAL_ROWWISE_ADAM|partial_rowwise_adam||x|
|PARTIAL_ROWWISE_LAMB|partial_rowwise_lamb||x|
|-|rowwise_adagrad_with_weight_decay|||
|-|approx_rowwise_adagrad_with_weight_decay|||
Note: x = supported

Reviewed By: jianyuh

Differential Revision: D44326540

fbshipit-source-id: f955cc566de6a2e67cd5014f0398a7370c7bad80
  • Loading branch information
sryap authored and facebook-github-bot committed Mar 24, 2023
1 parent d62b5cf commit 2b60742
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 156 deletions.
38 changes: 24 additions & 14 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,27 @@ endif()

set(OPTIMIZERS
adagrad
adam
approx_rowwise_adagrad
approx_rowwise_adagrad_with_weight_decay
approx_rowwise_adagrad_with_counter
approx_sgd
lamb
lars_sgd
partial_rowwise_adam
partial_rowwise_lamb
rowwise_adagrad
rowwise_adagrad_with_weight_decay
rowwise_adagrad_with_counter
rowwise_weighted_adagrad
sgd)

set(CPU_ONLY_OPTIMIZERS
approx_rowwise_adagrad
approx_rowwise_adagrad_with_counter)

set(GPU_ONLY_OPTIMIZERS
adam
lamb
lars_sgd
partial_rowwise_adam
partial_rowwise_lamb)

set(CPU_OPTIMIZERS ${OPTIMIZERS} ${CPU_ONLY_OPTIMIZERS})
set(GPU_OPTIMIZERS ${OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS})
set(ALL_OPTIMIZERS ${OPTIMIZERS} ${CPU_ONLY_OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS})

set(gen_gpu_source_files
"gen_embedding_forward_dense_weighted_codegen_cuda.cu"
"gen_embedding_forward_dense_unweighted_codegen_cuda.cu"
Expand All @@ -137,23 +143,27 @@ set(gen_cpu_source_files

set(gen_python_files ${CMAKE_BINARY_DIR}/__init__.py)

foreach(optimizer ${OPTIMIZERS})
list(APPEND gen_gpu_host_source_files
"gen_embedding_backward_split_${optimizer}.cpp")

foreach(optimizer ${CPU_OPTIMIZERS})
list(APPEND gen_cpu_source_files
"gen_embedding_backward_split_${optimizer}_cpu.cpp")
list(APPEND gen_cpu_source_files
"gen_embedding_backward_${optimizer}_split_cpu.cpp")
endforeach()

list(APPEND gen_python_files "${CMAKE_BINARY_DIR}/lookup_${optimizer}.py")
foreach(optimizer ${GPU_OPTIMIZERS})
list(APPEND gen_gpu_host_source_files
"gen_embedding_backward_split_${optimizer}.cpp")

foreach(weight weighted unweighted)
list(APPEND gen_gpu_source_files
"gen_embedding_backward_${optimizer}_split_${weight}_cuda.cu")
endforeach()
endforeach()

foreach(optimizer ${ALL_OPTIMIZERS})
list(APPEND gen_python_files "${CMAKE_BINARY_DIR}/lookup_${optimizer}.py")
endforeach()

set(CMAKE_CODEGEN_DIR ${CMAKE_CURRENT_SOURCE_DIR}/codegen)

set(codegen_dependencies
Expand Down
109 changes: 75 additions & 34 deletions fbgemm_gpu/codegen/embedding_backward_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,53 +127,60 @@ def int_arg(name: str, default: int = 0) -> str:
def generate(**kwargs: Any) -> None:
gen_args = kwargs["args"]

# Generates CUDA variants.
kwargs["args"] = gen_args["cuda"]
if kwargs.get("has_gpu_support"):
# Generates CUDA variants.
template = env.get_template("embedding_backward_split_template.cu")
src_cu = template.render(weighted=False, **kwargs)
write(
f"gen_embedding_backward_{kwargs.get('optimizer')}_split_unweighted_cuda.cu",
src_cu,
)
src_cu = template.render(weighted=True, **kwargs)
write(
f"gen_embedding_backward_{kwargs.get('optimizer')}_split_weighted_cuda.cu",
src_cu,
)
if not kwargs.get("dense"):
template = env.get_template("embedding_backward_split_host_template.cpp")
src_cpp = template.render(**kwargs)
write(
f"gen_embedding_backward_split_{kwargs.get('optimizer')}.cpp", src_cpp
)

template = env.get_template("embedding_backward_split_template.cu")
src_cu = template.render(weighted=False, **kwargs)
write(
f"gen_embedding_backward_{kwargs.get('optimizer')}_split_unweighted_cuda.cu",
src_cu,
)
src_cu = template.render(weighted=True, **kwargs)
write(
f"gen_embedding_backward_{kwargs.get('optimizer')}_split_weighted_cuda.cu",
src_cu,
)
if not kwargs.get("dense"):
template = env.get_template("embedding_backward_split_host_template.cpp")
src_cpp = template.render(**kwargs)
write(f"gen_embedding_backward_split_{kwargs.get('optimizer')}.cpp", src_cpp)

# Generates Python invoker for CUDA + CPU
template = env.get_template("split_embedding_codegen_lookup_invoker.template")
src_py = template.render(is_fbcode=args.is_fbcode, **kwargs)
write(f"lookup_{kwargs.get('optimizer')}.py", src_py)

# Generates CPU variants.
kwargs["args"] = gen_args["cpu"]
if kwargs.get("has_cpu_support"):
# Generates CPU variants.
kwargs["args"] = gen_args["cpu"]

is_approx = "approx" in kwargs.get("optimizer")
template = (
env.get_template("embedding_backward_split_cpu_approx_template.cpp")
if is_approx
else env.get_template("embedding_backward_split_cpu_template.cpp")
)

src_cpp = template.render(**kwargs)
write(
f"gen_embedding_backward_{kwargs.get('optimizer')}_split_cpu.cpp",
src_cpp,
)
is_approx = "approx" in kwargs.get("optimizer")
template = (
env.get_template("embedding_backward_split_cpu_approx_template.cpp")
if is_approx
else env.get_template("embedding_backward_split_cpu_template.cpp")
)

if not kwargs.get("dense"):
template = env.get_template("embedding_backward_split_host_cpu_template.cpp")
src_cpp = template.render(**kwargs)
write(
f"gen_embedding_backward_split_{kwargs.get('optimizer')}_cpu.cpp", src_cpp
f"gen_embedding_backward_{kwargs.get('optimizer')}_split_cpu.cpp",
src_cpp,
)

if not kwargs.get("dense"):
template = env.get_template(
"embedding_backward_split_host_cpu_template.cpp"
)
src_cpp = template.render(**kwargs)
write(
f"gen_embedding_backward_split_{kwargs.get('optimizer')}_cpu.cpp",
src_cpp,
)


@dataclass
class Args:
Expand Down Expand Up @@ -369,6 +376,8 @@ def adagrad() -> None:
split_precomputation="",
split_weight_update=split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=True,
has_gpu_support=True,
)


Expand Down Expand Up @@ -490,6 +499,8 @@ def rowwise_adagrad() -> None:
split_precomputation=split_precomputation,
split_weight_update=split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=True,
has_gpu_support=True,
)

approx_split_weight_update = """
Expand All @@ -512,6 +523,8 @@ def rowwise_adagrad() -> None:
split_precomputation=split_precomputation,
split_weight_update=approx_split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=True,
has_gpu_support=False,
)


Expand Down Expand Up @@ -611,6 +624,9 @@ def rowwise_adagrad_with_weight_decay() -> None:
split_precomputation=split_precomputation,
split_weight_update=split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
# Disable both CPU and GPU support
has_cpu_support=False,
has_gpu_support=False,
)

approx_split_weight_update = """
Expand All @@ -633,6 +649,9 @@ def rowwise_adagrad_with_weight_decay() -> None:
split_precomputation=split_precomputation,
split_weight_update=approx_split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
# Disable both CPU and GPU support
has_cpu_support=False,
has_gpu_support=False,
)


Expand Down Expand Up @@ -771,6 +790,8 @@ def rowwise_adagrad_with_counter() -> None:
split_precomputation=split_precomputation,
split_weight_update=split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=True,
has_gpu_support=True,
)

approx_split_weight_update = """
Expand Down Expand Up @@ -804,6 +825,8 @@ def rowwise_adagrad_with_counter() -> None:
split_precomputation=split_precomputation,
split_weight_update=approx_split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=True,
has_gpu_support=False,
)


Expand Down Expand Up @@ -874,6 +897,8 @@ def rowwise_weighted_adagrad() -> None:
split_precomputation=split_precomputation,
split_weight_update=split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=True,
has_gpu_support=True,
)


Expand All @@ -893,6 +918,8 @@ def sgd() -> None:
split_precomputation="",
split_weight_update=split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=True,
has_gpu_support=True,
)

approx_split_weight_update = """
Expand All @@ -908,6 +935,8 @@ def sgd() -> None:
split_precomputation="",
split_weight_update=approx_split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=True,
has_gpu_support=True,
)


Expand Down Expand Up @@ -978,6 +1007,8 @@ def lamb() -> None:
split_precomputation=split_precomputation,
split_weight_update=split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=False,
has_gpu_support=True,
)


Expand Down Expand Up @@ -1064,6 +1095,8 @@ def partial_rowwise_lamb() -> None:
split_precomputation=split_precomputation,
split_weight_update=split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=False,
has_gpu_support=True,
)


Expand Down Expand Up @@ -1114,6 +1147,8 @@ def adam() -> None:
split_precomputation="",
split_weight_update=split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=False,
has_gpu_support=True,
)


Expand Down Expand Up @@ -1174,6 +1209,8 @@ def partial_rowwise_adam() -> None:
split_precomputation=split_precomputation,
split_weight_update=split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=False,
has_gpu_support=True,
)


Expand Down Expand Up @@ -1232,6 +1269,8 @@ def lars_sgd() -> None:
split_precomputation=split_precomputation,
split_weight_update=split_weight_update,
split_weight_update_cpu=split_weight_update_cpu,
has_cpu_support=False,
has_gpu_support=True,
)


Expand Down Expand Up @@ -1296,6 +1335,8 @@ def backward_dense() -> None:
(FLOAT, "unused"),
]
),
has_cpu_support=True,
has_gpu_support=True,
)


Expand Down Expand Up @@ -1323,7 +1364,7 @@ def emb_codegen(
partial_rowwise_adam()
partial_rowwise_lamb()
rowwise_adagrad()
rowwise_adagrad_with_weight_decay()
# rowwise_adagrad_with_weight_decay() # Disabled
rowwise_adagrad_with_counter()
rowwise_weighted_adagrad()
sgd()
Expand Down
Loading

0 comments on commit 2b60742

Please sign in to comment.