๋ฒ์ญ: ์ ์ฉํ
PyTorch C++ ํ๋ก ํธ์๋๋ PyTorch ๋จธ์ ๋ฌ๋ ํ๋ ์์ํฌ์ ์์ C++ ์ธํฐํ์ด์ค์ ๋๋ค. PyTorch์ ์ฃผ๋ ์ธํฐํ์ด์ค๋ ๋ฌผ๋ก ํ์ด์ฌ์ด์ง๋ง ์ด ๊ณณ์ API๋ ํ ์(tensor)๋ ์๋ ๋ฏธ๋ถ๊ณผ ๊ฐ์ ๊ธฐ์ด์ ์ธ ์๋ฃ๊ตฌ์กฐ ๋ฐ ๊ธฐ๋ฅ์ ์ ๊ณตํ๋ C++ ์ฝ๋๋ฒ ์ด์ค ์์ ๊ตฌํ๋์์ต๋๋ค. C++ ํ๋ก ํธ์๋๋ ์ด๋ฌํ ๊ธฐ์ด์ ์ธ C++ ์ฝ๋๋ฒ ์ด์ค๋ฅผ ๋น๋กฏํด ๋จธ์ ๋ฌ๋ ํ์ต๊ณผ ์ถ๋ก ์ ์ํด ํ์ํ ๋๊ตฌ๋ค์ ์์ํ๋ ์์ C++11 API๋ฅผ ์ ๊ณตํฉ๋๋ค. ์ฌ๊ธฐ์๋ ์ ๊ฒฝ๋ง ๋ชจ๋ธ๋ง์ ์ํด ํ์ํ ๊ณต์ฉ ์ปดํฌ๋ํธ๋ค์ ๋นํธ์ธ ๋ชจ์, ๊ทธ๊ฒ์ ์์ํ๊ธฐ ์ํ ์ปค์คํ ๋ชจ๋, ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ๊ณผ ๊ฐ์ ์ ๋ช ํ ์ต์ ํ ์๊ณ ๋ฆฌ์ฆ ๋ผ์ด๋ธ๋ฌ๋ฆฌ, ๋ณ๋ ฌ ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ๋ฐ์ดํฐ์ ์ ์ ์ํ๊ณ ๋ถ๋ฌ์ค๊ธฐ ์ํ API, ์ง๋ ฌํ ๋ฃจํด ๋ฑ์ด ํฌํจ๋ฉ๋๋ค.
์ด ํํ ๋ฆฌ์ผ์ C++ ํ๋ก ํธ์๋๋ก ๋ชจ๋ธ์ ํ์ตํ๋ ์๋ ํฌ ์๋ ์์ ๋ฅผ ์๋ดํฉ๋๋ค. ๊ตฌ์ฒด์ ์ผ๋ก, ์ฐ๋ฆฌ๋ ์์ฑ ๋ชจ๋ธ ์ค ํ๋์ธ DCGAN ์ ํ์ต์์ผ MNIST ์ซ์ ์ด๋ฏธ์ง๋ค์ ์์ฑํ ๊ฒ์ ๋๋ค. ๊ฐ๋ ์ ์ผ๋ก ์ฌ์ด ์์์ด์ง๋ง, ์ฌ๋ฌ๋ถ์ด PyTorch C++ ํ๋ก ํธ์๋์ ๋ํ ๋๋ต์ ์ธ ๊ฐ์๋ฅผ ํ์ ํ๊ณ ๋ ๋ณต์กํ ๋ชจ๋ธ์ ํ์ต์ํค๊ณ ์ถ์ ์๊ตฌ๋ฅผ ๋ถ๋ฌ์ผ์ผํค๊ธฐ์ ์ถฉ๋ถํ ๊ฒ์ ๋๋ค. ๋จผ์ C++ ํ๋ก ํธ์๋ ์ฌ์ฉ์ ๋ํ ๋๊ธฐ๋ถ์ฌ๊ฐ ๋ ๋งํ ์ด์ผ๊ธฐ๋ก ์์ํ๊ณ , ๊ณง๋ฐ๋ก ๋ชจ๋ธ์ ์ ์ํ๊ณ ํ์ตํด ๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
Tip
C++ ํ๋ก ํธ์๋์ ๋ํ ์งง๊ณ ์ฌ๋ฏธ์๋ ๋ฐํ๋ฅผ ๋ณด๋ ค๋ฉด CppCon 2018 ๋ผ์ดํธ๋ ํ ํฌ ๋ฅผ ์์ฒญํ์ธ์.
Tip
์ด ๋ ธํธ ๋ C++ ํ๋ก ํธ์๋์ ์ปดํฌ๋ํธ์ ๋์์ธ ์ฒ ํ์ ์ ๋ฐ์ ์ธ ๊ฐ์๋ฅผ ์ ๊ณตํฉ๋๋ค.
Tip
PyTorch C++ ์ํ๊ณ์ ๋ํ ๋ฌธ์๋ https://pytorch.org/cppdocs์์ ํ์ธํ ์ ์์ต๋๋ค. API ๋ ๋ฒจ์ ๋ฌธ์๋ฟ๋ง ์๋๋ผ ๊ฐ๊ด์ ์ธ ์ค๋ช ๋ ์ฐพ์ ์ ์์ ๊ฒ์ ๋๋ค.
GAN๊ณผ MNIST ์ซ์๋ก์ ์ค๋ ๋ ์ฌ์ ์ ์์ํ๊ธฐ์ ์์, ๋จผ์ ํ์ด์ฌ ๋์ C++ ํ๋ก ํธ์๋๋ฅผ ์ฌ์ฉํ๋ ์ด์ ์ ๋ํด ์ค๋ช ํ๊ฒ ์ต๋๋ค. ์ฐ๋ฆฌ(PyTorch ํ)๋ ํ์ด์ฌ์ ์ฌ์ฉํ ์ ์๊ฑฐ๋ ์ฌ์ฉํ๊ธฐ์ ์ ํฉํ์ง ์์ ํ๊ฒฝ์์ ์ฐ๊ตฌ๋ฅผ ๊ฐ๋ฅํ๊ฒ ํ๊ธฐ ์ํด C++ ํ๋ก ํธ์๋๋ฅผ ๋ง๋ค์์ต๋๋ค. ์๋ฅผ ๋ค๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- ์ ์ง์ฐ ์์คํ : ์ด๋น ํ๋ ์ ์๊ฐ ๋๊ณ ์ง์ฐ ์๊ฐ์ด ์งง์ ์์ C++ ๊ฒ์ ์์ง์์ ๊ฐํ ํ์ต ์ฐ๊ตฌ๋ฅผ ์ํํ ์ ์์ต๋๋ค. ๊ทธ๋ฌํ ํ๊ฒฝ์์๋ ํ์ด์ฌ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ณด๋ค ์์ C++ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ํจ์ฌ ๋ ์ ํฉํฉ๋๋ค. ํ์ด์ฌ์ ๋๋ฆฐ ์ธํฐํ๋ฆฌํฐ ๋๋ฌธ์ ๋ค๋ฃจ๊ธฐ๊ฐ ์ฝ์ง ์์ต๋๋ค.
- ๊ณ ๋์ ๋ฉํฐ์ฐ๋ ๋ฉ ํ๊ฒฝ: ๊ธ๋ก๋ฒ ์ธํฐํ๋ฆฌํฐ ๋ฝ(GIL)์ผ๋ก ์ธํด ํ์ด์ฌ์ ๋์์ ๋ ์ด์์ ์์คํ ์ฐ๋ ๋๋ฅผ ์คํํ ์ ์์ต๋๋ค. ๋์์ผ๋ก ๋ฉํฐํ๋ก์ธ์ฑ์ ์ฌ์ฉํ๋ฉด ํ์ฅ์ฑ์ด ๋จ์ด์ง๋ฉฐ ์ฌ๊ฐํ ํ๊ณ๊ฐ ์์ต๋๋ค. C++๋ ์ด๋ฌํ ์ ์ฝ ์กฐ๊ฑด์ด ์์ผ๋ฉฐ ์ฐ๋ ๋๋ฅผ ์ฝ๊ฒ ๋ง๋ค๊ณ ์ฌ์ฉํ ์ ์์ต๋๋ค. Deep Neuroevolution ์ ์ฌ์ฉ๋ ๊ฒ๊ณผ ๊ฐ์ด ๊ณ ๋์ ๋ณ๋ ฌํ๊ฐ ํ์ํ ๋ชจ๋ธ๋ ์ด๋ฅผ ํ์ฉํ ์ ์์ต๋๋ค.
- ๊ธฐ์กด์ C++ ์ฝ๋๋ฒ ์ด์ค: ๋ฐฑ์๋ ์๋ฒ์ ์น ํ์ด์ง ์๋น์ค๋ถํฐ ์ฌ์ง ํธ์ง ์ํํธ์จ์ด์ 3D ๊ทธ๋ํฝ ๋ ๋๋ง์ ์ด๋ฅด๊ธฐ๊น์ง ์ด๋ ํ ์์ ์ด๋ผ๋ ์ํํ๋ ๊ธฐ์กด C++ ์ ํ๋ฆฌ์ผ์ด์ ์์ ์๋ก์, ๋จธ์ ๋ฌ๋ ๋ฐฉ๋ฒ๋ก ์ ์์คํ ์ ํตํฉํ๊ณ ์ถ์ ์ ์์ต๋๋ค. C++ ํ๋ก ํธ์๋๋ PyTorch (ํ์ด์ฌ) ๊ฒฝํ ๋ณธ์ฐ์ ๋์ ์ ์ฐ์ฑ๊ณผ ์ง๊ด์ฑ์ ์ ์งํ๋ฉด์, ํ์ด์ฌ๊ณผ C++๋ฅผ ์๋ค๋ก ๋ฐ์ธ๋ฉํ๋ ๋ฒ๊ฑฐ๋ก์ ์์ด C++๋ฅผ ์ฌ์ฉํ ์ ์๊ฒ ํด์ค๋๋ค.
C++ ํ๋ก ํธ์๋์ ๋ชฉ์ ์ ํ์ด์ฌ ํ๋ก ํธ์๋์ ๊ฒฝ์ํ๋ ๊ฒ์ด ์๋ ๋ณด์ํ๋ ๊ฒ์ ๋๋ค. ์ฐ๊ตฌ์์ ์์ง๋์ด ๋ชจ๋๊ฐ PyTorch์ ๋จ์์ฑ, ์ ์ฐ์ฑ ๋ฐ ์ง๊ด์ ์ธ API๋ฅผ ๋งค์ฐ ์ข์ํฉ๋๋ค. ์ฐ๋ฆฌ์ ๋ชฉํ๋ ์ฌ๋ฌ๋ถ์ด ์์ ์์๋ฅผ ๋น๋กฏํ ๋ชจ๋ ๊ฐ๋ฅํ ํ๊ฒฝ์์ ์ด ํต์ฌ ๋์์ธ ์์น์ ์ด์ฉํ ์ ์๋๋ก ํ๋ ๊ฒ์ ๋๋ค. ์ด๋ฌํ ์๋๋ฆฌ์ค ์ค ํ๋๊ฐ ์ฌ๋ฌ๋ถ์ ์ฌ๋ก์ ํด๋นํ๊ฑฐ๋, ๋จ์ํ ๊ด์ฌ์ด ์๊ฑฐ๋ ๊ถ๊ธํ๋ค๋ฉด ์๋ ๋ด์ฉ์ ํตํด C++ ํ๋ก ํธ์๋์ ๋ํด ์์ธํ ์ดํด๋ณด์ธ์.
Tip
C++ ํ๋ก ํธ์๋๋ ํ์ด์ฌ ํ๋ก ํธ์๋์ ์ต๋ํ ์ ์ฌํ API๋ฅผ
์ ๊ณตํ๊ณ ์ ํฉ๋๋ค. ๋ง์ผ ํ์ด์ฌ ํ๋ก ํธ์๋์ ์ต์ํ ์ฌ๋์ด "C++ ํ๋ก ํธ์๋๋ก X๋ฅผ ์ด๋ป๊ฒ ํด์ผ ํ๋๊ฐ?" ์๋ฌธ์ ๊ฐ๋๋ค๋ฉด, ๋ง์ ๊ฒฝ์ฐ์ ํ์ด์ฌ์์์ ๊ฐ์ ๋ฐฉ์์ผ๋ก ์ฝ๋๋ฅผ ์์ฑํด ํ์ด์ฌ์์์ ๋์ผํ ํจ์์ ๋ฉ์๋๋ฅผ ์ฌ์ฉํ ์ ์์ ๊ฒ์ ๋๋ค. (๋ค๋ง, ์จ์ ์ ๋๋ธ ์ฝ๋ก ์ผ๋ก ๋ฐ๊พธ๋ ๊ฒ์ ์ ์ํ์ธ์.)
๋จผ์ ์ต์ํ์ C++ ์ ํ๋ฆฌ์ผ์ด์ ์ ์์ฑํด ์ฐ๋ฆฌ์ ์ค์ ๊ณผ ๋น๋ ํ๊ฒฝ์ด ๋์ผํ์ง ํ์ธํ๊ฒ ์ต๋๋ค. ๋จผ์ , C++ ํ๋ก ํธ์๋๋ฅผ ์ฌ์ฉํ๋ ๋ฐ ํ์ํ ๋ชจ๋ ๊ด๋ จ ํค๋, ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋ฐ CMake ๋น๋ ํ์ผ์ ํจํค์งํ๋ LibTorch ๋ฐฐํฌํ์ ์ฌ๋ณธ์ด ํ์ํฉ๋๋ค. ๋ฆฌ๋ ์ค, ๋งฅOS, ์๋์ฐ์ฉ LibTorch ๋ฐฐํฌํ์ PyTorch website ์์ ๋ค์ด๋ก๋ํ ์ ์์ต๋๋ค. ์ด ํํ ๋ฆฌ์ผ์ ๋๋จธ์ง ๋ถ๋ถ์ ๊ธฐ๋ณธ ์ฐ๋ถํฌ ๋ฆฌ๋ ์ค ํ๊ฒฝ์ ๊ฐ์ ํ์ง๋ง ๋งฅOS๋ ์๋์ฐ๋ฅผ ์ฌ์ฉํ์ ๋ ๊ด์ฐฎ์ต๋๋ค.
Tip
PyTorch C++ ๋ฐฐํฌํ ์ค์น ์ ์ค๋ช ์ ๋ค์์ ๊ณผ์ ์ด ๋ ์์ธํ ์๋ด๋์ด ์์ต๋๋ค.
Tip
์๋์ฐ์์๋ ๋๋ฒ๊ทธ ๋ฐ ๋ฆด๋ฆฌ์ค ๋น๋๊ฐ ABI์ ํธํ๋์ง ์์ต๋๋ค. ํ๋ก์ ํธ๋ฅผ
๋๋ฒ๊ทธ ๋ชจ๋๋ก ๋น๋ํ๋ ค๋ฉด LibTorch์ ๋๋ฒ๊ทธ ๋ฒ์ ์ ์ฌ์ฉํด๋ณด์ธ์.
์๋์ cmake --build .
์ ์ฌ๋ฐ๋ฅธ ์ค์ ์ ์ง์ ํ๋ ๊ฒ๋ ์์ง
๋ง์ธ์.
๊ฐ์ฅ ๋จผ์ ํ ๊ฒ์ PyTorch ์น์ฌ์ดํธ์์ ๊ฒ์๋ ๋งํฌ๋ฅผ ํตํด LibTorch ๋ฐฐํฌํ์ ๋ก์ปฌ์ ๋ค์ด๋ก๋ํ๋ ๊ฒ์ ๋๋ค. ์ผ๋ฐ์ Ubuntu Linux ํ๊ฒฝ์ ๊ฒฝ์ฐ ๋ค์ ๋ช ๋ น์ด๋ฅผ ์คํํฉ๋๋ค.
# CUDA 9.0 ๋ฑ์ ๋ํ ์ง์์ด ํ์ํ ๊ฒฝ์ฐ ์๋ URL์์ "cpu"๋ฅผ "cu90"๋ก ๋ฐ๊พธ์ธ์.
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip
๋ค์์ผ๋ก torch/torch.h
๋ฅผ ํธ์ถํ๋ dcgan.cpp
๋ผ๋ ์ด๋ฆ์ C++
ํ์ผ ํ๋๋ฅผ ์์ฑํฉ์๋ค. ์ฐ์ ์ ์๋์ ๊ฐ์ด 3x3 ํญ๋ฑ ํ๋ ฌ์ ์ถ๋ ฅํ๊ธฐ๋ง ํ๋ฉด
๋ฉ๋๋ค.
#include <torch/torch.h>
#include <iostream>
int main() {
torch::Tensor tensor = torch::eye(3);
std::cout << tensor << std::endl;
}
์ด ์์ ์ ํ๋ฆฌ์ผ์ด์
๊ณผ ์ดํ ์์ฑํ ํ์ต์ฉ ์คํฌ๋ฆฝํธ๋ฅผ ๋น๋ํ๊ธฐ ์ํด ์ฐ๋ฆฌ๋ ์๋์ CMakeLists.txt
๋ฅผ
์ฌ์ฉํ ๊ฒ์
๋๋ค:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(dcgan)
find_package(Torch REQUIRED)
add_executable(dcgan dcgan.cpp)
target_link_libraries(dcgan "${TORCH_LIBRARIES}")
set_property(TARGET dcgan PROPERTY CXX_STANDARD 14)
Note
CMake๋ LibTorch์ ๊ถ์ฅ๋๋ ๋น๋ ์์คํ ์ด์ง๋ง ํ์ ์๊ตฌ ์ฌํญ์ ์๋๋๋ค. Visual Studio ํ๋ก์ ํธ ํ์ผ, QMake, ์ผ๋ฐ Make ํ์ผ ๋ฑ ๋ค๋ฅธ ๋น๋ ํ๊ฒฝ์ ์ฌ์ฉํด๋ ๋ฉ๋๋ค. ํ์ง๋ง ์ด์ ๋ํ ์ฆ๊ฐ์ ์ธ ์ง์์ ์ ๊ณตํ์ง ์์ต๋๋ค.
์ CMake ํ์ผ 4๋ฒ์งธ ์ค์ find_package(Torch REQUIRED)
๋
CMake๊ฐ LibTorch ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋น๋ ์ค์ ์ ์ฐพ๋๋ก ์๋ดํฉ๋๋ค.
CMake๊ฐ ํด๋น ํ์ผ์ ์์น ๋ฅผ ์ฐพ์ ์ ์๋๋ก ํ๋ ค๋ฉด cmake
ํธ์ถ ์
CMAKE_PREFIX_PATH
๋ฅผ ์ค์ ํด์ผ ํฉ๋๋ค. ์ด์ ์์ dcgan
์ ํ๋ฆฌ์ผ์ด์
์
๋ํด ๋๋ ํฐ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ํต์ผํ๋๋ก ํ๊ฒ ์ต๋๋ค.
dcgan/
CMakeLists.txt
dcgan.cpp
๋ํ ์์ผ๋ก ์์ถ ํด์ ๋ LibTorch ๋ฐฐํฌํ์ ๊ฒฝ๋ก๋ฅผ /path/to/libtorch
๋ก ๋ถ๋ฅด๋๋ก ํ๊ฒ ์ต๋๋ค. ์ด๋ ๋ฐ๋์ ์ ๋ ๊ฒฝ๋ก์ฌ์ผ ํฉ๋๋ค. ํนํ
CMAKE_PREFIX_PATH
๋ฅผ ../../libtorch
์ ๊ฐ์ด ์ค์ ํ๋ฉด ์์์น ๋ชปํ
์ค๋ฅ๊ฐ ๋ฐ์ํ ์ ์์ต๋๋ค. ๊ทธ๋ณด๋ค๋ $PWD/../../libtorch
์ ๊ฐ์ด ํด๋น
์ ๋ ๊ฒฝ๋ก๋ฅผ ์
๋ ฅํ์ธ์. ์ด์ ์ ํ๋ฆฌ์ผ์ด์
์ ๋น๋ํ ์ค๋น๊ฐ ๋์์ต๋๋ค.
root@fa350df05ecf:/home# mkdir build
root@fa350df05ecf:/home# cd build
root@fa350df05ecf:/home/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Found torch: /path/to/libtorch/lib/libtorch.so
-- Configuring done
-- Generating done
-- Build files have been written to: /home/build
root@fa350df05ecf:/home/build# cmake --build . --config Release
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcgan
์์์ ์ฐ๋ฆฌ๋ ๋จผ์ dcgan
๋๋ ํ ๋ฆฌ ์์ build
ํด๋๋ฅผ ๋ง๋ค๊ณ
์ด ํด๋์ ๋ค์ด๊ฐ์ ํ์ํ ๋น๋(Make) ํ์ผ์ ์์ฑํ๋ cmake
๋ช
๋ น์ด๋ฅผ
์คํํ ํ cmake --build . --config Release
๋ฅผ ์คํํ์ฌ ํ๋ก์ ํธ๋ฅผ
์ฑ๊ณต์ ์ผ๋ก ์ปดํ์ผํ์ต๋๋ค. ์ด์ ์ฐ๋ฆฌ์ ์์ ๋ฐ์ด๋๋ฆฌ๋ฅผ ์คํํ๊ณ ๊ธฐ๋ณธ
ํ๋ก์ ํธ ์ค์ ์ ๋ํ ์ด ์น์
์ ์๋ฃํ ์ค๋น๊ฐ ๋์ต๋๋ค.
root@fa350df05ecf:/home/build# ./dcgan
1 0 0
0 1 0
0 0 1
[ Variable[CPUFloatType]{3,3} ]
์ ๊ฐ ๋ณด๊ธฐ์ ํญ๋ฑ ํ๋ ฌ์ธ ๊ฒ ๊ฐ๊ตฐ์!
์ด์ ๊ธฐ๋ณธ์ ์ธ ํ๊ฒฝ์ ์ค์ ํ์ผ๋, ์ด๋ฒ ํํ ๋ฆฌ์ผ์์ ํจ์ฌ ๋ ํฅ๋ฏธ๋ก์ด ๋ถ๋ถ์ ์ดํด๋ด ์๋ค. ๋จผ์ C++ ํ๋ก ํธ์๋์์ ๋ชจ๋์ ์ ์ํ๊ณ ์ํธ ์์ฉํ๋ ๋ฐฉ๋ฒ์ ๋ํด ๋ ผ์ํ๊ฒ ์ต๋๋ค. ๊ธฐ๋ณธ์ ์ธ ์๊ท๋ชจ ์์ ๋ชจ๋๋ถํฐ ์์ํ์ฌ C++ ํ๋ก ํธ์๋๊ฐ ์ ๊ณตํ๋ ๋ค์ํ ๋ด์ฅ ๋ชจ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ์์ฑ๋ ์๋ GAN์ ๊ตฌํํ๊ฒ ์ต๋๋ค.
ํ์ด์ฌ ์ธํฐํ์ด์ค์ ๋ง์ฐฌ๊ฐ์ง๋ก, C++ ํ๋ก ํธ์๋์ ๊ธฐ๋ฐ์ ๋ ์ ๊ฒฝ๋ง๋
๋ชจ๋ ์ด๋ผ ๋ถ๋ฆฌ๋ ์ฌ์ฌ์ฉ ๊ฐ๋ฅํ ๋น๋ฉ ๋ธ๋ก์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. ํ์ด์ฌ์
๋ค๋ฅธ ๋ชจ๋ ๋ชจ๋์ด ํ์๋๋ torch.nn.Module
๋ผ๋ ๊ธฐ๋ณธ ๋ชจ๋ ํด๋์ค๊ฐ
์๋ฏ์ด C++์๋ torch::nn::Module
ํด๋์ค๊ฐ ์์ต๋๋ค.
์ผ๋ฐ์ ์ผ๋ก ๋ชจ๋์๋ ์บก์ํ๋ ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํํ๋ forward()
๋ฉ์๋๋ฅผ ๋น๋กฏํด ๋งค๊ฐ๋ณ์, ๋ฒํผ ๋ฐ ํ์ ๋ชจ๋ ์ธ ๊ฐ์ง ํ์ ๊ฐ์ฒด๊ฐ
ํฌํจ๋ฉ๋๋ค.
๋งค๊ฐ๋ณ์์ ๋ฒํผ๋ ํ ์์ ํํ๋ก ์ํ๋ฅผ ์ ์ฅํฉ๋๋ค. ๋งค๊ฐ๋ณ์๋ ๊ทธ๋๋์ธํธ๋ฅผ ๊ธฐ๋กํ์ง๋ง ๋ฒํผ๋ ๊ธฐ๋กํ์ง ์์ต๋๋ค. ๋งค๊ฐ๋ณ์๋ ์ผ๋ฐ์ ์ผ๋ก ์ ๊ฒฝ๋ง์ ํ์ต ๊ฐ๋ฅํ ๊ฐ์ค์น์ ๋๋ค. ๋ฒํผ์ ์๋ก๋ ๋ฐฐ์น ์ ๊ทํ๋ฅผ ์ํ ํ๊ท ๋ฐ ๋ถ์ฐ์ด ์์ต๋๋ค. ํน์ ๋ ผ๋ฆฌ ๋ฐ ์ํ ๋ธ๋ก์ ์ฌ์ฌ์ฉํ๊ธฐ ์ํด, PyTorch API๋ ๋ชจ๋๋ค์ด ์ค์ฒฉ๋๋ ๊ฒ์ ํ์ฉํฉ๋๋ค. ์ค์ฒฉ๋ ๋ชจ๋์ ํ์ ๋ชจ๋ ์ด๋ผ๊ณ ํฉ๋๋ค.
๋งค๊ฐ๋ณ์, ๋ฒํผ ๋ฐ ํ์ ๋ชจ๋์ ๋ช
์์ ์ผ๋ก ๋ฑ๋ก(register)์ ํด์ผ ํฉ๋๋ค.
๋ฑ๋ก์ด ๋๋ฉด parameters()
๋ buffers()
๊ฐ์ ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ (์ค์ฒฉ์
ํฌํจํ) ์ ์ฒด ๋ชจ๋ ๊ณ์ธต ๊ตฌ์กฐ์์ ๋ชจ๋ ๋งค๊ฐ๋ณ์ ๋ฌถ์์ ๊ฒ์ํ ์ ์์ต๋๋ค.
๋ง์ฐฌ๊ฐ์ง๋ก, to(...)
์ ๊ฐ์ ๋ฉ์๋๋ ๋ชจ๋ ๊ณ์ธต ๊ตฌ์กฐ ์ ์ฒด์ ๋ํ ๋ฉ์๋์
๋๋ค.
์๋ฅผ ๋ค์ด, to(torch::kCUDA)
๋ ๋ชจ๋ ๋งค๊ฐ๋ณ์์ ๋ฒํผ๋ฅผ CPU์์ CUDA ๋ฉ๋ชจ๋ฆฌ๋ก
์ด๋์ํต๋๋ค.
์ด ๋ด์ฉ์ ์ฝ๋๋ก ๊ตฌํํ๊ธฐ ์ํด, ํ์ด์ฌ ์ธํฐํ์ด์ค๋ก ์์ฑ๋ ๊ฐ๋จํ ๋ชจ๋ ํ๋๋ฅผ ์๊ฐํด ๋ด ์๋ค.
import torch
class Net(torch.nn.Module):
def __init__(self, N, M):
super(Net, self).__init__()
self.W = torch.nn.Parameter(torch.randn(N, M))
self.b = torch.nn.Parameter(torch.randn(M))
def forward(self, input):
return torch.addmm(self.b, input, self.W)
์ด๋ฅผ C++๋ก ์์ฑํ๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
#include <torch/torch.h>
struct Net : torch::nn::Module {
Net(int64_t N, int64_t M) {
W = register_parameter("W", torch::randn({N, M}));
b = register_parameter("b", torch::randn(M));
}
torch::Tensor forward(torch::Tensor input) {
return torch::addmm(b, input, W);
}
torch::Tensor W, b;
};
ํ์ด์ฌ์์์ ๋ง์ฐฌ๊ฐ์ง๋ก ๋ชจ๋ ๊ธฐ๋ณธ ํด๋์ค์์ ํ์ํ Net
์ด๋ผ๋ ํด๋์ค๋ฅผ
์ ์ํฉ๋๋ค. (์ฌ์ด ์ค๋ช
์ ์ํด class
๋์ struct
์ ์ฌ์ฉํ์ต๋๋ค.)
ํ์ด์ฌ์์ torch.randn์ ์ฌ์ฉํ๋ ๊ฒ์ฒ๋ผ ์์ฑ์์์๋ torch::randn
์
์ฌ์ฉํด ํ
์๋ฅผ ๋ง๋ญ๋๋ค. ํ ๊ฐ์ง ํฅ๋ฏธ๋ก์ด ์ฐจ์ด์ ์ ๋งค๊ฐ๋ณ์๋ฅผ ๋ฑ๋กํ๋
๋ฐฉ๋ฒ์
๋๋ค. ํ์ด์ฌ์์๋ ํ
์๋ฅผ torch.nn
์ผ๋ก ๊ฐ์ธ๋ ๊ฒ๊ณผ ๋ฌ๋ฆฌ,
C++์์๋ register_parameter
๋ฉ์๋๋ฅผ ํตํด ํ
์๋ฅผ ์ ๋ฌํด์ผ
ํฉ๋๋ค. ์ด๋ฌํ ์ฐจ์ด์ ์์ธ์ ํ์ด์ฌ API์ ๊ฒฝ์ฐ, ์ด๋ค ์์ฑ(attirbute)์ด
torch.nn.Parameter
ํ์
์ธ์ง ๊ฐ์งํด ๊ทธ๋ฌํ ํ
์๋ฅผ ์๋์ผ๋ก ๋ฑ๋กํ ์ ์๊ธฐ
๋๋ฌธ์ ๋ํ๋ฉ๋๋ค. C++์์๋ ๋ฆฌํ๋ ์
(reflection)์ด ๋งค์ฐ ์ ํ์ ์ด๋ฏ๋ก ๋ณด๋ค
์ ํต์ ์ธ (๊ทธ๋ฆฌํ์ฌ ๋ ๋ง๋ฒ์ ์ธ) ๋ฐฉ์์ด ์ ๊ณต๋ฉ๋๋ค.
๋งค๊ฐ๋ณ์ ๋ฑ๋ก๊ณผ ๋ง์ฐฌ๊ฐ์ง ๋ฐฉ๋ฒ์ผ๋ก ์๋ธ๋ชจ๋์ ๋ฑ๋กํ ์ ์์ต๋๋ค. ํ์ด์ฌ์์ ์๋ธ๋ชจ๋์ ์ด๋ค ๋ชจ๋์ ์์ฑ์ผ๋ก ์ง์ ๋ ๋ ์๋์ผ๋ก ๊ฐ์ง๋๊ณ ๋ฑ๋ก๋ฉ๋๋ค.
class Net(torch.nn.Module):
def __init__(self, N, M):
super(Net, self).__init__()
# Registered as a submodule behind the scenes
self.linear = torch.nn.Linear(N, M)
self.another_bias = torch.nn.Parameter(torch.rand(M))
def forward(self, input):
return self.linear(input) + self.another_bias
์๋ฅผ ๋ค์ด, parameters()
๋ฉ์๋๋ฅผ ์ฌ์ฉํ๋ฉด ๋ชจ๋ ๊ณ์ธต์ ๋ชจ๋ ๋งค๊ฐ๋ณ์์
์ฌ๊ท์ ์ผ๋ก ์ก์ธ์คํ ์ ์์ต๋๋ค.
>>> net = Net(4, 5)
>>> print(list(net.parameters()))
[Parameter containing:
tensor([0.0808, 0.8613, 0.2017, 0.5206, 0.5353], requires_grad=True), Parameter containing:
tensor([[-0.3740, -0.0976, -0.4786, -0.4928],
[-0.1434, 0.4713, 0.1735, -0.3293],
[-0.3467, -0.3858, 0.1980, 0.1986],
[-0.1975, 0.4278, -0.1831, -0.2709],
[ 0.3730, 0.4307, 0.3236, -0.0629]], requires_grad=True), Parameter containing:
tensor([ 0.2038, 0.4638, -0.2023, 0.1230, -0.0516], requires_grad=True)]
C++์์ torch::nn::Linear
๋ฑ์ ๋ชจ๋์ ์๋ธ๋ชจ๋๋ก ๋ฑ๋กํ๋ ค๋ฉด ์ด๋ฆ์์
์ ์ถํ ์ ์๋ฏ์ด register_module()
๋ฉ์๋๋ฅผ ์ฌ์ฉํฉ๋๋ค.
struct Net : torch::nn::Module {
Net(int64_t N, int64_t M)
: linear(register_module("linear", torch::nn::Linear(N, M))) {
another_bias = register_parameter("b", torch::randn(M));
}
torch::Tensor forward(torch::Tensor input) {
return linear(input) + another_bias;
}
torch::nn::Linear linear;
torch::Tensor another_bias;
};
Tip
torch::nn
์ ๋ํ ์ด ๋ฌธ์
์์ torch::nn::Linear
, torch::nn::Dropout
, torch::nn::Conv2d
๋ฑ ์ฌ์ฉ ๊ฐ๋ฅํ ์ ์ฒด ๋นํธ์ธ ๋ชจ๋ ๋ชฉ๋ก์ ํ์ธํ ์
์์ต๋๋ค.
์ ์ฝ๋์์ ํ ๊ฐ์ง ๋ฏธ๋ฌํ ์ฌ์ค์ ์๋ธ๋ชจ๋์ ์์ฑ์์ ์ด๋์
๋ผ์ด์
๋ชฉ๋ก์ ์์ฑ๋๊ณ ๋งค๊ฐ๋ณ์๋ ์์ฑ์์ ๋ฐ๋(body)์ ์์ฑ๋์๋ค๋
๊ฒ์
๋๋ค. ์ฌ๊ธฐ์๋ ์ถฉ๋ถํ ์ด์ ๊ฐ ์์ผ๋ฉฐ ์๋ C++ ํ๋ก ํธ์๋์
์ค๋์ญ ๋ชจ๋ธ ์น์
์์ ๋ ๋ค๋ฃฐ ์์ ์
๋๋ค. ๊ทธ๋ ์ง๋ง ์ต์ข
๊ฒฐ๋ก ์
ํ์ด์ฌ์์์ฒ๋ผ ๋ชจ๋ ํธ๋ฆฌ์ ๋งค๊ฐ๋ณ์์ ์ฌ๊ท์ ์ผ๋ก ์ก์ธ์คํ ์
์๋ค๋ ๊ฒ์
๋๋ค. parameters()
๋ฅผ ํธ์ถํ๋ฉด ์ํ๊ฐ ๊ฐ๋ฅํ
std::vector<torch::Tensor>
๊ฐ ๋ฐํ๋ฉ๋๋ค.
int main() {
Net net(4, 5);
for (const auto& p : net.parameters()) {
std::cout << p << std::endl;
}
}
์ด๋ฅผ ์คํํ ๊ฒฐ๊ณผ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
root@fa350df05ecf:/home/build# ./dcgan
0.0345
1.4456
-0.6313
-0.3585
-0.4008
[ Variable[CPUFloatType]{5} ]
-0.1647 0.2891 0.0527 -0.0354
0.3084 0.2025 0.0343 0.1824
-0.4630 -0.2862 0.2500 -0.0420
0.3679 -0.1482 -0.0460 0.1967
0.2132 -0.1992 0.4257 0.0739
[ Variable[CPUFloatType]{5,4} ]
0.01 *
3.6861
-10.1166
-45.0333
7.9983
-20.0705
[ Variable[CPUFloatType]{5} ]
ํ์ด์ฌ์์์ ๊ฐ์ด ์ธ ๊ฐ์ ๋งค๊ฐ๋ณ์๊ฐ ์ถ๋ ฅ๋์ต๋๋ค. ์ด ๋งค๊ฐ๋ณ์๋ค์ ์ด๋ฆ์
ํ์ธํ ์ ์๋๋ก C++ API๋ named_parameters()
๋ฉ์๋๋ฅผ ์ ๊ณตํ๋ฉฐ, ์ด๋
ํ์ด์ฌ์์์ ๊ฐ์ด Orderdict
๋ฅผ ๋ฐํํฉ๋๋ค.
Net net(4, 5);
for (const auto& pair : net.named_parameters()) {
std::cout << pair.key() << ": " << pair.value() << std::endl;
}
๋ง์ฐฌ๊ฐ์ง๋ก ์ฝ๋๋ฅผ ์คํํ๋ฉด ๊ฒฐ๊ณผ๋ ์๋์ ๊ฐ์ต๋๋ค.
root@fa350df05ecf:/home/build# make && ./dcgan 11:13:48
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcgan
b: -0.1863
-0.8611
-0.1228
1.3269
0.9858
[ Variable[CPUFloatType]{5} ]
linear.weight: 0.0339 0.2484 0.2035 -0.2103
-0.0715 -0.2975 -0.4350 -0.1878
-0.3616 0.1050 -0.4982 0.0335
-0.1605 0.4963 0.4099 -0.2883
0.1818 -0.3447 -0.1501 -0.0215
[ Variable[CPUFloatType]{5,4} ]
linear.bias: -0.0250
0.0408
0.3756
-0.2149
-0.3636
[ Variable[CPUFloatType]{5} ]
Note
torch::nn::Module
์ ๋ํ ๋ฌธ์ ๋
๋ชจ๋ ๊ณ์ธต ๊ตฌ์กฐ์ ๋ํ ๋ฉ์๋ ๋ชฉ๋ก ์ ์ฒด๊ฐ ํฌํจ๋์ด
์์ต๋๋ค.
๋คํธ์ํฌ๋ฅผ C++๋ก ์คํํ๊ธฐ ์ํด์๋, ์ฐ๋ฆฌ๊ฐ ์ ์ํ forward()
๋ฉ์๋๋ฅผ
ํธ์ถํ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค.
int main() {
Net net(4, 5);
std::cout << net.forward(torch::ones({2, 4})) << std::endl;
}
์ถ๋ ฅ์ ๋๋ต ์๋์ ๊ฐ์ ๊ฒ์ ๋๋ค
root@fa350df05ecf:/home/build# ./dcgan
0.8559 1.1572 2.1069 -0.1247 0.8060
0.8559 1.1572 2.1069 -0.1247 0.8060
[ Variable[CPUFloatType]{2,5} ]
์ด์ ์ฐ๋ฆฌ๋ C++์์ ๋ชจ๋์ ์ ์ํ๊ณ , ๋งค๊ฐ๋ณ์๋ฅผ ๋ฑ๋กํ๊ณ , ํ์ ๋ชจ๋์
๋ฑ๋กํ๊ณ , parameters()
๋ฑ์ ๋ฉ์๋๋ฅผ ํตํด ๋ชจ๋ ๊ณ์ธต์ ํ์ํ๊ณ ,
๋ชจ๋์ forward()
๋ฉ์๋๋ฅผ ์คํํ๋ ๋ฐฉ๋ฒ์ ๋ฐฐ์ ์ต๋๋ค. C++ API์๋
๋ค๋ฅธ ๋ฉ์๋, ํด๋์ค, ๊ทธ๋ฆฌ๊ณ ์ฃผ์ ๊ฐ ๋ง์ง๋ง ์ ์ฒด ๋ชฉ๋ก์ ๋ฌธ์ ๋ฅผ
์ฐธ์กฐํ์๊ธฐ ๋ฐ๋๋๋ค. ์ ์ ํ์ DCGAN ๋ชจ๋ธ๊ณผ ์๋ ํฌ ์๋ ํ์ต
ํ์ดํ๋ผ์ธ์ ๊ตฌํํ๋ฉด์๋ ๋ช ๊ฐ์ง ๊ฐ๋
์ ๋ ๋ค๋ฃฐ ์์ ์
๋๋ค. ๊ทธ์ ์์
C++ ํ๋ก ํธ์๋์์ torch::nn::Module
์ ํ์ ํด๋์ค๋ค์ ๋ํด ์ ๊ณตํ๋
์ค๋์ญ ๋ชจ๋ธ ์ ๋ํด ๊ฐ๋จํ ์ค๋ช
ํ๊ฒ ์ต๋๋ค.
์ด ๋ ผ์์์ ์ค๋์ญ ๋ชจ๋ธ์ด๋ ๋ชจ๋์ ์ ์ฅํ๊ณ ์ ๋ฌํ๋ ๋ฐฉ์ (๋๊ฐ ํน์ ๋ฌด์์ด ํน์ ๋ชจ๋ ์ธ์คํด์ค๋ฅผ ์์ ํ๋์ง)์ ์ง์นญํฉ๋๋ค. ํ์ด์ฌ์์ ๊ฐ์ฒด๋ ํญ์ ํ์ ๋์ ์ผ๋ก ํ ๋น๋๋ฉฐ ๋ ํผ๋ฐ์ค ์๋งจํฑ์ ๊ฐ์ง๋๋ฐ, ์ด๋ ๋ค๋ฃจ๊ณ ์ดํดํ๊ธฐ๊ฐ ๋งค์ฐ ์ฝ์ต๋๋ค. ์ค์ ๋ก ํ์ด์ฌ์์๋ ๊ฐ์ฒด๊ฐ ์ด๋์ ์กด์ฌํ๊ณ ์ด๋ป๊ฒ ๋ ํผ๋ฐ์ค๋๋์ง ์ ๊ฒฝ ์ฐ์ง ์๊ณ ํ๋ ค๋ ์ผ์๋ง ์ง์คํ ์ ์์ต๋๋ค.
์ ๊ธ ์ธ์ด์ธ C++๋ ์ด ๋ถ๋ถ์์ ๋ ๋ง์ ์ต์
์ ์ ๊ณตํฉ๋๋ค. ์ด๋
C++ ํ๋ก ํธ์๋์ ๋ณต์ก์ฑ์ ์ฆ๊ฐ์ํค๋ฉฐ ๊ทธ ์ค๊ณ์ ์ธ์ฒด๊ณตํ์ ์์์๋
ํฐ ์ํฅ์ ์ค๋๋ค. ํนํ, C++ ํ๋ก ํธ์๋ ๋ชจ๋์์๋ ๋ฐธ๋ฅ ์๋งจํฑ
๋๋ ๋ ํผ๋ฐ์ค ์๋งจํฑ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ ์๊ฐ ์ง๊ธ๊น์ง์
์ฌ๋ก์์ ์ดํด๋ณธ ๊ฐ์ฅ ๋จ์ํ ๊ฒฝ์ฐ๋ก, ๋ชจ๋ ๊ฐ์ฒด๊ฐ ์คํ์ ํ ๋น๋๊ณ
ํจ์์ ์ ๋ฌ๋ ๋ ๋ ํผ๋ฐ์ค ํน์ ํฌ์ธํฐ๋ก ๋ณต์ฌ ๋ฐ ์ด๋(std:move
)
์ํค๊ฑฐ๋ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค.
struct Net : torch::nn::Module { };
void a(Net net) { }
void b(Net& net) { }
void c(Net* net) { }
int main() {
Net net;
a(net);
a(std::move(net));
b(net);
c(&net);
}
ํ์(๋ ํผ๋ฐ์ค ์๋งจํฑ)์ ๊ฒฝ์ฐ, std::shared_ptr
๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
๋ชจ๋ ๊ณณ์์ shared_ptr
๋ฅผ ์ฌ์ฉํ๋ค๋ ๊ฐ์ ํ์, ๋ ํผ๋ฐ์ค ์๋งจํฑ์
์ฅ์ ์ ํ์ด์ฌ์์์ ๊ฐ์ด ๋ชจ๋์ด ํจ์์ ์ ๋ฌ๋๊ณ ์ธ์๊ฐ ์ ์ธ๋๋ ๋ฐฉ์์
๋ํด ์๊ฐํ ๋ถ๋ด์ ๋์ด์ค๋ค๋ ๊ฒ์
๋๋ค.
struct Net : torch::nn::Module {};
void a(std::shared_ptr<Net> net) { }
int main() {
auto net = std::make_shared<Net>();
a(net);
}
๊ฒฝํ์ ์ผ๋ก, ๋์ ์ธ์ด๋ฅผ ์ฌ์ฉํ๋ ์ฐ๊ตฌ์๋ค์ ๋น๋ก ๋ฐธ๋ฅ ์๋งจํฑ์ด
๋ C++์ "๋ค์ดํฐ๋ธ"ํจ์๋ ๋ถ๊ตฌํ๊ณ ๋ ํผ๋ฐ์ค ์๋งจํฑ์ ํจ์ฌ
์ ํธํฉ๋๋ค. ๋ํ torch::nn::Module
์ ์ค๊ณ๋
์ฌ์ฉ์ ์นํ์ ์ธ ํ์ด์ฌ API๋ฅผ ์ ์ฌํ๊ฒ ๋ฐ๋ฅด๊ธฐ ์ํด shared ์ค๋์ญ์
์์กดํฉ๋๋ค. ์์ ์์๋ก ๋ค์๋ Net
์ ์ ์๋ฅผ ์ถ์ฝํด์ ๋ค์
์ดํด๋ด
์๋ค.
struct Net : torch::nn::Module {
Net(int64_t N, int64_t M)
: linear(register_module("linear", torch::nn::Linear(N, M)))
{ }
torch::nn::Linear linear;
};
ํ์ ๋ชจ๋์ธ linear
๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด ์ด๋ฅผ ํด๋์ค์ ์ง์ ์ ์ฅํ๊ณ ์
ํฉ๋๋ค. ๊ทธ๋ฌ๋ ๋์์ ๋ชจ๋์ ๊ธฐ์ด ํด๋์ค๊ฐ ์ด ํ์ ๋ชจ๋์ ๋ํด ์๊ณ ์ ๊ทผํ
์ ์๊ธฐ๋ฅผ ์ํฉ๋๋ค. ์ด๋ฅผ ์ํด์๋ ํด๋น ํ์ ๋ชจ๋์ ๋ํ ์ฐธ์กฐ๋ฅผ ์ ์ฅํด์ผ ํฉ๋๋ค.
์ด ์๊ฐ ์ด๋ฏธ ์ฐ๋ฆฌ๋ shared ์ค๋์ญ์ ํ์๋ก ํฉ๋๋ค. torch::nn::Module
ํด๋์ค์ ๊ตฌ์ ํด๋์ค์ธ Net
๋ชจ๋์์ ํ์ ๋ชจ๋์ ๋ํ ๋ ํผ๋ฐ์ค๊ฐ
ํ์ํฉ๋๋ค. ๋ฐ๋ผ์ ๊ธฐ์ด ํด๋์ค๋ ๋ชจ๋์ shared_ptr
๋ก ์ ์ฅํ๋ฉฐ ์ด์
๋ฐ๋ผ ๊ตฌ์ ํด๋์ค ๋ํ ๋ง์ฐฌ๊ฐ์ง์ผ ๊ฒ์
๋๋ค.
ํ์ง๋ง ์ ๊น! ์์ ์ฝ๋์๋ shared_ptr
์ ๋ํ ์ธ๊ธ์ด ์์ต๋๋ค! ์ ๊ทธ๋ฐ
๊ฒ์ผ๊น์? ์๋ํ๋ฉด std::shared_ptr<MyModule>
๋ ํ์ดํํ๊ธฐ์ ๋๋ฌด ๊ธธ๊ธฐ ๋๋ฌธ์
๋๋ค.
์ฐ๊ตฌ์๋ค์ ์์ฐ์ฑ์ ์ ์งํ๊ธฐ ์ํด, ์ฐ๋ฆฌ๋ ๋ ํผ๋ฐ์ค ์๋งจํฑ์ ์ ์งํ๋ฉด์ ๋ฐธ๋ฅ
์๋งจํฑ๋ง์ ์ฅ์ ์ธ shared_ptr
์ ๋ํ ์ธ๊ธ์ ์จ๊ธฐ๊ธฐ ์ํ ์ ๊ตํ ๊ณํ์
์ธ์ ์ต๋๋ค. ๊ทธ ์๋ ๋ฐฉ์์ ์ดํดํ๊ธฐ ์ํด ์ฝ์ด ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์๋ torch::nn::Linear
๋ชจ๋์ ๋จ์ํ๋ ์ ์๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค. (์ ์ฒด ์ ์๋
์ฌ๊ธฐ ์์
ํ์ธํ ์ ์์ต๋๋ค.)
struct LinearImpl : torch::nn::Module {
LinearImpl(int64_t in, int64_t out);
Tensor forward(const Tensor& input);
Tensor weight, bias;
};
TORCH_MODULE(Linear);
์์ฝํ์๋ฉด ์ด ๋ชจ๋์ Linear
๊ฐ ์๋ LinearImpl
์ด๋ผ๊ณ ๋ถ๋ฆฝ๋๋ค. ๊ทธ๋ฆฌ๊ณ
TORCH_MODULE
๋ผ๋ ๋งคํฌ๋ก๊ฐ ์ค์ Linear
ํด๋์ค๋ฅผ ์ ์ํฉ๋๋ค. ์ด๋ ๊ฒ "์์ฑ๋"
ํด๋์ค๋ std::shared_ptr<LinearImpl>
๋ฅผ ๊ฐ์ธ๋ ๋ํผ(wrapper)์
๋๋ค.
๋จ์ํ typedef๊ฐ ์๋ ๋ํผ์ด๋ฏ๋ก ์์ฑ์๋ ์ฌ์ ํ ์์ํ๋ ๋๋ก ์๋ํฉ๋๋ค.
์ฆ, std::make_shared<LinearImpl>(3, 4)
๊ฐ ์๋ torch::nn::Linear(3, 4)
๋ผ๊ณ ์ธ ์ ์์ต๋๋ค. ์ด๋ ๊ฒ ๋งคํฌ๋ก์ ์ํด ์์ฑ๋ ํด๋์ค๋ holder ๋ชจ๋์ด๋ผ๊ณ
๋ถ๋ฆ
๋๋ค. (shared) ํฌ์ธํฐ์ ๋ง์ฐฌ๊ฐ์ง๋ก ํ์ดํ ์ฐ์ฐ์(์ฆ,
model->forward(...)
)๋ฅผ ์ฌ์ฉํด ๊ธฐ์ ๊ฐ์ฒด์ ์ก์ธ์คํฉ๋๋ค.
๊ฒฐ๋ก ์ ์ผ๋ก ํ์ด์ฌ API์ ๋งค์ฐ ์ ์ฌํ ์ค๋์ญ ๋ชจ๋ธ์ ์ป์์ต๋๋ค.
๊ธฐ๋ณธ์ ์ผ๋ก ๋ ํผ๋ฐ์ค ์๋งจํฑ์ ๋ฐ๋ฅด์ง๋ง, std:shared_ptr
๋
std::make_shared
๋ฑ์ ํ์ดํํ ํ์๊ฐ ์์ต๋๋ค. ์ฐ๋ฆฌ์ Net
์์์์
๋ชจ๋ holder API๋ฅผ ์ฌ์ฉํ๋ฉด ์๋์ ๊ฐ์ต๋๋ค.
struct NetImpl : torch::nn::Module {};
TORCH_MODULE(Net);
void a(Net net) { }
int main() {
Net net;
a(net);
}
์ฌ๊ธฐ์ ์ธ๊ธํ ๋งํ ๋ฏธ๋ฌํ ๋ฌธ์ ๊ฐ ํ๋ ์์ต๋๋ค. ๊ธฐ๋ณธ ์์ฑ์์ ์ํด ๋ง๋ค์ด์ง
std::shared_ptr
๋ "๋น์ด" ์์ต๋๋ค. ์ฆ, null ํฌ์ธํฐ์
๋๋ค. ๊ธฐ๋ณธ ์์ฑ์๋ก
๋ง๋ค์ด์ง Linear
์ด๋ Net
์ ๋ฌด์์ด์ด์ผ ํ ๊น์? ์, ์ด๊ฑด ์ด๋ ค์ด ๊ฒฐ์ ์
๋๋ค.
๋น (null) std::shared_ptr<LinearImpl>
๋ก ์ ํ ์ ์์ต๋๋ค. ํ์ง๋ง
Linear(3, 4)
๊ฐ std::make_shared<LinearImpl>(3, 4)
์ ๊ฐ๋ค๋ ๊ฒ์ ๊ธฐ์ตํฉ์๋ค.
์ฆ, Linear linear;
์ด null ํฌ์ธํฐ์ฌ์ผ ํ๋ค๊ณ ๊ฒฐ์ ํ๋ค๋ฉด
์์ฑ์์์ ์ธ์๋ฅผ ์ ํ ๋ฐ์ง ์๊ฑฐ๋ ๋ชจ๋ ์ธ์์ ๋ํด ๊ธฐ๋ณธ๊ฐ์ ์ฌ์ฉํ๋
๋ชจ๋์ ์์ฑํ ๋ฐฉ๋ฒ์ด ์์ด์ง๋๋ค. ์ด๋ฌํ ์ด์ ๋ก ํ์ฌ
API์์ ๊ธฐ๋ณธ ์์ฑ์์ ์ํด ๋ง๋ค์ด์ง ๋ชจ๋ holder(Linear()
๋ฑ)๋
๊ธฐ์ ๋ชจ๋(LinearImpl()
)์ ๊ธฐ๋ณธ ์์ฑ์๋ฅผ ํธ์ถํฉ๋๋ค. ๋ง์ฝ
๊ธฐ์ ๋ชจ๋์ ๊ธฐ๋ณธ ์์ฑ์๊ฐ ์์ผ๋ฉด ์ปดํ์ผ๋ฌ ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค.
๋ฐ๋๋ก ๋น holder๋ฅผ ์์ฑํ๋ ค๋ฉด holder ์์ฑ์์ nullptr
๋ฅผ
์ ๋ฌํ๋ฉด ๋ฉ๋๋ค.
์ค์ ๋ก๋ ์์์์ ๊ฐ์ด ํ์ ๋ชจ๋์ ์ฌ์ฉํด ๋ชจ๋์ ์ด๋์ ๋ผ์ด์ (initializer) ๋ชฉ๋ก ์ ๋ฑ๋ก ๋ฐ ์์ฑํ๊ฑฐ๋,
struct Net : torch::nn::Module {
Net(int64_t N, int64_t M)
: linear(register_module("linear", torch::nn::Linear(N, M)))
{ }
torch::nn::Linear linear;
};
ํ์ด์ฌ ์ฌ์ฉ์๋ค์๊ฒ ๋ ์น์ํ ๋ฐฉ๋ฒ์ผ๋ก, ๋จผ์ null ํฌ์ธํฐ๋ก ํ๋๋ฅผ ์์ฑํ ์ดํ ์์ฑ์์์ ๊ฐ์ ์ง์ ํ ์ ์์ต๋๋ค.
struct Net : torch::nn::Module {
Net(int64_t N, int64_t M) {
linear = register_module("linear", torch::nn::Linear(N, M));
}
torch::nn::Linear linear{nullptr}; // construct an empty holder
};
๊ฒฐ๋ก ์ ์ผ๋ก ์ด๋ค ์ค๋์ญ ๋ชจ๋ธ, ์ด๋ค ์๋งจํฑ์ ์ฌ์ฉํ๋ฉด ์ข์๊น์? C++
ํ๋ก ํธ์๋ API๋ ๋ชจ๋ holder๊ฐ ์ ๊ณตํ๋ ์ค๋์ญ ๋ชจ๋ธ์ ๊ฐ์ฅ ์ ์ง์ํฉ๋๋ค.
์ด ๋ฉ์ปค๋์ฆ์ ์ ์ผํ ๋จ์ ์ ๋ชจ๋ ์ ์ธ ์๋์ boilerplate ํ ์ค์ด
์ถ๊ฐ๋๋ค๋ ๊ฒ์
๋๋ค. ์ฆ, ๊ฐ์ฅ ๋จ์ํ ๋ชจ๋ธ์ C++ ๋ชจ๋์ ๊ธฐ์ด๋ฅผ ๋ฐฐ์ธ ๋
๋์ค๋ ๋ฐธ๋ฅ ์๋งจํฑ ๋ชจ๋ธ์
๋๋ค. ์๊ณ ๊ฐ๋จํ ์คํฌ๋ฆฝํธ์ ๊ฒฝ์ฐ,
์ด๊ฒ๋ง์ผ๋ก ์ถฉ๋ถํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ธ์ ๊ฐ๋ ๊ธฐ์ ์ ์ด์ ๋ก ์ธํด
์ด ๊ธฐ๋ฅ์ด ํญ์ ์ง์๋์ง๋ ์๋๋ค๋ ์ฌ์ค์ ์๊ฒ ๋ ๊ฒ์
๋๋ค. ์๋ฅผ ๋ค์ด ์ง๋ ฌํ
API(torch::save
๋ฐ torch::load
)๋ ๋ชจ๋ holder(ํน์ ์ผ๋ฐ
shared_ptr
)๋ง์ ์ง์ํฉ๋๋ค. ๋ฐ๋ผ์ C++ ํ๋ก ํธ์๋๋ก ๋ชจ๋์
์ ์ํ ๋์๋ ๋ชจ๋ holder API ๋ฐฉ์์ด ๊ถ์ฅ๋๋ฉฐ, ์์ผ๋ก ๋ณธ ํํ ๋ฆฌ์ผ์์
์ด API๋ฅผ ์ฌ์ฉํ๊ฒ ์ต๋๋ค.
์ด์ ์ด ๊ธ์์ ํด๊ฒฐํ๋ ค๋ ๋จธ์ ๋ฌ๋ ํ์คํฌ๋ฅผ ์ํ ๋ชจ๋์ ์ ์ํ๋๋ฐ ํ์ํ ๋ฐฐ๊ฒฝ๊ณผ ๋์ ๋ถ ์ค๋ช ์ด ๋๋ฌ์ต๋๋ค. ๋ค์ ์๊ธฐํ์๋ฉด, ์ฐ๋ฆฌ์ ํ์คํฌ๋ MNIST ๋ฐ์ดํฐ์ ์ ์ซ์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๊ฒ์ ๋๋ค. ์ฐ๋ฆฌ๋ ์ด ํ์คํฌ๋ฅผ ํ๊ธฐ ์ํด ์ ๋์ ์์ฑ ์ ๊ฒฝ๋ง(GAN) ์ ์ฌ์ฉํ๊ณ ์ ํฉ๋๋ค. ๊ทธ ์ค์์๋ ์ฐ๋ฆฌ๋ DCGAN ์ํคํ ์ฒ ๋ฅผ ์ฌ์ฉํ ๊ฒ์ ๋๋ค. DCGAN์ ๊ฐ์ฅ ์ด๊ธฐ์ ๋ฐํ๋๋ ์ ์ผ ๊ฐ๋จํ GAN์ด์ง๋ง ์ด ํ์คํฌ๋ฅผ ์ํด์๋ ์ถฉ๋ถํฉ๋๋ค.
Tip
์ด ํํ ๋ฆฌ์ผ์ ๋์จ ์์ค ์ฝ๋ ์ ์ฒด๋ ์ด ์ ์ฅ์ ์์ ํ์ธํ ์ ์์ต๋๋ค.
GAN์ ์์ฑ๊ธฐ(generator) ์ ํ๋ณ๊ธฐ(discriminator) ๋ผ๋
๋ ๊ฐ์ง ์ ๊ฒฝ๋ง ๋ชจ๋ธ๋ก ๊ตฌ์ฑ๋ฉ๋๋ค. ์์ฑ๊ธฐ๋ ๋
ธ์ด์ฆ ๋ถํฌ์์ ์ํ์ ์
๋ ฅ๋ฐ๊ณ ,
๊ฐ ๋
ธ์ด์ฆ ์ํ์ ๋ชฉํ ๋ถํฌ(์ด ๊ฒฝ์ฐ MNIST ๋ฐ์ดํฐ์
)์ ์ ์ฌํ ์ด๋ฏธ์ง๋ก
๋ณํํ๋ ๊ฒ์ด ๋ชฉํ์
๋๋ค. ํ๋ณ๊ธฐ๋ MNIST ๋ฐ์ดํฐ์
์ ์ง์ง
์ด๋ฏธ์ง๋ฅผ ์
๋ ฅ๋ฐ๊ฑฐ๋ ์์ฑ๊ธฐ๋ก๋ถํฐ ๊ฐ์ง ์ด๋ฏธ์ง๋ฅผ ์
๋ ฅ๋ฐ์ต๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ด๋ค ์ด๋ฏธ์ง๊ฐ ์ผ๋ง๋ ์ง์ง๊ฐ์์ง (1
์ ๊ฐ๊น์ด ์ถ๋ ฅ)
ํน์ ๊ฐ์ง๊ฐ์ ์ง (0
์ ๊ฐ๊น์ด ์ถ๋ ฅ) ํ๋ณํฉ๋๋ค. ์์ฑ๊ธฐ๊ฐ
๋ง๋ ์ด๋ฏธ์ง๊ฐ ์ผ๋ง๋ ์ง์ง๊ฐ์ ์ง ํ๋ณ๊ธฐ๊ฐ ํผ๋๋ฐฑํ๊ณ ์ด ํผ๋๋ฐฑ์ ์์ฑ๊ธฐ
ํ์ต์ ์ฌ์ฉ๋ฉ๋๋ค. ํ๋ณ๊ธฐ๊ฐ ์ง์ง์ ๋ํ ์๋ชฉ์ด ์ผ๋ง๋ ์ข์ ์ง์
๋ํ ํผ๋๋ฐฑ์ ํ๋ณ๊ธฐ๋ฅผ ์ต์ ํํ๊ธฐ ์ํด ์ฌ์ฉ๋ฉ๋๋ค. ์ด๋ก ์ ์ผ๋ก,
์์ฑ๊ธฐ์ ํ๋ณ๊ธฐ ์ฌ์ด์ ์ฌ์ธํ ๊ท ํ์ ์ด ๋์ ๋์์ ๊ฐ์ ์ํต๋๋ค.
์ด๋ฅผ ํตํด ์์ฑ๊ธฐ๋ ๋ชฉํ ๋ถํฌ์ ๊ตฌ๋ณํ ์ ์๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ณ ,
(๊ทธ๋์ฏค์ด๋ฉด) ์ ํ์ต๋์ด ์์ ํ๋ณ๊ธฐ์ ์๋ชฉ์ ์์ฌ ์ง์ง์ ๊ฐ์ง
์ด๋ฏธ์ง ๋ชจ๋์ ๋ํด 0.5
์ ํ๋ฅ ์ ์ถ๋ ฅํ ๊ฒ์
๋๋ค. ์ต์ข
๊ฒฐ๊ณผ๋ฌผ์ ๋
ธ์ด์ฆ๋ฅผ ์
๋ ฅ๋ฐ์ ์ค์ ์ซ์์ ์ด๋ฏธ์ง๋ฅผ ์ถ๋ ฅ์ผ๋ก ์์ฑํ๋
๊ธฐ๊ณ์
๋๋ค.
๋จผ์ ์ผ๋ จ์ ์ ์น๋ (transposed) 2D ํฉ์ฑ๊ณฑ, ๋ฐฐ์น ์ ๊ทํ ๋ฐ
ReLU ํ์ฑํ ์ ๋์ผ๋ก ๊ตฌ์ฑ๋ ์์ฑ๊ธฐ ๋ชจ๋์ ์ ์ํ๊ฒ ์ต๋๋ค.
๋ชจ๋์ forward()
๋ฉ์๋๋ฅผ ์ง์ ์ ์ํ์ฌ ๋ชจ๋ ๊ฐ ์
๋ ฅ์
(ํจ์ํ์ผ๋ก) ๋ช
์์ ์ผ๋ก ์ ๋ฌํฉ๋๋ค.
struct DCGANGeneratorImpl : nn::Module {
DCGANGeneratorImpl(int kNoiseSize)
: conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
.bias(false)),
batch_norm1(256),
conv2(nn::ConvTranspose2dOptions(256, 128, 3)
.stride(2)
.padding(1)
.bias(false)),
batch_norm2(128),
conv3(nn::ConvTranspose2dOptions(128, 64, 4)
.stride(2)
.padding(1)
.bias(false)),
batch_norm3(64),
conv4(nn::ConvTranspose2dOptions(64, 1, 4)
.stride(2)
.padding(1)
.bias(false))
{
// register_module() is needed if we want to use the parameters() method later on
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv3", conv3);
register_module("conv4", conv4);
register_module("batch_norm1", batch_norm1);
register_module("batch_norm2", batch_norm2);
register_module("batch_norm3", batch_norm3);
}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(batch_norm1(conv1(x)));
x = torch::relu(batch_norm2(conv2(x)));
x = torch::relu(batch_norm3(conv3(x)));
x = torch::tanh(conv4(x));
return x;
}
nn::ConvTranspose2d conv1, conv2, conv3, conv4;
nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);
DCGANGenerator generator(kNoiseSize);
์ด์ DCGANGenerator
์ forward()
๋ฅผ ํธ์ถํด ๋
ธ์ด์ฆ ์ํ์ ์ด๋ฏธ์ง์ ๋งคํํ ์ ์์ต๋๋ค.
์ฌ๊ธฐ์ ์ฌ์ฉํ nn::ConvTranspose2d
๋ฐ nn::BatchNorm2d
๋ฑ์ ๋ชจ๋์
์์ ์ค๋ช
ํ ๊ตฌ์กฐ๋ฅผ ๋ฐ๋ฆ
๋๋ค. ์์ kNoiseSize
๋ ์
๋ ฅ ๋
ธ์ด์ฆ ๋ฒกํฐ์ ํฌ๊ธฐ๋ฅผ
๊ฒฐ์ ํ๋ฉฐ 100
์ผ๋ก ์ค์ ๋ฉ๋๋ค. ํ์ดํผํ๋ผ๋ฏธํฐ๋ ๋ฌผ๋ก ๋ํ์์๋ค์ ๋ง์ ๋
ธ๋ ฅ์
ํตํด ์ธํ
๋์ต๋๋ค.
Attention!
ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ ํ๋๋ผ ๋ค์น ๋ํ์์์ ์์์ต๋๋ค. ๊ทธ๋ค์ ์๋ก์๋ก ๊ฐ์ฌ๋ฃ๋ฅผ ๋จน์ด๋๊น์.
Note
C++ ํ๋ก ํธ์๋์ Conv2d
์ ๊ฐ์ ๊ธฐ๋ณธ ์ ๊ณต ๋ชจ๋์ ์ต์
์ด ์ ๋ฌ๋๋ ๋ฐฉ๋ฒ์ ๋ํ
๊ฐ๋จํ ์ค๋ช
ํ์๋ฉด, ๋ชจ๋ ๋ชจ๋์ ๋ช ๊ฐ์ง ํ์ ์ต์
์ ๊ฐ๊ณ ์์ต๋๋ค. (์: BatchNorm2d
์
feature ๊ฐ์) ๋ง์ฝ BatchNorm2d(128)
, Dropout(0.5)
, Conv2d(8, 4, 2)
์
๊ฐ์ด ํ์ ์ต์
๋ง ์ค์ ํ๋ ค ํ๋ค๋ฉด ๋ชจ๋ ์์ฑ์์ ์ง์ ์ ๋ฌํ ์ ์์ต๋๋ค.
(์ฌ๊ธฐ์๋ ๊ฐ๊ฐ ์
๋ ฅ ์ฑ๋ ์, ์ถ๋ ฅ ์ฑ๋ ์ ๋ฐ ์ปค๋ ํฌ๊ธฐ๋ฅผ ์๋ฏธ)
๊ทธ๋ฌ๋ ๋ง์ฝ Conv2d
์ bias
์ ๊ฐ์ด ์ผ๋ฐ์ ์ผ๋ก ๊ธฐ๋ณธ๊ฐ์ ์ฌ์ฉํ๋
๋ค๋ฅธ ์ต์
์ ์์ ํด์ผ ํ๋ ๊ฒฝ์ฐ, options ๊ฐ์ฒด๋ฅผ ์์ฑํด ์ ๋ฌํด์ผ ํฉ๋๋ค.
C++ ํ๋ก ํธ์๋์ ๋ชจ๋์ ModuleOptions
์ด๋ผ๊ณ ํ๋ ์ฐ๊ด๋ ์ต์
struct๋ฅผ
๊ฐ์ง๊ณ ์์ต๋๋ค. ์ฌ๊ธฐ์ Module
์ ํด๋น ๋ชจ๋์ ์ด๋ฆ์ผ๋ก, ์๋ฅผ ๋ค์ด Linear
์ ๊ฒฝ์ฐ LinearOptions
์ ๊ฐ์ต๋๋ค. ์ฐ๋ฆฌ๋ ์์ Conv2d
๋ชจ๋์
๋ํด ์ด๋ฅผ ์ํํ ๊ฒ์
๋๋ค.
ํ๋ณ๊ธฐ๋ ๋ง์ฐฌ๊ฐ์ง๋ก ํฉ์ฑ๊ณฑ, ๋ฐฐ์น ์ ๊ทํ ๋ฐ ํ์ฑํ์ ์ฐ์์ ๋๋ค. ํ์ง๋ง ์ด๋ฒ์ ํฉ์ฑ๊ณฑ์ ์ ์น๋์ง ์์ ๊ธฐ๋ณธ ํฉ์ฑ๊ณฑ์ด๋ฉฐ, ์ผ๋ฐ์ ReLU ๋์ ์ ์ํ ๊ฐ์ด 0.2์ธ leaky ReLU๋ฅผ ์ฌ์ฉํฉ๋๋ค. ๋ํ ์ต์ข ํ์ฑํ๋ ๊ฐ์ 0๊ณผ 1 ์ฌ์ด์ ๋ฒ์๋ก ์์ถํ๋ Sigmoid๊ฐ ๋ฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ ์ด๋ ๊ฒ ์์ถ๋ ๊ฐ์ ํ๋ณ์๊ฐ ์ด๋ฏธ์ง์ ๋ํด ์ถ๋ ฅํ๋ ํ๋ฅ ๋ก ํด์ํ ์ ์์ต๋๋ค.
ํ๋ณ๊ธฐ๋ฅผ ๋ง๋ค๊ธฐ ์ํด Sequential ๋ชจ๋์ด๋ผ๋ ๋ค๋ฅธ ๊ฒ์ ์๋ํด ๋ณด๊ฒ ์ต๋๋ค. ํ์ด์ฌ์์์ ๊ฐ์ด, PyTorch๋ ๋ชจ๋ธ ์ ์๋ฅผ ์ํด ๋ ๊ฐ์ง API๋ฅผ ์ ๊ณตํฉ๋๋ค. (์์ฑ๊ธฐ ๋ชจ๋ ์์์ ๊ฐ์ด) ์ ๋ ฅ์ด ์ฐ์์ ์ธ ํจ์๋ฅผ ํตํด ์ ๋ฌ๋๋ ํจ์ํ API์ ์ ์ฒด ๋ชจ๋ธ์ ํ์ ๋ชจ๋๋ก ํฌํจํ๋ Sequential ๋ชจ๋์ ์์ฑํ๋ ๊ฐ์ฒด ์งํฅํ API์ ๋๋ค. Sequential ์ ์ฌ์ฉํ๋ฉด ํ๋ณ๊ธฐ๋ ๋๋ต ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
nn::Sequential discriminator(
// Layer 1
nn::Conv2d(
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 2
nn::Conv2d(
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(128),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 3
nn::Conv2d(
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(256),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 4
nn::Conv2d(
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
nn::Sigmoid());
Tip
Sequential
๋ชจ๋์ ๋จ์ํ ํจ์ ํฉ์ฑ๋ง์ ์ํํฉ๋๋ค. ์ฒซ ๋ฒ์งธ ํ์ ๋ชจ๋์ ์ถ๋ ฅ์
๋ ๋ฒ์งธ ํ์ ๋ชจ๋์ ์
๋ ฅ์ด ๋๊ณ ์ธ ๋ฒ์งธ ํ์ ๋ชจ๋์ ์ถ๋ ฅ์ ๋ค ๋ฒ์งธ ํ์ ๋ชจ๋์ ์
๋ ฅ์ด
๋๊ณ ์ดํ์๋ ๋ง์ฐฌ๊ฐ์ง์
๋๋ค.
์ด์ ์์ฑ๊ธฐ์ ํ๋ณ๊ธฐ ๋ชจ๋ธ์ ์ ์ํ์ผ๋ฏ๋ก ์ด๋ฌํ ๋ชจ๋ธ์ ํ์ต์ํฌ ๋ฐ์ดํฐ๊ฐ ํ์ํฉ๋๋ค. ํ์ด์ฌ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก C++ ํ๋ก ํธ์๋๋ ๊ฐ๋ ฅํ ๋ณ๋ ฌ ๋ฐ์ดํฐ ๋ก๋(data loader)๋ฅผ ์ ๊ณตํ๋ค. ์ด ๋ฐ์ดํฐ ๋ก๋๋ ์ฌ์ฉ์๊ฐ ์ง์ ์ ์ํ ์ ์๋ ๋ฐ์ดํฐ์ ์์ ๋ฐ์ดํฐ ๋ฐฐ์น๋ฅผ ์ฝ์ ์ ์์ผ๋ฉฐ ์ค์ ์ ์ํ ๋ง์ ์ต์ ์ ์ ๊ณตํฉ๋๋ค.
Note
ํ์ด์ฌ ๋ฐ์ดํฐ ๋ก๋๊ฐ ๋ฉํฐ ํ๋ก์ธ์ฑ์ ์ฌ์ฉํ๋ ๋ฐ๋ฉด, C++ ๋ฐ์ดํฐ ๋ก๋๋ ์ค์ ๋ก ๋ฉํฐ ์ค๋ ๋ฉ์ ์ฌ์ฉํด ์ด๋ ํ ์๋ก์ด ํ๋ก์ธ์ค๋ ์์ํ์ง ์์ต๋๋ค.
๋ฐ์ดํฐ ๋ก๋๋ torch::data::
๋ค์์คํ์ด์ค์ ํฌํจ๋ C++ ํ๋ก ํธ์๋์
data
API์ ์ผ๋ถ์
๋๋ค. ์ด API๋ ๋ค์๊ณผ ๊ฐ์ ๋ช ๊ฐ์ง ์ปดํฌ๋ํธ๋ก ๊ตฌ์ฑ๋ฉ๋๋ค.
- ๋ฐ์ดํฐ ๋ก๋ ํด๋์ค
- ๋ฐ์ดํฐ์ ์ ์ ์ํ๊ธฐ ์ํ API
- ๋ณํ ์ ์ ์ํ๊ธฐ ์ํ API (๋ฐ์ดํฐ์ ์ ์ ์ฉ ๊ฐ๋ฅ)
- ์ํ๋ฌ ๋ฅผ ์ ์ํ๊ธฐ ์ํ API (๋ฐ์ดํฐ์ ์ ์ํ ์ธ๋ฑ์ค๋ฅผ ์์ฑ)
- ๊ธฐ์กด ๋ฐ์ดํฐ์ , ๋ณํ, ์ํ๋ฌ๋ค์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ
์ด ํํ ๋ฆฌ์ผ์์๋ C++ ํ๋ก ํธ์๋์ ํจ๊ป ์ ๊ณต๋๋ MNIST
๋ฐ์ดํฐ์
์
์ฌ์ฉํฉ๋๋ค. torch::data::datasets::MNIST
์ธ์คํด์ค๋ฅผ ๋ง๋ค์ด
๋ค์ ๋ ๊ฐ์ง ๋ณํ์ ์ ์ฉํด๋ด
์๋ค. ์ฒซ์งธ, ์ด๋ฏธ์ง๋ฅผ ์ ๊ทํํ์ฌ -1
๊ณผ
+1
์ฌ์ด์ ์๋๋ก ํฉ๋๋ค. (๊ธฐ์กด ๋ฒ์๋ 0
๊ณผ 1
์ฌ์ด)
๋์งธ, ํ
์ ๋ฐฐ์น(batch)๋ฅผ ์ฒซ ๋ฒ์งธ ์ฐจ์์ ๋ฐ๋ผ ๋จ์ผ ํ
์๋ก ์๋ ์ด๋ฅธ๋ฐ
Stack
collation ์ ์ ์ฉํฉ๋๋ค.
auto dataset = torch::data::datasets::MNIST("./mnist")
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
.map(torch::data::transforms::Stack<>());
MNIST ๋ฐ์ดํฐ์
์ ํ์ต ๋ฐ์ด๋๋ฆฌ ์คํ ์์น๋ฅผ ๊ธฐ์ค์ผ๋ก ./mnist
๋๋ ํ ๋ฆฌ์ ์์นํด์ผ ํฉ๋๋ค. MNIST ๋ฐ์ดํฐ์
์ ์ด ์คํฌ๋ฆฝํธ ๋ฅผ
์ฌ์ฉํด ๋ค์ด๋ก๋ํ ์ ์์ต๋๋ค.
๋ค์์ผ๋ก, ๋ฐ์ดํฐ ๋ก๋๋ฅผ ๋ง๋ค๊ณ ์ด ๋ฐ์ดํฐ์
์ ์ ๋ฌํฉ๋๋ค. ์๋ก์ด ๋ฐ์ดํฐ
๋ก๋๋ฅผ ๋ง๋ค๊ธฐ ์ํด torch::data::make_data_loader
๋ฅผ ์ฌ์ฉํฉ๋๋ค.
์ด ๋ก๋๋ ์ฌ๋ฐ๋ฅธ ํ์
(๋ฐ์ดํฐ์
ํ์
, ์ํ๋ฌ ํ์
๋ฐ ๊ธฐํ ๊ตฌํ ์ธ๋ถ์ฌํญ์
๋ฐ๋ผ ๊ฒฐ์ ๋จ)์ std::unique_ptr
๋ฅผ ๋ฐํํฉ๋๋ค.
auto data_loader = torch::data::make_data_loader(std::move(dataset));
๋ฐ์ดํฐ ๋ก๋์๋ ๋ง์ ์ต์
์ด ์ ๊ณต๋ฉ๋๋ค. ์ ์ฒด ๋ชฉ๋ก์ ์ฌ๊ธฐ
์์ ํ์ธํ ์ ์์ต๋๋ค.
์๋ฅผ ๋ค์ด ๋ฐ์ดํฐ ๋ก๋ฉ ์๋๋ฅผ ๋์ด๊ธฐ ์ํด ์์
์ ์๋ฅผ ๋๋ฆด ์
์์ต๋๋ค. ๊ธฐ๋ณธ๊ฐ์ 0์ด๋ฉฐ, ์ด๋ ์ฃผ ์ฐ๋ ๋๊ฐ ์ฌ์ฉ๋จ์ ์๋ฏธํฉ๋๋ค.
workers
๋ฅผ 2
๋ก ์ค์ ํ๋ฉด ๋ฐ์ดํฐ๋ฅผ ๋์์ ๋ก๋ํ๋ ์ฐ๋ ๋๊ฐ
๋ ๊ฐ ์์ฑ๋ฉ๋๋ค. ๋ํ ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ๊ธฐ๋ณธ๊ฐ 1
์์ 64
(kBatchSize
๊ฐ)
์ ๊ฐ์ด ๋ ์ ๋นํ ๊ฐ์ผ๋ก ๋๋ ค์ผ ํฉ๋๋ค. ๊ทธ๋ฌ๋ฉด
DataLoaderOptions
๊ฐ์ฒด๋ฅผ ๋ง๋ค์ด ์ ์ ํ ์์ฑ์ ์ค์ ํด ๋ณด๊ฒ ์ต๋๋ค.
auto data_loader = torch::data::make_data_loader(
std::move(dataset),
torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));
์ด์ ๋ฐ์ดํฐ ๋ฐฐ์น๋ฅผ ๋ก๋ํ๋ ๋ฃจํ๋ฅผ ์์ฑํ ์ ์์ต๋๋ค. ์ง๊ธ์ ์ฝ์์๋ง ์ถ๋ ฅํ ๊ฒ์ ๋๋ค.
for (torch::data::Example<>& batch : *data_loader) {
std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
for (int64_t i = 0; i < batch.data.size(0); ++i) {
std::cout << batch.target[i].item<int64_t>() << " ";
}
std::cout << std::endl;
}
์ด ๊ฒฝ์ฐ ๋ฐ์ดํฐ ๋ก๋๊ฐ ๋ฐํํ๋ ํ์
์ torch::data::Example
์
๋๋ค.
์ด ํ์
์ ๋ฐ์ดํฐ๋ฅผ ์ํ data
ํ๋์ ๋ ์ด๋ธ์ ์ํ target
ํ๋๊ฐ
์๋ ๊ฐ๋จํ struct์
๋๋ค. ์์ Stack
collation์ ์ ์ฉํ๊ธฐ ๋๋ฌธ์,
๋ฐ์ดํฐ ๋ก๋๋ ์ด example์ ํ๋๋ง ๋ฐํํฉ๋๋ค. ๋ฐ์ดํฐ ๋ก๋์ collation์
์ ์ฉํ์ง ์์ผ๋ฉด, std::vector<torch::data::Example<>>
๋ฅผ yieldํ๋ฉฐ,
๊ฐ ๋ฐฐ์น์ example์๋ ํ๋์ element๊ฐ ์์ ๊ฒ์
๋๋ค.
์ด ์ฝ๋๋ฅผ ๋ค์ ๋น๋ํ๊ณ ์คํํ๋ฉด ๋๋ต ๋ค์๊ณผ ๊ฐ์ ๋ด์ฉ์ ์ป์ ๊ฒ์ ๋๋ค.
root@fa350df05ecf:/home/build# make
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcgan
root@fa350df05ecf:/home/build# make
[100%] Built target dcgan
root@fa350df05ecf:/home/build# ./dcgan
Batch size: 64 | Labels: 5 2 6 7 2 1 6 7 0 1 6 2 3 6 9 1 8 4 0 6 5 3 3 0 4 6 6 6 4 0 8 6 0 6 9 2 4 0 2 8 6 3 3 2 9 2 0 1 4 2 3 4 8 2 9 9 3 5 8 0 0 7 9 9
Batch size: 64 | Labels: 2 2 4 7 1 2 8 8 6 9 0 2 2 9 3 6 1 3 8 0 4 4 8 8 8 9 2 6 4 7 1 5 0 9 7 5 4 3 5 4 1 2 8 0 7 1 9 6 1 6 5 3 4 4 1 2 3 2 3 5 0 1 6 2
Batch size: 64 | Labels: 4 5 4 2 1 4 8 3 8 3 6 1 5 4 3 6 2 2 5 1 3 1 5 0 8 2 1 5 3 2 4 4 5 9 7 2 8 9 2 0 6 7 4 3 8 3 5 8 8 3 0 5 8 0 8 7 8 5 5 6 1 7 8 0
Batch size: 64 | Labels: 3 3 7 1 4 1 6 1 0 3 6 4 0 2 5 4 0 4 2 8 1 9 6 5 1 6 3 2 8 9 2 3 8 7 4 5 9 6 0 8 3 0 0 6 4 8 2 5 4 1 8 3 7 8 0 0 8 9 6 7 2 1 4 7
Batch size: 64 | Labels: 3 0 5 5 9 8 3 9 8 9 5 9 5 0 4 1 2 7 7 2 0 0 5 4 8 7 7 6 1 0 7 9 3 0 6 3 2 6 2 7 6 3 3 4 0 5 8 8 9 1 9 2 1 9 4 4 9 2 4 6 2 9 4 0
Batch size: 64 | Labels: 9 6 7 5 3 5 9 0 8 6 6 7 8 2 1 9 8 8 1 1 8 2 0 7 1 4 1 6 7 5 1 7 7 4 0 3 2 9 0 6 6 3 4 4 8 1 2 8 6 9 2 0 3 1 2 8 5 6 4 8 5 8 6 2
Batch size: 64 | Labels: 9 3 0 3 6 5 1 8 6 0 1 9 9 1 6 1 7 7 4 4 4 7 8 8 6 7 8 2 6 0 4 6 8 2 5 3 9 8 4 0 9 9 3 7 0 5 8 2 4 5 6 2 8 2 5 3 7 1 9 1 8 2 2 7
Batch size: 64 | Labels: 9 1 9 2 7 2 6 0 8 6 8 7 7 4 8 6 1 1 6 8 5 7 9 1 3 2 0 5 1 7 3 1 6 1 0 8 6 0 8 1 0 5 4 9 3 8 5 8 4 8 0 1 2 6 2 4 2 7 7 3 7 4 5 3
Batch size: 64 | Labels: 8 8 3 1 8 6 4 2 9 5 8 0 2 8 6 6 7 0 9 8 3 8 7 1 6 6 2 7 7 4 5 5 2 1 7 9 5 4 9 1 0 3 1 9 3 9 8 8 5 3 7 5 3 6 8 9 4 2 0 1 2 5 4 7
Batch size: 64 | Labels: 9 2 7 0 8 4 4 2 7 5 0 0 6 2 0 5 9 5 9 8 8 9 3 5 7 5 4 7 3 0 5 7 6 5 7 1 6 2 8 7 6 3 2 6 5 6 1 2 7 7 0 0 5 9 0 0 9 1 7 8 3 2 9 4
Batch size: 64 | Labels: 7 6 5 7 7 5 2 2 4 9 9 4 8 7 4 8 9 4 5 7 1 2 6 9 8 5 1 2 3 6 7 8 1 1 3 9 8 7 9 5 0 8 5 1 8 7 2 6 5 1 2 0 9 7 4 0 9 0 4 6 0 0 8 6
...
์ฆ, MNIST ๋ฐ์ดํฐ์ ์์ ๋ฐ์ดํฐ๋ฅผ ์ฑ๊ณต์ ์ผ๋ก ๋ก๋ํ ์ ์์ต๋๋ค.
์ด์ ์์ ์ ์๊ณ ๋ฆฌ์ฆ ๋ถ๋ถ์ ๋ง๋ฌด๋ฆฌํ๊ณ ์์ฑ๊ธฐ์ ํ๋ณ๊ธฐ ์ฌ์ด์์ ์ผ์ด๋๋ ์ฌ์ธํ ์์ฉ์ ๊ตฌํํด ๋ณด๊ฒ ์ต๋๋ค. ๋จผ์ ์์ฑ๊ธฐ์ ํ๋ณ๊ธฐ ๊ฐ๊ฐ์ ์ํด ์ด ๋ ๊ฐ์ optimizer๋ฅผ ์์ฑํ๊ฒ ์ต๋๋ค. ์ฐ๋ฆฌ๊ฐ ์ฌ์ฉํ๋ optimizer๋ Adam ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํํฉ๋๋ค.
torch::optim::Adam generator_optimizer(
generator->parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple(0.5, 0.5)));
torch::optim::Adam discriminator_optimizer(
discriminator->parameters(), torch::optim::AdamOptions(5e-4).betas(std::make_tuple(0.5, 0.5)));
Note
์ด ๊ธ ์์ฑ ๋น์, C++ ํ๋ก ํธ์๋๊ฐ Adagrad, Adam, LBFGS, RMSprop ๋ฐ SGD๋ฅผ ๊ตฌํํ๋ ์ตํฐ๋ง์ด์ ๋ฅผ ์ ๊ณตํฉ๋๋ค. ์ต์ ๋ฆฌ์คํธ๋ docs ์ ์์ต๋๋ค.
๋ค์์ผ๋ก, ์ฐ๋ฆฌ์ ํ์ต ๋ฃจํ๋ฅผ ์์ ํด์ผ ํฉ๋๋ค. ๋งค ์ํญ๋ง๋ค ๋ฐ์ดํฐ ๋ก๋๋ฅผ ๋ฐ๋ณต ์คํํ๋ ๋ฐ๊นฅ ๋ฃจํ๋ฅผ ์ถ๊ฐํด ๋ค์์ GAN ํ์ต ์ฝ๋๋ฅผ ์์ฑํฉ๋๋ค.
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
int64_t batch_index = 0;
for (torch::data::Example<>& batch : *data_loader) {
// Train discriminator with real images.
discriminator->zero_grad();
torch::Tensor real_images = batch.data;
torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
torch::Tensor real_output = discriminator->forward(real_images);
torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
d_loss_real.backward();
// Train discriminator with fake images.
torch::Tensor noise = torch::randn({batch.data.size(0), kNoiseSize, 1, 1});
torch::Tensor fake_images = generator->forward(noise);
torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
d_loss_fake.backward();
torch::Tensor d_loss = d_loss_real + d_loss_fake;
discriminator_optimizer.step();
// Train generator.
generator->zero_grad();
fake_labels.fill_(1);
fake_output = discriminator->forward(fake_images);
torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
g_loss.backward();
generator_optimizer.step();
std::printf(
"\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
epoch,
kNumberOfEpochs,
++batch_index,
batches_per_epoch,
d_loss.item<float>(),
g_loss.item<float>());
}
}
์ ์ฝ๋๋ ๋จผ์ ์ง์ง (real) ์ด๋ฏธ์ง์ ๋ํด ํ๋ณ๊ธฐ๋ฅผ ํ๊ฐํ๋๋ฐ, ์ด ๋
ํ๋ณ๊ธฐ๋ ๋์ ํ๋ฅ ์ ์ถ๋ ฅํด์ผ ํฉ๋๋ค. ์ด๋ฅผ ์ํด
torch::empty(batch.data.size(0)).uniform_(0.8, 1.0)
๋ฅผ ๋ชฉํ ํ๋ฅ
๊ฐ์ผ๋ก ์ฌ์ฉํฉ๋๋ค.
Note
ํ๋ณ๊ธฐ๋ฅผ ๋ณด๋ค ๊ฒฌ๊ณ ํ๊ฒ ํ์ตํ๊ธฐ ์ํด ๋ชจ๋ ๊ณณ์์ 1.0์ด ์๋ 0.8๊ณผ 1.0 ์ฌ์ด์ ๊ท ์ผ ๋ถํฌ์์ ์์์ ๊ฐ์ ์ ํํฉ๋๋ค. ์ด ํธ๋ฆญ์ label smoothing ์ด๋ผ๊ณ ํฉ๋๋ค.
ํ๋ณ๊ธฐ๋ฅผ ํ๊ฐํ๊ธฐ์ ์์ ๋งค๊ฐ๋ณ์์ ๊ทธ๋๋์ธํธ๋ฅผ 0์ผ๋ก ๋ง๋ญ๋๋ค.
์์ค์ ๊ณ์ฐํ ํ d_loss.backward()
๋ฅผ ํธ์ถํด ์ด๋ฅผ
๋คํธ์ํฌ์ ์ญ์ ํํฉ๋๋ค. ๊ฐ์ง (fake) ์ด๋ฏธ์ง๋ค์ ๋ํด์ ์ด ๊ณผ์ ์
๋ฐ๋ณตํฉ๋๋ค. ๋ฐ์ดํฐ์
์ ์ด๋ฏธ์ง๋ฅผ ์ฌ์ฉํ๋ ๋์ , ์์ฑ์์
๋ฌด์์ ๋
ธ์ด์ฆ๋ฅผ ์
๋ ฅํ์ฌ ์ฌ๊ธฐ์ ์ฌ์ฉํ ๊ฐ์ง ์ด๋ฏธ์ง๋ฅผ ๋ง๋ญ๋๋ค.
๊ทธ๋ฆฌ๊ณ ๊ทธ ๊ฐ์ง ์ด๋ฏธ์ง๋ค์ ํ๋ณ๊ธฐ์ ์ ๋ฌํฉ๋๋ค. ์ด๋ฒ์๋
ํ๋ณ๊ธฐ๊ฐ ๋ฎ์ ํ๋ฅ , ์ด์์ ์ผ๋ก๋ ๋ชจ๋ 0์ ์ถ๋ ฅํ๊ธฐ๋ฅผ ๋ฐ๋๋๋ค.
์ง์ง ์ด๋ฏธ์ง์ ๊ฐ์ง ์ด๋ฏธ์ง ๋ฐฐ์น ๋ชจ๋์ ๋ํ ํ๋ณ๊ธฐ ์์ค์ ๊ณ์ฐํ
ํ์๋, ํ๋ณ๊ธฐ์ optimizer ๋งค๊ฐ๋ณ์ ์
๋ฐ์ดํธ๋ฅผ ํ ๋จ๊ณ์ฉ
์งํํ ์ ์์ต๋๋ค.
์์ฑ๊ธฐ๋ฅผ ํ์ต์ํค๊ธฐ ์ํด ์ฐ์ ๊ทธ๋๋์ธํธ๋ฅผ ๋ค์ ํ๋ฒ 0์ผ๋ก ์ค์ ํ๊ณ
๋ค์ ๊ฐ์ง ์ด๋ฏธ์ง๋ก ํ๋ณ๊ธฐ๋ฅผ ํ๊ฐํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ด๋ฒ์๋ ํ๋ณ๊ธฐ๊ฐ
ํ๋ฅ 1์ ๋งค์ฐ ๊ทผ์ ํ๊ฒ ์ถ๋ ฅํ๊ฒ ํ์ฌ, ์์ฑ๊ธฐ๊ฐ ํ๋ณ๊ธฐ๋ฅผ
์์ฌ ์ค์ (๋ฐ์ดํฐ์
์ ์๋) ์ง์ง๋ผ๊ณ ์๊ฐํ๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์
์๋๋ก ํ๋ ค ํฉ๋๋ค. ์ด๋ฅผ ์ํด fake_labels
ํ
์๋ฅผ ๋ชจ๋
1๋ก ์ฑ์ฐ๊ฒ ์ต๋๋ค. ๋ง์ง๋ง์ผ๋ก ๋งค๊ฐ๋ณ์๋ฅผ ์
๋ฐ์ดํธํ๊ธฐ ์ํด
์์ฑ๊ธฐ์ optimzier ๋งค๊ฐ๋ณ์ ์
๋ฐ์ดํธ๋ฅผ ์งํํฉ๋๋ค.
์ด์ CPU๋ก ๋ชจ๋ธ์ ํ์ต์ํฌ ์ค๋น๊ฐ ๋์์ต๋๋ค. ์ํ๋ ์ํ ์ถ๋ ฅ์ ์บก์ฒํ ์ ์๋ ์ฝ๋๋ ์์ง ์์ง๋ง ์ ์ ํ์ ์ถ๊ฐํ๊ฒ ์ต๋๋ค. ์ง๊ธ์ ๋ชจ๋ธ์ด ๋ฌด์ธ๊ฐ ๋ฅผ ์ํํ๊ณ ์๋ค๋ ๊ฒ๋ง์ ๊ด์ฐฐํ๊ณ , ๋์ค์๋ ์์ฑ๋ ์ด๋ฏธ์ง๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ด ๋ฌด์ธ๊ฐ๊ฐ ์๋ฏธ ์๋์ง ์ฌ๋ถ๋ฅผ ํ์ธํ ๊ฒ์ ๋๋ค. ๋ค์ ๋น๋ํ๊ณ ์คํํ๋ฉด ๋ค์๊ณผ ๊ฐ์ ๋ด์ฉ์ด ์ถ๋ ฅ๋ผ์ผ ํฉ๋๋ค.
root@3c0711f20896:/home/build# make && ./dcgan
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcga
[ 1/10][100/938] D_loss: 0.6876 | G_loss: 4.1304
[ 1/10][200/938] D_loss: 0.3776 | G_loss: 4.3101
[ 1/10][300/938] D_loss: 0.3652 | G_loss: 4.6626
[ 1/10][400/938] D_loss: 0.8057 | G_loss: 2.2795
[ 1/10][500/938] D_loss: 0.3531 | G_loss: 4.4452
[ 1/10][600/938] D_loss: 0.3501 | G_loss: 5.0811
[ 1/10][700/938] D_loss: 0.3581 | G_loss: 4.5623
[ 1/10][800/938] D_loss: 0.6423 | G_loss: 1.7385
[ 1/10][900/938] D_loss: 0.3592 | G_loss: 4.7333
[ 2/10][100/938] D_loss: 0.4660 | G_loss: 2.5242
[ 2/10][200/938] D_loss: 0.6364 | G_loss: 2.0886
[ 2/10][300/938] D_loss: 0.3717 | G_loss: 3.8103
[ 2/10][400/938] D_loss: 1.0201 | G_loss: 1.3544
[ 2/10][500/938] D_loss: 0.4522 | G_loss: 2.6545
...
์ด ์คํฌ๋ฆฝํธ๋ CPU์์ ์ ๋์ํ์ง๋ง, ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ด GPU์์ ํจ์ฌ ๋น ๋ฅด๋ค๋
๊ฒ์ ์ ์๋ ค์ง ์ฌ์ค์
๋๋ค. ์ด๋ป๊ฒ ํ์ต์ GPU๋ก ์ฎ๊ธธ ์ ์์ ์ง์ ๋ํด ๋น ๋ฅด๊ฒ ๋
ผ์ํด
๋ณด๊ฒ ์ต๋๋ค. ์ด๋ฅผ ์ํด ํด์ผ ํ ์ผ ๋ ๊ฐ์ง๋ก GPU ์ฅ์น(device) ์ฌ์์ ์ฐ๋ฆฌ๊ฐ ์ง์ ํ ๋นํ
ํ
์์ ์ ๋ฌํ๋ ๊ฒ๊ณผ, C++ ํ๋ก ํธ์๋์ ๋ชจ๋ ํ
์์ ๋ชจ๋์ด ๊ฐ๊ณ ์๋ to()
๋ฉ์๋๋ฅผ ์ฌ์ฉํด ๋ค๋ฅธ ๋ชจ๋ ํ
์๋ฅผ GPU์ ๋ช
์์ ์ผ๋ก ๋ณต์ฌํ๋ ๊ฒ์ด ์์ต๋๋ค.
๋ ๊ฐ์ง๋ฅผ ๋ชจ๋ ๋ฌ์ฑํ๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ผ๋ก ํ์ต ์คํฌ๋ฆฝํธ ์ต์์์
torch::Device
์ธ์คํด์ค๋ฅผ ๋ง๋ค์ด torch::zeros
์ ๊ฐ์
ํ
์ ํฉํ ๋ฆฌ ํจ์๋ to()
๋ฉ์๋์ ์ ๋ฌํ ์ ์์ต๋๋ค. ๋จผ์ CPU device๋ก
์ด๋ฅผ ๊ตฌํํด๋ณด๊ฒ ์ต๋๋ค.
// ํ์ต ์คํฌ๋ฆฝํธ ์ต์๋จ์ ์ด ์ฝ๋๋ฅผ ๋ฃ์ผ์ธ์.
torch::Device device(torch::kCPU);
์๋์ ๊ฐ์ ์๋ก์ด ํ ์ ํ ๋น์ ๊ฒฝ์ฐ,
torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
๋ง์ง๋ง ์ธ์๋ก device
๋ฅผ ๋ฐ๋๋ก ์์ ํฉ๋๋ค.
torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);
MNIST ๋ฐ์ดํฐ์
์ ํ
์์ฒ๋ผ ์ฐ๋ฆฌ๊ฐ ์ง์ ์์ฑํ์ง ์๋ ํ
์์์๋
๋ช
์์ ์ผ๋ก to()
ํธ์ถ์ ์ฝ์
ํด์ผ ํฉ๋๋ค. ๋ฐ๋ผ์ ์๋ ์ฝ๋์ ๊ฒฝ์ฐ,
torch::Tensor real_images = batch.data;
๋ค์๊ณผ ๊ฐ์ด ๋ณํฉ๋๋ค.
torch::Tensor real_images = batch.data.to(device);
๋ํ, ๋ชจ๋ธ ๋งค๊ฐ๋ณ์๋ฅผ ์ฌ๋ฐ๋ฅธ ์ฅ์น๋ก ์ฎ๊ฒจ์ผ ํฉ๋๋ค.
generator->to(device);
discriminator->to(device);
Note
๋ง์ผ ํ
์๊ฐ ์ด๋ฏธ to()
์ ์ ๋ฌ๋ ์ฅ์น ์์ ์๋ค๋ฉด ๊ทธ ํธ์ถ์ ์๋ฌด ์ผ๋ ํ์ง ์์ต๋๋ค. ์ฌ๋ณธ์ด ์์ฑ๋์ง๋ ์์ต๋๋ค.
์ด์ CPU์์ ์คํ๋๋ ์ด์ ์ ์ฝ๋๊ฐ ๋ณด๋ค ๋ช ์์ ์ผ๋ก ๋ฐ๋์์ต๋๋ค. ํ์ง๋ง ์ด์ ๋ ์ฅ์น๋ฅผ CUDA ์ฅ์น๋ก ๋ณ๊ฒฝํ๋ ๊ฒ ๋ํ ๋งค์ฐ ์ฝ์ต๋๋ค.
torch::Device device(torch::kCUDA)
์ด์ ๋ชจ๋ ํ
์๊ฐ GPU์ ์กด์ฌํ๋ฉฐ ์ด๋ ํ ๋ค์ด์คํธ๋ฆผ ์ฝ๋ ๋ณ๊ฒฝ ์์ด๋
๋ชจ๋ ์ฐ์ฐ์ ์ํด ๋น ๋ฅธ CUDA ์ปค๋์ ํธ์ถํฉ๋๋ค. ํน์ ์ธ๋ฑ์ค์ ์ฅ์น๋ฅผ
์ง์ ํ๋ ค๋ฉด Device
์์ฑ์์ ๋ ๋ฒ์งธ ์ธ์๋ก ์ ๋ฌํ๋ฉด ๋ฉ๋๋ค.
์๋ก ๋ค๋ฅธ ์ฅ์น์ ์๋ก ๋ค๋ฅธ ํ
์๊ฐ ์กด์ฌํ๊ธฐ๋ฅผ ์ํ๋ ๊ฒฝ์ฐ,
๋ณ๋์ ์ฅ์น ์ธ์คํด์ค(์: CUDA ์ฅ์น 0๊ณผ ๋ค๋ฅธ CUDA ์ฅ์น 1)๋ฅผ
์ ๋ฌํ ์๋ ์์ต๋๋ค. ๋ฟ๋ง ์๋๋ผ, ์ด๋ฌํ ์ค์ ์ ๋์ ์ผ๋ก ์ํํ ์๋
์์ด ๋ค์๊ณผ ๊ฐ์ด ํ์ต ์คํฌ๋ฆฝํธ์ ํด๋์ฑ์ ๋์ด๋ ๋ฐ ์ข
์ข
์ ์ฉํ๊ฒ ์ฌ์ฉ๋ฉ๋๋ค.
torch::Device device = torch::kCPU;
if (torch::cuda::is_available()) {
std::cout << "CUDA is available! Training on GPU." << std::endl;
device = torch::kCUDA;
}
๋์๊ฐ ์๋์ ๊ฐ์ ์ฝ๋๋ ๊ฐ๋ฅํฉ๋๋ค.
torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
๋ง์ง๋ง์ผ๋ก ํ์ต ์คํฌ๋ฆฝํธ์ ์ถ๊ฐํด์ผ ํ ๋ด์ฉ์ ๋ชจ๋ธ ๋งค๊ฐ๋ณ์ ๋ฐ ์ตํฐ๋ง์ด์ ์ ์ํ, ๊ทธ๋ฆฌ๊ณ ์์ฑ๋ ๋ช ๊ฐ์ ์ด๋ฏธ์ง ์ํ์ ์ฃผ๊ธฐ์ ์ผ๋ก ์ ์ฅํ๋ ๊ฒ์ ๋๋ค. ํ์ต ๊ณผ์ ๋์ค์ ์ปดํจํฐ๊ฐ ๋ค์ด๋๋ฉด ์ด๋ ๊ฒ ์ ์ฅ๋ ์ํ๋ก๋ถํฐ ํ์ต ์ํ๋ฅผ ๋ณต์ํ ์ ์์ต๋๋ค. ์ด๋ ์ฅ์๊ฐ ์ง์๋๋ ํ์ต์ ์ํด ํ์๋ก ์๊ตฌ๋ฉ๋๋ค. ๋คํํ๋ C++ ํ๋ก ํธ์๋๋ ๊ฐ๋ณ ํ ์๋ฟ๋ง ์๋๋ผ ๋ชจ๋ธ ๋ฐ ์ตํฐ๋ง์ด์ ์ํ๋ฅผ ์ง๋ ฌํํ๊ณ ์ญ์ง๋ ฌํํ ์ ์๋ API๋ฅผ ์ ๊ณตํฉ๋๋ค.
์ด๋ฅผ ์ํ ํต์ฌ API๋ torch::save(thing,filename)
์
torch::load(thing,filename)
๋ก, ์ฌ๊ธฐ์ thing
์
torch::nn::Module
์ ํ์ ํด๋์ค ํน์ ์ฐ๋ฆฌ์ ํ์ต ์คํฌ๋ฆฝํธ์ Adam
๊ฐ์ฒด์ ๊ฐ์ ์ตํฐ๋ง์ด์ ์ธ์คํด์ค๊ฐ ๋ ์ ์์ต๋๋ค. ๋ชจ๋ธ ๋ฐ ์ตํฐ๋ง์ด์ ์ํ๋ฅผ
ํน์ ์ฃผ๊ธฐ๋ง๋ค ์ ์ฅํ๋๋ก ํ์ต ๋ฃจํ๋ฅผ ์์ ํด๋ณด๊ฒ ์ต๋๋ค.
if (batch_index % kCheckpointEvery == 0) {
// ๋ชจ๋ธ ๋ฐ ์ตํฐ๋ง์ด์ ์ํ๋ฅผ ์ ์ฅํฉ๋๋ค.
torch::save(generator, "generator-checkpoint.pt");
torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
torch::save(discriminator, "discriminator-checkpoint.pt");
torch::save(discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
// ์์ฑ๊ธฐ๋ฅผ ์ํ๋งํ๊ณ ์ด๋ฏธ์ง๋ฅผ ์ ์ฅํฉ๋๋ค.
torch::Tensor samples = generator->forward(torch::randn({8, kNoiseSize, 1, 1}, device));
torch::save((samples + 1.0) / 2.0, torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
}
์ฌ๊ธฐ์ 100
๋ฐฐ์น๋ง๋ค ์ํ๋ฅผ ์ ์ฅํ๋ ค๋ฉด kCheckpointEvery
๋ฅผ 100
๊ณผ ๊ฐ์ ์ ์๋ก ์ค์ ํ ์ ์์ผ๋ฉฐ, checkpoint_counter
๋ ์ํ๋ฅผ ์ ์ฅํ ๋๋ง๋ค
์ฆ๊ฐํ๋ ์นด์ดํฐ์
๋๋ค.
ํ์ต ์ํ๋ฅผ ๋ณต์ํ๊ธฐ ์ํด ๋ชจ๋ธ ๋ฐ ์ตํฐ๋ง์ด์ ๋ฅผ ๋ชจ๋ ์์ฑํ ํ ํ์ต ๋ฃจํ ์์ ๋ค์ ์ฝ๋๋ฅผ ์ถ๊ฐํ ์ ์์ต๋๋ค.
torch::optim::Adam generator_optimizer(
generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
torch::optim::Adam discriminator_optimizer(
discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
if (kRestoreFromCheckpoint) {
torch::load(generator, "generator-checkpoint.pt");
torch::load(generator_optimizer, "generator-optimizer-checkpoint.pt");
torch::load(discriminator, "discriminator-checkpoint.pt");
torch::load(
discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
}
int64_t checkpoint_counter = 0;
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
int64_t batch_index = 0;
for (torch::data::Example<>& batch : *data_loader) {
ํ์ต ์คํฌ๋ฆฝํธ๊ฐ ์์ฑ๋์ด CPU์์๋ GPU์์๋ GAN์ ํ๋ จ์ํฌ ์ค๋น๊ฐ
๋์ต๋๋ค. ํ์ต ๊ณผ์ ์ ์ค๊ฐ ์ถ๋ ฅ์ ๊ฒ์ฌํ๊ธฐ ์ํด
"dcgan-sample-xxx.pt"
์ ์ฃผ๊ธฐ์ ์ผ๋ก ์ด๋ฏธ์ง ์ํ์ ์ ์ฅํ๋ ์ฝ๋๋ฅผ
์ถ๊ฐํ์ผ๋, ํ
์๋ค์ ๋ถ๋ฌ์ matplotlib๋ก ์๊ฐํํ๋ ๊ฐ๋จํ ํ์ด์ฌ
์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํด๋ณด๊ฒ ์ต๋๋ค.
import argparse
import matplotlib.pyplot as plt
import torch
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--sample-file", required=True)
parser.add_argument("-o", "--out-file", default="out.png")
parser.add_argument("-d", "--dimension", type=int, default=3)
options = parser.parse_args()
module = torch.jit.load(options.sample_file)
images = list(module.parameters())[0]
for index in range(options.dimension * options.dimension):
image = images[index].detach().cpu().reshape(28, 28).mul(255).to(torch.uint8)
array = image.numpy()
axis = plt.subplot(options.dimension, options.dimension, 1 + index)
plt.imshow(array, cmap="gray")
axis.get_xaxis().set_visible(False)
axis.get_yaxis().set_visible(False)
plt.savefig(options.out_file)
print("Saved ", options.out_file)
์ด์ ๋ชจ๋ธ์ ์ฝ 30 ์ํญ ์ ๋ ํ์ต์ํต์๋ค.
root@3c0711f20896:/home/build# make && ./dcgan 10:17:57
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcgan
CUDA is available! Training on GPU.
[ 1/30][200/938] D_loss: 0.4953 | G_loss: 4.0195
-> checkpoint 1
[ 1/30][400/938] D_loss: 0.3610 | G_loss: 4.8148
-> checkpoint 2
[ 1/30][600/938] D_loss: 0.4072 | G_loss: 4.36760
-> checkpoint 3
[ 1/30][800/938] D_loss: 0.4444 | G_loss: 4.0250
-> checkpoint 4
[ 2/30][200/938] D_loss: 0.3761 | G_loss: 3.8790
-> checkpoint 5
[ 2/30][400/938] D_loss: 0.3977 | G_loss: 3.3315
...
-> checkpoint 120
[30/30][938/938] D_loss: 0.3610 | G_loss: 3.8084
๊ทธ๋ฆฌ๊ณ ์ด๋ฏธ์ง๋ค์ ํ๋กฏ์ ์๊ฐํํฉ๋๋ค.
root@3c0711f20896:/home/build# python display.py -i dcgan-sample-100.pt
Saved out.png
๊ทธ ๊ฒฐ๊ณผ๋ ์๋์ ๊ฐ์ ๊ฒ์ ๋๋ค.
์ซ์๋ค์! ๋ง์ธ! ์ด์ ์ฌ๋ฌ๋ถ ์ฐจ๋ก์ ๋๋ค. ์ซ์๊ฐ ๋ณด๋ค ๋์ ๋ณด์ด๋๋ก ๋ชจ๋ธ์ ๊ฐ์ ํ ์ ์๋์?
์ด ํํ ๋ฆฌ์ผ์ ํตํด PyTorch C++ ํ๋ก ํธ์๋์ ๋ํ ์ด๋ ์ ๋ ์ดํด๋๊ฐ ์๊ธฐ์ จ๊ธฐ ๋ฐ๋๋๋ค. ํ์ฐ์ ์ผ๋ก PyTorch ๊ฐ์ ๋จธ์ ๋ฌ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ ๋งค์ฐ ๋ค์ํ๊ณ ๊ด๋ฒ์ํ API๋ฅผ ๊ฐ์ง๊ณ ์์ต๋๋ค. ๋ฐ๋ผ์, ์ฌ๊ธฐ์ ๋ ผ์ํ๊ธฐ์ ์๊ฐ๊ณผ ๊ณต๊ฐ์ด ๋ถ์กฑํ๋ ๊ฐ๋ ๋ค์ด ๋ง์ต๋๋ค. ๊ทธ๋ฌ๋ ์ง์ API๋ฅผ ์ฌ์ฉํด๋ณด๊ณ , ๋ฌธ์, ๊ทธ ์ค์์๋ ํนํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ API ์น์ ์ ์ฐธ์กฐํด๋ณด๋ ๊ฒ์ ๊ถ์ฅ๋๋ฆฝ๋๋ค. ๋ํ, C++ ํ๋ก ํธ์๋๊ฐ ํ์ด์ฌ ํ๋ก ํธ์๋์ ๋์์ธ๊ณผ ์๋งจํฑ์ ๋ฐ๋ฅธ๋ค๋ ์ฌ์ค์ ์ ๊ธฐ์ตํ๋ฉด ๋ณด๋ค ๋น ๋ฅด๊ฒ ํ์ตํ ์ ์์ ๊ฒ์ ๋๋ค.
Tip
๋ณธ ํํ ๋ฆฌ์ผ์ ๋ํ ์ ์ฒด ์์ค์ฝ๋๋ ์ด ์ ์ฅ์ ์ ์ ๊ณต๋์ด ์์ต๋๋ค.
์ธ์ ๋ ๊ทธ๋ ๋ฏ์ด ์ด๋ค ๋ฌธ์ ๊ฐ ์๊ธฐ๊ฑฐ๋ ์ง๋ฌธ์ด ์์ผ๋ฉด ์ ํฌ ํฌ๋ผ ์ ์ด์ฉํ๊ฑฐ๋ Github ์ด์ ๋ก ์ฐ๋ฝ์ฃผ์ธ์.