Skip to content

Latest commit

ย 

History

History
1299 lines (1026 loc) ยท 58.9 KB

cpp_frontend.rst

File metadata and controls

1299 lines (1026 loc) ยท 58.9 KB

PyTorch C++ ํ”„๋ก ํŠธ์—”๋“œ ์‚ฌ์šฉํ•˜๊ธฐ

๋ฒˆ์—ญ: ์œ ์šฉํ™˜

PyTorch C++ ํ”„๋ก ํŠธ์—”๋“œ๋Š” PyTorch ๋จธ์‹ ๋Ÿฌ๋‹ ํ”„๋ ˆ์ž„์›Œํฌ์˜ ์ˆœ์ˆ˜ C++ ์ธํ„ฐํŽ˜์ด์Šค์ž…๋‹ˆ๋‹ค. PyTorch์˜ ์ฃผ๋œ ์ธํ„ฐํŽ˜์ด์Šค๋Š” ๋ฌผ๋ก  ํŒŒ์ด์ฌ์ด์ง€๋งŒ ์ด ๊ณณ์˜ API๋Š” ํ…์„œ(tensor)๋‚˜ ์ž๋™ ๋ฏธ๋ถ„๊ณผ ๊ฐ™์€ ๊ธฐ์ดˆ์ ์ธ ์ž๋ฃŒ๊ตฌ์กฐ ๋ฐ ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•˜๋Š” C++ ์ฝ”๋“œ๋ฒ ์ด์Šค ์œ„์— ๊ตฌํ˜„๋˜์—ˆ์Šต๋‹ˆ๋‹ค. C++ ํ”„๋ก ํŠธ์—”๋“œ๋Š” ์ด๋Ÿฌํ•œ ๊ธฐ์ดˆ์ ์ธ C++ ์ฝ”๋“œ๋ฒ ์ด์Šค๋ฅผ ๋น„๋กฏํ•ด ๋จธ์‹ ๋Ÿฌ๋‹ ํ•™์Šต๊ณผ ์ถ”๋ก ์„ ์œ„ํ•ด ํ•„์š”ํ•œ ๋„๊ตฌ๋“ค์„ ์ƒ์†ํ•˜๋Š” ์ˆœ์ˆ˜ C++11 API๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์—๋Š” ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ๋ง์„ ์œ„ํ•ด ํ•„์š”ํ•œ ๊ณต์šฉ ์ปดํฌ๋„ŒํŠธ๋“ค์˜ ๋นŒํŠธ์ธ ๋ชจ์Œ, ๊ทธ๊ฒƒ์„ ์ƒ์†ํ•˜๊ธฐ ์œ„ํ•œ ์ปค์Šคํ…€ ๋ชจ๋“ˆ, ํ™•๋ฅ ์  ๊ฒฝ์‚ฌ ํ•˜๊ฐ•๋ฒ•๊ณผ ๊ฐ™์€ ์œ ๋ช…ํ•œ ์ตœ์ ํ™” ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ, ๋ณ‘๋ ฌ ๋ฐ์ดํ„ฐ ๋กœ๋” ๋ฐ ๋ฐ์ดํ„ฐ์…‹์„ ์ •์˜ํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์œ„ํ•œ API, ์ง๋ ฌํ™” ๋ฃจํ‹ด ๋“ฑ์ด ํฌํ•จ๋ฉ๋‹ˆ๋‹ค.

์ด ํŠœํ† ๋ฆฌ์–ผ์€ C++ ํ”„๋ก ํŠธ์—”๋“œ๋กœ ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๋Š” ์—”๋“œ ํˆฌ ์—”๋“œ ์˜ˆ์ œ๋ฅผ ์•ˆ๋‚ดํ•ฉ๋‹ˆ๋‹ค. ๊ตฌ์ฒด์ ์œผ๋กœ, ์šฐ๋ฆฌ๋Š” ์ƒ์„ฑ ๋ชจ๋ธ ์ค‘ ํ•˜๋‚˜์ธ DCGAN ์„ ํ•™์Šต์‹œ์ผœ MNIST ์ˆซ์ž ์ด๋ฏธ์ง€๋“ค์„ ์ƒ์„ฑํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๊ฐœ๋…์ ์œผ๋กœ ์‰ฌ์šด ์˜ˆ์‹œ์ด์ง€๋งŒ, ์—ฌ๋Ÿฌ๋ถ„์ด PyTorch C++ ํ”„๋ก ํŠธ์—”๋“œ์— ๋Œ€ํ•œ ๋Œ€๋žต์ ์ธ ๊ฐœ์š”๋ฅผ ํŒŒ์•…ํ•˜๊ณ  ๋” ๋ณต์žกํ•œ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๊ณ  ์‹ถ์€ ์š•๊ตฌ๋ฅผ ๋ถˆ๋Ÿฌ์ผ์œผํ‚ค๊ธฐ์— ์ถฉ๋ถ„ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋จผ์ € C++ ํ”„๋ก ํŠธ์—”๋“œ ์‚ฌ์šฉ์— ๋Œ€ํ•œ ๋™๊ธฐ๋ถ€์—ฌ๊ฐ€ ๋  ๋งŒํ•œ ์ด์•ผ๊ธฐ๋กœ ์‹œ์ž‘ํ•˜๊ณ , ๊ณง๋ฐ”๋กœ ๋ชจ๋ธ์„ ์ •์˜ํ•˜๊ณ  ํ•™์Šตํ•ด ๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

Tip

C++ ํ”„๋ก ํŠธ์—”๋“œ์— ๋Œ€ํ•œ ์งง๊ณ  ์žฌ๋ฏธ์žˆ๋Š” ๋ฐœํ‘œ๋ฅผ ๋ณด๋ ค๋ฉด CppCon 2018 ๋ผ์ดํŠธ๋‹ ํ† ํฌ ๋ฅผ ์‹œ์ฒญํ•˜์„ธ์š”.

Tip

์ด ๋…ธํŠธ ๋Š” C++ ํ”„๋ก ํŠธ์—”๋“œ์˜ ์ปดํฌ๋„ŒํŠธ์™€ ๋””์ž์ธ ์ฒ ํ•™์˜ ์ „๋ฐ˜์ ์ธ ๊ฐœ์š”๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

Tip

PyTorch C++ ์ƒํƒœ๊ณ„์— ๋Œ€ํ•œ ๋ฌธ์„œ๋Š” https://pytorch.org/cppdocs์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. API ๋ ˆ๋ฒจ์˜ ๋ฌธ์„œ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ๊ฐœ๊ด„์ ์ธ ์„ค๋ช…๋„ ์ฐพ์„ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋™๊ธฐ๋ถ€์—ฌ

GAN๊ณผ MNIST ์ˆซ์ž๋กœ์˜ ์„ค๋ ˆ๋Š” ์—ฌ์ •์„ ์‹œ์ž‘ํ•˜๊ธฐ์— ์•ž์„œ, ๋จผ์ € ํŒŒ์ด์ฌ ๋Œ€์‹  C++ ํ”„๋ก ํŠธ์—”๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ์ด์œ ์— ๋Œ€ํ•ด ์„ค๋ช…ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ(PyTorch ํŒ€)๋Š” ํŒŒ์ด์ฌ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†๊ฑฐ๋‚˜ ์‚ฌ์šฉํ•˜๊ธฐ์— ์ ํ•ฉํ•˜์ง€ ์•Š์€ ํ™˜๊ฒฝ์—์„œ ์—ฐ๊ตฌ๋ฅผ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๊ธฐ ์œ„ํ•ด C++ ํ”„๋ก ํŠธ์—”๋“œ๋ฅผ ๋งŒ๋“ค์—ˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

  • ์ €์ง€์—ฐ ์‹œ์Šคํ…œ: ์ดˆ๋‹น ํ”„๋ ˆ์ž„ ์ˆ˜๊ฐ€ ๋†’๊ณ  ์ง€์—ฐ ์‹œ๊ฐ„์ด ์งง์€ ์ˆœ์ˆ˜ C++ ๊ฒŒ์ž„ ์—”์ง„์—์„œ ๊ฐ•ํ™” ํ•™์Šต ์—ฐ๊ตฌ๋ฅผ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌํ•œ ํ™˜๊ฒฝ์—์„œ๋Š” ํŒŒ์ด์ฌ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ณด๋‹ค ์ˆœ์ˆ˜ C++ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ํ›จ์”ฌ ๋” ์ ํ•ฉํ•ฉ๋‹ˆ๋‹ค. ํŒŒ์ด์ฌ์€ ๋Š๋ฆฐ ์ธํ„ฐํ”„๋ฆฌํ„ฐ ๋•Œ๋ฌธ์— ๋‹ค๋ฃจ๊ธฐ๊ฐ€ ์‰ฝ์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
  • ๊ณ ๋„์˜ ๋ฉ€ํ‹ฐ์“ฐ๋ ˆ๋”ฉ ํ™˜๊ฒฝ: ๊ธ€๋กœ๋ฒŒ ์ธํ„ฐํ”„๋ฆฌํ„ฐ ๋ฝ(GIL)์œผ๋กœ ์ธํ•ด ํŒŒ์ด์ฌ์€ ๋™์‹œ์— ๋‘˜ ์ด์ƒ์˜ ์‹œ์Šคํ…œ ์“ฐ๋ ˆ๋“œ๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๋Œ€์•ˆ์œผ๋กœ ๋ฉ€ํ‹ฐํ”„๋กœ์„ธ์‹ฑ์„ ์‚ฌ์šฉํ•˜๋ฉด ํ™•์žฅ์„ฑ์ด ๋–จ์–ด์ง€๋ฉฐ ์‹ฌ๊ฐํ•œ ํ•œ๊ณ„๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. C++๋Š” ์ด๋Ÿฌํ•œ ์ œ์•ฝ ์กฐ๊ฑด์ด ์—†์œผ๋ฉฐ ์“ฐ๋ ˆ๋“œ๋ฅผ ์‰ฝ๊ฒŒ ๋งŒ๋“ค๊ณ  ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. Deep Neuroevolution ์— ์‚ฌ์šฉ๋œ ๊ฒƒ๊ณผ ๊ฐ™์ด ๊ณ ๋„์˜ ๋ณ‘๋ ฌํ™”๊ฐ€ ํ•„์š”ํ•œ ๋ชจ๋ธ๋„ ์ด๋ฅผ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๊ธฐ์กด์˜ C++ ์ฝ”๋“œ๋ฒ ์ด์Šค: ๋ฐฑ์—”๋“œ ์„œ๋ฒ„์˜ ์›น ํŽ˜์ด์ง€ ์„œ๋น„์Šค๋ถ€ํ„ฐ ์‚ฌ์ง„ ํŽธ์ง‘ ์†Œํ”„ํŠธ์›จ์–ด์˜ 3D ๊ทธ๋ž˜ํ”ฝ ๋ Œ๋”๋ง์— ์ด๋ฅด๊ธฐ๊นŒ์ง€ ์–ด๋– ํ•œ ์ž‘์—…์ด๋ผ๋„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ธฐ์กด C++ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์†Œ์œ ์ž๋กœ์„œ, ๋จธ์‹ ๋Ÿฌ๋‹ ๋ฐฉ๋ฒ•๋ก ์„ ์‹œ์Šคํ…œ์— ํ†ตํ•ฉํ•˜๊ณ  ์‹ถ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. C++ ํ”„๋ก ํŠธ์—”๋“œ๋Š” PyTorch (ํŒŒ์ด์ฌ) ๊ฒฝํ—˜ ๋ณธ์—ฐ์˜ ๋†’์€ ์œ ์—ฐ์„ฑ๊ณผ ์ง๊ด€์„ฑ์„ ์œ ์ง€ํ•˜๋ฉด์„œ, ํŒŒ์ด์ฌ๊ณผ C++๋ฅผ ์•ž๋’ค๋กœ ๋ฐ”์ธ๋”ฉํ•˜๋Š” ๋ฒˆ๊ฑฐ๋กœ์›€ ์—†์ด C++๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ด์ค๋‹ˆ๋‹ค.

C++ ํ”„๋ก ํŠธ์—”๋“œ์˜ ๋ชฉ์ ์€ ํŒŒ์ด์ฌ ํ”„๋ก ํŠธ์—”๋“œ์™€ ๊ฒฝ์Ÿํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹Œ ๋ณด์™„ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์—ฐ๊ตฌ์ž์™€ ์—”์ง€๋‹ˆ์–ด ๋ชจ๋‘๊ฐ€ PyTorch์˜ ๋‹จ์ˆœ์„ฑ, ์œ ์—ฐ์„ฑ ๋ฐ ์ง๊ด€์ ์ธ API๋ฅผ ๋งค์šฐ ์ข‹์•„ํ•ฉ๋‹ˆ๋‹ค. ์šฐ๋ฆฌ์˜ ๋ชฉํ‘œ๋Š” ์—ฌ๋Ÿฌ๋ถ„์ด ์œ„์˜ ์˜ˆ์‹œ๋ฅผ ๋น„๋กฏํ•œ ๋ชจ๋“  ๊ฐ€๋Šฅํ•œ ํ™˜๊ฒฝ์—์„œ ์ด ํ•ต์‹ฌ ๋””์ž์ธ ์›์น™์„ ์ด์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์‹œ๋‚˜๋ฆฌ์˜ค ์ค‘ ํ•˜๋‚˜๊ฐ€ ์—ฌ๋Ÿฌ๋ถ„์˜ ์‚ฌ๋ก€์— ํ•ด๋‹นํ•˜๊ฑฐ๋‚˜, ๋‹จ์ˆœํžˆ ๊ด€์‹ฌ์ด ์žˆ๊ฑฐ๋‚˜ ๊ถ๊ธˆํ•˜๋‹ค๋ฉด ์•„๋ž˜ ๋‚ด์šฉ์„ ํ†ตํ•ด C++ ํ”„๋ก ํŠธ์—”๋“œ์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์‚ดํŽด๋ณด์„ธ์š”.

Tip

C++ ํ”„๋ก ํŠธ์—”๋“œ๋Š” ํŒŒ์ด์ฌ ํ”„๋ก ํŠธ์—”๋“œ์™€ ์ตœ๋Œ€ํ•œ ์œ ์‚ฌํ•œ API๋ฅผ

์ œ๊ณตํ•˜๊ณ ์ž ํ•ฉ๋‹ˆ๋‹ค. ๋งŒ์ผ ํŒŒ์ด์ฌ ํ”„๋ก ํŠธ์—”๋“œ์— ์ต์ˆ™ํ•œ ์‚ฌ๋žŒ์ด "C++ ํ”„๋ก ํŠธ์—”๋“œ๋กœ X๋ฅผ ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ•˜๋Š”๊ฐ€?" ์˜๋ฌธ์„ ๊ฐ–๋Š”๋‹ค๋ฉด, ๋งŽ์€ ๊ฒฝ์šฐ์— ํŒŒ์ด์ฌ์—์„œ์™€ ๊ฐ™์€ ๋ฐฉ์‹์œผ๋กœ ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•ด ํŒŒ์ด์ฌ์—์„œ์™€ ๋™์ผํ•œ ํ•จ์ˆ˜์™€ ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค. (๋‹ค๋งŒ, ์˜จ์ ์„ ๋”๋ธ” ์ฝœ๋ก ์œผ๋กœ ๋ฐ”๊พธ๋Š” ๊ฒƒ์— ์œ ์˜ํ•˜์„ธ์š”.)

๊ธฐ๋ณธ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์ž‘์„ฑํ•˜๊ธฐ

๋จผ์ € ์ตœ์†Œํ•œ์˜ C++ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์„ ์ž‘์„ฑํ•ด ์šฐ๋ฆฌ์˜ ์„ค์ •๊ณผ ๋นŒ๋“œ ํ™˜๊ฒฝ์ด ๋™์ผํ•œ์ง€ ํ™•์ธํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ๋จผ์ €, C++ ํ”„๋ก ํŠธ์—”๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐ ํ•„์š”ํ•œ ๋ชจ๋“  ๊ด€๋ จ ํ—ค๋”, ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๋ฐ CMake ๋นŒ๋“œ ํŒŒ์ผ์„ ํŒจํ‚ค์ง•ํ•˜๋Š” LibTorch ๋ฐฐํฌํŒ์˜ ์‚ฌ๋ณธ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ๋ฆฌ๋ˆ…์Šค, ๋งฅOS, ์œˆ๋„์šฐ์šฉ LibTorch ๋ฐฐํฌํŒ์€ PyTorch website ์—์„œ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ํŠœํ† ๋ฆฌ์–ผ์˜ ๋‚˜๋จธ์ง€ ๋ถ€๋ถ„์€ ๊ธฐ๋ณธ ์šฐ๋ถ„ํˆฌ ๋ฆฌ๋ˆ…์Šค ํ™˜๊ฒฝ์„ ๊ฐ€์ •ํ•˜์ง€๋งŒ ๋งฅOS๋‚˜ ์œˆ๋„์šฐ๋ฅผ ์‚ฌ์šฉํ•˜์…”๋„ ๊ดœ์ฐฎ์Šต๋‹ˆ๋‹ค.

Tip

PyTorch C++ ๋ฐฐํฌํŒ ์„ค์น˜ ์˜ ์„ค๋ช…์— ๋‹ค์Œ์˜ ๊ณผ์ •์ด ๋” ์ž์„ธํžˆ ์•ˆ๋‚ด๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

Tip

์œˆ๋„์šฐ์—์„œ๋Š” ๋””๋ฒ„๊ทธ ๋ฐ ๋ฆด๋ฆฌ์Šค ๋นŒ๋“œ๊ฐ€ ABI์™€ ํ˜ธํ™˜๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ๋ฅผ ๋””๋ฒ„๊ทธ ๋ชจ๋“œ๋กœ ๋นŒ๋“œํ•˜๋ ค๋ฉด LibTorch์˜ ๋””๋ฒ„๊ทธ ๋ฒ„์ „์„ ์‚ฌ์šฉํ•ด๋ณด์„ธ์š”. ์•„๋ž˜์˜ cmake --build . ์— ์˜ฌ๋ฐ”๋ฅธ ์„ค์ •์„ ์ง€์ •ํ•˜๋Š” ๊ฒƒ๋„ ์žŠ์ง€ ๋งˆ์„ธ์š”.

๊ฐ€์žฅ ๋จผ์ € ํ•  ๊ฒƒ์€ PyTorch ์›น์‚ฌ์ดํŠธ์—์„œ ๊ฒ€์ƒ‰๋œ ๋งํฌ๋ฅผ ํ†ตํ•ด LibTorch ๋ฐฐํฌํŒ์„ ๋กœ์ปฌ์— ๋‹ค์šด๋กœ๋“œํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ผ๋ฐ˜์  Ubuntu Linux ํ™˜๊ฒฝ์˜ ๊ฒฝ์šฐ ๋‹ค์Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.

# CUDA 9.0 ๋“ฑ์— ๋Œ€ํ•œ ์ง€์›์ด ํ•„์š”ํ•œ ๊ฒฝ์šฐ ์•„๋ž˜ URL์—์„œ "cpu"๋ฅผ "cu90"๋กœ ๋ฐ”๊พธ์„ธ์š”.
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip

๋‹ค์Œ์œผ๋กœ torch/torch.h ๋ฅผ ํ˜ธ์ถœํ•˜๋Š” dcgan.cpp ๋ผ๋Š” ์ด๋ฆ„์˜ C++ ํŒŒ์ผ ํ•˜๋‚˜๋ฅผ ์ž‘์„ฑํ•ฉ์‹œ๋‹ค. ์šฐ์„ ์€ ์•„๋ž˜์™€ ๊ฐ™์ด 3x3 ํ•ญ๋“ฑ ํ–‰๋ ฌ์„ ์ถœ๋ ฅํ•˜๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

#include <torch/torch.h>
#include <iostream>

int main() {
  torch::Tensor tensor = torch::eye(3);
  std::cout << tensor << std::endl;
}

์ด ์ž‘์€ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜๊ณผ ์ดํ›„ ์™„์„ฑํ•  ํ•™์Šต์šฉ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ๋นŒ๋“œํ•˜๊ธฐ ์œ„ํ•ด ์šฐ๋ฆฌ๋Š” ์•„๋ž˜์˜ CMakeLists.txt ๋ฅผ ์‚ฌ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค:

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(dcgan)

find_package(Torch REQUIRED)

add_executable(dcgan dcgan.cpp)
target_link_libraries(dcgan "${TORCH_LIBRARIES}")
set_property(TARGET dcgan PROPERTY CXX_STANDARD 14)

Note

CMake๋Š” LibTorch์— ๊ถŒ์žฅ๋˜๋Š” ๋นŒ๋“œ ์‹œ์Šคํ…œ์ด์ง€๋งŒ ํ•„์ˆ˜ ์š”๊ตฌ ์‚ฌํ•ญ์€ ์•„๋‹™๋‹ˆ๋‹ค. Visual Studio ํ”„๋กœ์ ํŠธ ํŒŒ์ผ, QMake, ์ผ๋ฐ˜ Make ํŒŒ์ผ ๋“ฑ ๋‹ค๋ฅธ ๋นŒ๋“œ ํ™˜๊ฒฝ์„ ์‚ฌ์šฉํ•ด๋„ ๋ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ด์— ๋Œ€ํ•œ ์ฆ‰๊ฐ์ ์ธ ์ง€์›์€ ์ œ๊ณตํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

์œ„ CMake ํŒŒ์ผ 4๋ฒˆ์งธ ์ค„์˜ find_package(Torch REQUIRED) ๋Š” CMake๊ฐ€ LibTorch ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๋นŒ๋“œ ์„ค์ •์„ ์ฐพ๋„๋ก ์•ˆ๋‚ดํ•ฉ๋‹ˆ๋‹ค. CMake๊ฐ€ ํ•ด๋‹น ํŒŒ์ผ์˜ ์œ„์น˜ ๋ฅผ ์ฐพ์„ ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋ ค๋ฉด cmake ํ˜ธ์ถœ ์‹œ CMAKE_PREFIX_PATH ๋ฅผ ์„ค์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด์— ์•ž์„œ dcgan ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์— ๋Œ€ํ•ด ๋””๋ ‰ํ„ฐ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ†ต์ผํ•˜๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

dcgan/
  CMakeLists.txt
  dcgan.cpp

๋˜ํ•œ ์•ž์œผ๋กœ ์••์ถ• ํ•ด์ œ๋œ LibTorch ๋ฐฐํฌํŒ์˜ ๊ฒฝ๋กœ๋ฅผ /path/to/libtorch ๋กœ ๋ถ€๋ฅด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ๋ฐ˜๋“œ์‹œ ์ ˆ๋Œ€ ๊ฒฝ๋กœ์—ฌ์•ผ ํ•ฉ๋‹ˆ๋‹ค. ํŠนํžˆ CMAKE_PREFIX_PATH ๋ฅผ ../../libtorch ์™€ ๊ฐ™์ด ์„ค์ •ํ•˜๋ฉด ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ณด๋‹ค๋Š” $PWD/../../libtorch ์™€ ๊ฐ™์ด ํ•ด๋‹น ์ ˆ๋Œ€ ๊ฒฝ๋กœ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”. ์ด์ œ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์„ ๋นŒ๋“œํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

root@fa350df05ecf:/home# mkdir build
root@fa350df05ecf:/home# cd build
root@fa350df05ecf:/home/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Found torch: /path/to/libtorch/lib/libtorch.so
-- Configuring done
-- Generating done
-- Build files have been written to: /home/build
root@fa350df05ecf:/home/build# cmake --build . --config Release
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcgan

์œ„์—์„œ ์šฐ๋ฆฌ๋Š” ๋จผ์ € dcgan ๋””๋ ‰ํ† ๋ฆฌ ์•ˆ์— build ํด๋”๋ฅผ ๋งŒ๋“ค๊ณ  ์ด ํด๋”์— ๋“ค์–ด๊ฐ€์„œ ํ•„์š”ํ•œ ๋นŒ๋“œ(Make) ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๋Š” cmake ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•œ ํ›„ cmake --build . --config Release ๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ํ”„๋กœ์ ํŠธ๋ฅผ ์„ฑ๊ณต์ ์œผ๋กœ ์ปดํŒŒ์ผํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด์ œ ์šฐ๋ฆฌ์˜ ์ž‘์€ ๋ฐ”์ด๋„ˆ๋ฆฌ๋ฅผ ์‹คํ–‰ํ•˜๊ณ  ๊ธฐ๋ณธ ํ”„๋กœ์ ํŠธ ์„ค์ •์— ๋Œ€ํ•œ ์ด ์„น์…˜์„ ์™„๋ฃŒํ•  ์ค€๋น„๊ฐ€ ๋์Šต๋‹ˆ๋‹ค.

root@fa350df05ecf:/home/build# ./dcgan
1  0  0
0  1  0
0  0  1
[ Variable[CPUFloatType]{3,3} ]

์ œ๊ฐ€ ๋ณด๊ธฐ์—” ํ•ญ๋“ฑ ํ–‰๋ ฌ์ธ ๊ฒƒ ๊ฐ™๊ตฐ์š”!

์‹ ๊ฒฝ๋ง ๋ชจ๋ธ ์ •์˜ํ•˜๊ธฐ

์ด์ œ ๊ธฐ๋ณธ์ ์ธ ํ™˜๊ฒฝ์„ ์„ค์ •ํ–ˆ์œผ๋‹ˆ, ์ด๋ฒˆ ํŠœํ† ๋ฆฌ์–ผ์—์„œ ํ›จ์”ฌ ๋” ํฅ๋ฏธ๋กœ์šด ๋ถ€๋ถ„์„ ์‚ดํŽด๋ด…์‹œ๋‹ค. ๋จผ์ € C++ ํ”„๋ก ํŠธ์—”๋“œ์—์„œ ๋ชจ๋“ˆ์„ ์ •์˜ํ•˜๊ณ  ์ƒํ˜ธ ์ž‘์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ๋…ผ์˜ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ๊ธฐ๋ณธ์ ์ธ ์†Œ๊ทœ๋ชจ ์˜ˆ์ œ ๋ชจ๋“ˆ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜์—ฌ C++ ํ”„๋ก ํŠธ์—”๋“œ๊ฐ€ ์ œ๊ณตํ•˜๋Š” ๋‹ค์–‘ํ•œ ๋‚ด์žฅ ๋ชจ๋“ˆ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์™„์„ฑ๋„ ์žˆ๋Š” GAN์„ ๊ตฌํ˜„ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

๋ชจ๋“ˆ API ๊ธฐ์ดˆ

ํŒŒ์ด์ฌ ์ธํ„ฐํŽ˜์ด์Šค์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ, C++ ํ”„๋ก ํŠธ์—”๋“œ์— ๊ธฐ๋ฐ˜์„ ๋‘” ์‹ ๊ฒฝ๋ง๋„ ๋ชจ๋“ˆ ์ด๋ผ ๋ถˆ๋ฆฌ๋Š” ์žฌ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋นŒ๋”ฉ ๋ธ”๋ก์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ํŒŒ์ด์ฌ์— ๋‹ค๋ฅธ ๋ชจ๋“  ๋ชจ๋“ˆ์ด ํŒŒ์ƒ๋˜๋Š” torch.nn.Module ๋ผ๋Š” ๊ธฐ๋ณธ ๋ชจ๋“ˆ ํด๋ž˜์Šค๊ฐ€ ์žˆ๋“ฏ์ด C++์—๋Š” torch::nn::Module ํด๋ž˜์Šค๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ ๋ชจ๋“ˆ์—๋Š” ์บก์Šํ™”๋œ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๊ตฌํ˜„ํ•˜๋Š” forward() ๋ฉ”์„œ๋“œ๋ฅผ ๋น„๋กฏํ•ด ๋งค๊ฐœ๋ณ€์ˆ˜, ๋ฒ„ํผ ๋ฐ ํ•˜์œ„ ๋ชจ๋“ˆ ์„ธ ๊ฐ€์ง€ ํ•˜์œ„ ๊ฐ์ฒด๊ฐ€ ํฌํ•จ๋ฉ๋‹ˆ๋‹ค.

๋งค๊ฐœ๋ณ€์ˆ˜์™€ ๋ฒ„ํผ๋Š” ํ…์„œ์˜ ํ˜•ํƒœ๋กœ ์ƒํƒœ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ๊ธฐ๋กํ•˜์ง€๋งŒ ๋ฒ„ํผ๋Š” ๊ธฐ๋กํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ์‹ ๊ฒฝ๋ง์˜ ํ•™์Šต ๊ฐ€๋Šฅํ•œ ๊ฐ€์ค‘์น˜์ž…๋‹ˆ๋‹ค. ๋ฒ„ํผ์˜ ์˜ˆ๋กœ๋Š” ๋ฐฐ์น˜ ์ •๊ทœํ™”๋ฅผ ์œ„ํ•œ ํ‰๊ท  ๋ฐ ๋ถ„์‚ฐ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ํŠน์ • ๋…ผ๋ฆฌ ๋ฐ ์ƒํƒœ ๋ธ”๋ก์„ ์žฌ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด, PyTorch API๋Š” ๋ชจ๋“ˆ๋“ค์ด ์ค‘์ฒฉ๋˜๋Š” ๊ฒƒ์„ ํ—ˆ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ค‘์ฒฉ๋œ ๋ชจ๋“ˆ์€ ํ•˜์œ„ ๋ชจ๋“ˆ ์ด๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

๋งค๊ฐœ๋ณ€์ˆ˜, ๋ฒ„ํผ ๋ฐ ํ•˜์œ„ ๋ชจ๋“ˆ์€ ๋ช…์‹œ์ ์œผ๋กœ ๋“ฑ๋ก(register)์„ ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋“ฑ๋ก์ด ๋˜๋ฉด parameters() ๋‚˜ buffers() ๊ฐ™์€ ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ (์ค‘์ฒฉ์„ ํฌํ•จํ•œ) ์ „์ฒด ๋ชจ๋“ˆ ๊ณ„์ธต ๊ตฌ์กฐ์—์„œ ๋ชจ๋“  ๋งค๊ฐœ๋ณ€์ˆ˜ ๋ฌถ์Œ์„ ๊ฒ€์ƒ‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ, to(...) ์™€ ๊ฐ™์€ ๋ฉ”์„œ๋“œ๋Š” ๋ชจ๋“ˆ ๊ณ„์ธต ๊ตฌ์กฐ ์ „์ฒด์— ๋Œ€ํ•œ ๋ฉ”์„œ๋“œ์ž…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, to(torch::kCUDA) ๋Š” ๋ชจ๋“  ๋งค๊ฐœ๋ณ€์ˆ˜์™€ ๋ฒ„ํผ๋ฅผ CPU์—์„œ CUDA ๋ฉ”๋ชจ๋ฆฌ๋กœ ์ด๋™์‹œํ‚ต๋‹ˆ๋‹ค.

๋ชจ๋“ˆ ์ •์˜ ๋ฐ ๋งค๊ฐœ๋ณ€์ˆ˜ ๋“ฑ๋ก

์ด ๋‚ด์šฉ์„ ์ฝ”๋“œ๋กœ ๊ตฌํ˜„ํ•˜๊ธฐ ์œ„ํ•ด, ํŒŒ์ด์ฌ ์ธํ„ฐํŽ˜์ด์Šค๋กœ ์ž‘์„ฑ๋œ ๊ฐ„๋‹จํ•œ ๋ชจ๋“ˆ ํ•˜๋‚˜๋ฅผ ์ƒ๊ฐํ•ด ๋ด…์‹œ๋‹ค.

import torch

class Net(torch.nn.Module):
  def __init__(self, N, M):
    super(Net, self).__init__()
    self.W = torch.nn.Parameter(torch.randn(N, M))
    self.b = torch.nn.Parameter(torch.randn(M))

  def forward(self, input):
    return torch.addmm(self.b, input, self.W)

์ด๋ฅผ C++๋กœ ์ž‘์„ฑํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

#include <torch/torch.h>

struct Net : torch::nn::Module {
  Net(int64_t N, int64_t M) {
    W = register_parameter("W", torch::randn({N, M}));
    b = register_parameter("b", torch::randn(M));
  }
  torch::Tensor forward(torch::Tensor input) {
    return torch::addmm(b, input, W);
  }
  torch::Tensor W, b;
};

ํŒŒ์ด์ฌ์—์„œ์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ๋ชจ๋“ˆ ๊ธฐ๋ณธ ํด๋ž˜์Šค์—์„œ ํŒŒ์ƒํ•œ Net ์ด๋ผ๋Š” ํด๋ž˜์Šค๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. (์‰ฌ์šด ์„ค๋ช…์„ ์œ„ํ•ด class ๋Œ€์‹  struct ์„ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค.) ํŒŒ์ด์ฌ์—์„œ torch.randn์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ ์ƒ์„ฑ์ž์—์„œ๋Š” torch::randn ์„ ์‚ฌ์šฉํ•ด ํ…์„œ๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ํ•œ ๊ฐ€์ง€ ํฅ๋ฏธ๋กœ์šด ์ฐจ์ด์ ์€ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๋“ฑ๋กํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ํŒŒ์ด์ฌ์—์„œ๋Š” ํ…์„œ๋ฅผ torch.nn ์œผ๋กœ ๊ฐ์‹ธ๋Š” ๊ฒƒ๊ณผ ๋‹ฌ๋ฆฌ, C++์—์„œ๋Š” register_parameter ๋ฉ”์„œ๋“œ๋ฅผ ํ†ตํ•ด ํ…์„œ๋ฅผ ์ „๋‹ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ฐจ์ด์˜ ์›์ธ์€ ํŒŒ์ด์ฌ API์˜ ๊ฒฝ์šฐ, ์–ด๋–ค ์†์„ฑ(attirbute)์ด torch.nn.Parameter ํƒ€์ž…์ธ์ง€ ๊ฐ์ง€ํ•ด ๊ทธ๋Ÿฌํ•œ ํ…์„œ๋ฅผ ์ž๋™์œผ๋กœ ๋“ฑ๋กํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋‚˜ํƒ€๋‚ฉ๋‹ˆ๋‹ค. C++์—์„œ๋Š” ๋ฆฌํ”Œ๋ ‰์…˜(reflection)์ด ๋งค์šฐ ์ œํ•œ์ ์ด๋ฏ€๋กœ ๋ณด๋‹ค ์ „ํ†ต์ ์ธ (๊ทธ๋ฆฌํ•˜์—ฌ ๋œ ๋งˆ๋ฒ•์ ์ธ) ๋ฐฉ์‹์ด ์ œ๊ณต๋ฉ๋‹ˆ๋‹ค.

์„œ๋ธŒ๋ชจ๋“ˆ ๋“ฑ๋ก ๋ฐ ๋ชจ๋“ˆ ๊ณ„์ธต ๊ตฌ์กฐ ํƒ์ƒ‰

๋งค๊ฐœ๋ณ€์ˆ˜ ๋“ฑ๋ก๊ณผ ๋งˆ์ฐฌ๊ฐ€์ง€ ๋ฐฉ๋ฒ•์œผ๋กœ ์„œ๋ธŒ๋ชจ๋“ˆ์„ ๋“ฑ๋กํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํŒŒ์ด์ฌ์—์„œ ์„œ๋ธŒ๋ชจ๋“ˆ์€ ์–ด๋–ค ๋ชจ๋“ˆ์˜ ์†์„ฑ์œผ๋กœ ์ง€์ •๋  ๋•Œ ์ž๋™์œผ๋กœ ๊ฐ์ง€๋˜๊ณ  ๋“ฑ๋ก๋ฉ๋‹ˆ๋‹ค.

class Net(torch.nn.Module):
  def __init__(self, N, M):
      super(Net, self).__init__()
      # Registered as a submodule behind the scenes
      self.linear = torch.nn.Linear(N, M)
      self.another_bias = torch.nn.Parameter(torch.rand(M))

  def forward(self, input):
    return self.linear(input) + self.another_bias

์˜ˆ๋ฅผ ๋“ค์–ด, parameters() ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋ชจ๋“ˆ ๊ณ„์ธต์˜ ๋ชจ๋“  ๋งค๊ฐœ๋ณ€์ˆ˜์— ์žฌ๊ท€์ ์œผ๋กœ ์•ก์„ธ์Šคํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> net = Net(4, 5)
>>> print(list(net.parameters()))
[Parameter containing:
tensor([0.0808, 0.8613, 0.2017, 0.5206, 0.5353], requires_grad=True), Parameter containing:
tensor([[-0.3740, -0.0976, -0.4786, -0.4928],
        [-0.1434,  0.4713,  0.1735, -0.3293],
        [-0.3467, -0.3858,  0.1980,  0.1986],
        [-0.1975,  0.4278, -0.1831, -0.2709],
        [ 0.3730,  0.4307,  0.3236, -0.0629]], requires_grad=True), Parameter containing:
tensor([ 0.2038,  0.4638, -0.2023,  0.1230, -0.0516], requires_grad=True)]

C++์—์„œ torch::nn::Linear ๋“ฑ์˜ ๋ชจ๋“ˆ์„ ์„œ๋ธŒ๋ชจ๋“ˆ๋กœ ๋“ฑ๋กํ•˜๋ ค๋ฉด ์ด๋ฆ„์—์„œ ์œ ์ถ”ํ•  ์ˆ˜ ์žˆ๋“ฏ์ด register_module() ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

struct Net : torch::nn::Module {
  Net(int64_t N, int64_t M)
      : linear(register_module("linear", torch::nn::Linear(N, M))) {
    another_bias = register_parameter("b", torch::randn(M));
  }
  torch::Tensor forward(torch::Tensor input) {
    return linear(input) + another_bias;
  }
  torch::nn::Linear linear;
  torch::Tensor another_bias;
};

Tip

torch::nn ์— ๋Œ€ํ•œ ์ด ๋ฌธ์„œ ์—์„œ torch::nn::Linear, torch::nn::Dropout, torch::nn::Conv2d ๋“ฑ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์ „์ฒด ๋นŒํŠธ์ธ ๋ชจ๋“ˆ ๋ชฉ๋ก์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์œ„ ์ฝ”๋“œ์—์„œ ํ•œ ๊ฐ€์ง€ ๋ฏธ๋ฌ˜ํ•œ ์‚ฌ์‹ค์€ ์„œ๋ธŒ๋ชจ๋“ˆ์€ ์ƒ์„ฑ์ž์˜ ์ด๋‹ˆ์…œ๋ผ์ด์ € ๋ชฉ๋ก์— ์ž‘์„ฑ๋˜๊ณ  ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ์ƒ์„ฑ์ž์˜ ๋ฐ”๋””(body)์— ์ž‘์„ฑ๋˜์—ˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์—๋Š” ์ถฉ๋ถ„ํ•œ ์ด์œ ๊ฐ€ ์žˆ์œผ๋ฉฐ ์•„๋ž˜ C++ ํ”„๋ก ํŠธ์—”๋“œ์˜ ์˜ค๋„ˆ์‹ญ ๋ชจ๋ธ ์„น์…˜์—์„œ ๋” ๋‹ค๋ฃฐ ์˜ˆ์ •์ž…๋‹ˆ๋‹ค. ๊ทธ๋ ‡์ง€๋งŒ ์ตœ์ข… ๊ฒฐ๋ก ์€ ํŒŒ์ด์ฌ์—์„œ์ฒ˜๋Ÿผ ๋ชจ๋“ˆ ํŠธ๋ฆฌ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜์— ์žฌ๊ท€์ ์œผ๋กœ ์•ก์„ธ์Šคํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. parameters() ๋ฅผ ํ˜ธ์ถœํ•˜๋ฉด ์ˆœํšŒ๊ฐ€ ๊ฐ€๋Šฅํ•œ std::vector<torch::Tensor> ๊ฐ€ ๋ฐ˜ํ™˜๋ฉ๋‹ˆ๋‹ค.

int main() {
  Net net(4, 5);
  for (const auto& p : net.parameters()) {
    std::cout << p << std::endl;
  }
}

์ด๋ฅผ ์‹คํ–‰ํ•œ ๊ฒฐ๊ณผ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

root@fa350df05ecf:/home/build# ./dcgan
0.0345
1.4456
-0.6313
-0.3585
-0.4008
[ Variable[CPUFloatType]{5} ]
-0.1647  0.2891  0.0527 -0.0354
0.3084  0.2025  0.0343  0.1824
-0.4630 -0.2862  0.2500 -0.0420
0.3679 -0.1482 -0.0460  0.1967
0.2132 -0.1992  0.4257  0.0739
[ Variable[CPUFloatType]{5,4} ]
0.01 *
3.6861
-10.1166
-45.0333
7.9983
-20.0705
[ Variable[CPUFloatType]{5} ]

ํŒŒ์ด์ฌ์—์„œ์™€ ๊ฐ™์ด ์„ธ ๊ฐœ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๊ฐ€ ์ถœ๋ ฅ๋์Šต๋‹ˆ๋‹ค. ์ด ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค์˜ ์ด๋ฆ„์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋„๋ก C++ API๋Š” named_parameters() ๋ฉ”์„œ๋“œ๋ฅผ ์ œ๊ณตํ•˜๋ฉฐ, ์ด๋Š” ํŒŒ์ด์ฌ์—์„œ์™€ ๊ฐ™์ด Orderdict ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

Net net(4, 5);
for (const auto& pair : net.named_parameters()) {
  std::cout << pair.key() << ": " << pair.value() << std::endl;
}

๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•˜๋ฉด ๊ฒฐ๊ณผ๋Š” ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

root@fa350df05ecf:/home/build# make && ./dcgan                                                                                                                                            11:13:48
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcgan
b: -0.1863
-0.8611
-0.1228
1.3269
0.9858
[ Variable[CPUFloatType]{5} ]
linear.weight:  0.0339  0.2484  0.2035 -0.2103
-0.0715 -0.2975 -0.4350 -0.1878
-0.3616  0.1050 -0.4982  0.0335
-0.1605  0.4963  0.4099 -0.2883
0.1818 -0.3447 -0.1501 -0.0215
[ Variable[CPUFloatType]{5,4} ]
linear.bias: -0.0250
0.0408
0.3756
-0.2149
-0.3636
[ Variable[CPUFloatType]{5} ]

Note

torch::nn::Module ์— ๋Œ€ํ•œ ๋ฌธ์„œ ๋Š” ๋ชจ๋“ˆ ๊ณ„์ธต ๊ตฌ์กฐ์— ๋Œ€ํ•œ ๋ฉ”์„œ๋“œ ๋ชฉ๋ก ์ „์ฒด๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

์ˆœ์ „ํŒŒ(forward) ๋ชจ๋“œ๋กœ ๋„คํŠธ์›Œํฌ ์‹คํ–‰

๋„คํŠธ์›Œํฌ๋ฅผ C++๋กœ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š”, ์šฐ๋ฆฌ๊ฐ€ ์ •์˜ํ•œ forward() ๋ฉ”์„œ๋“œ๋ฅผ ํ˜ธ์ถœํ•˜๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

int main() {
  Net net(4, 5);
  std::cout << net.forward(torch::ones({2, 4})) << std::endl;
}

์ถœ๋ ฅ์€ ๋Œ€๋žต ์•„๋ž˜์™€ ๊ฐ™์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค

root@fa350df05ecf:/home/build# ./dcgan
0.8559  1.1572  2.1069 -0.1247  0.8060
0.8559  1.1572  2.1069 -0.1247  0.8060
[ Variable[CPUFloatType]{2,5} ]

๋ชจ๋“ˆ ์˜ค๋„ˆ์‹ญ (Ownership)

์ด์ œ ์šฐ๋ฆฌ๋Š” C++์—์„œ ๋ชจ๋“ˆ์„ ์ •์˜ํ•˜๊ณ , ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๋“ฑ๋กํ•˜๊ณ , ํ•˜์œ„ ๋ชจ๋“ˆ์„ ๋“ฑ๋กํ•˜๊ณ , parameters() ๋“ฑ์˜ ๋ฉ”์„œ๋“œ๋ฅผ ํ†ตํ•ด ๋ชจ๋“ˆ ๊ณ„์ธต์„ ํƒ์ƒ‰ํ•˜๊ณ , ๋ชจ๋“ˆ์˜ forward() ๋ฉ”์„œ๋“œ๋ฅผ ์‹คํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ฐฐ์› ์Šต๋‹ˆ๋‹ค. C++ API์—๋Š” ๋‹ค๋ฅธ ๋ฉ”์„œ๋“œ, ํด๋ž˜์Šค, ๊ทธ๋ฆฌ๊ณ  ์ฃผ์ œ๊ฐ€ ๋งŽ์ง€๋งŒ ์ „์ฒด ๋ชฉ๋ก์€ ๋ฌธ์„œ ๋ฅผ ์ฐธ์กฐํ•˜์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค. ์ž ์‹œ ํ›„์— DCGAN ๋ชจ๋ธ๊ณผ ์—”๋“œ ํˆฌ ์—”๋“œ ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ์„ ๊ตฌํ˜„ํ•˜๋ฉด์„œ๋„ ๋ช‡ ๊ฐ€์ง€ ๊ฐœ๋…์„ ๋” ๋‹ค๋ฃฐ ์˜ˆ์ •์ž…๋‹ˆ๋‹ค. ๊ทธ์— ์•ž์„œ C++ ํ”„๋ก ํŠธ์—”๋“œ์—์„œ torch::nn::Module ์˜ ํ•˜์œ„ ํด๋ž˜์Šค๋“ค์— ๋Œ€ํ•ด ์ œ๊ณตํ•˜๋Š” ์˜ค๋„ˆ์‹ญ ๋ชจ๋ธ ์— ๋Œ€ํ•ด ๊ฐ„๋‹จํžˆ ์„ค๋ช…ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

์ด ๋…ผ์˜์—์„œ ์˜ค๋„ˆ์‹ญ ๋ชจ๋ธ์ด๋ž€ ๋ชจ๋“ˆ์„ ์ €์žฅํ•˜๊ณ  ์ „๋‹ฌํ•˜๋Š” ๋ฐฉ์‹ (๋ˆ„๊ฐ€ ํ˜น์€ ๋ฌด์—‡์ด ํŠน์ • ๋ชจ๋“ˆ ์ธ์Šคํ„ด์Šค๋ฅผ ์†Œ์œ ํ•˜๋Š”์ง€)์„ ์ง€์นญํ•ฉ๋‹ˆ๋‹ค. ํŒŒ์ด์ฌ์—์„œ ๊ฐ์ฒด๋Š” ํ•ญ์ƒ ํž™์— ๋™์ ์œผ๋กœ ํ• ๋‹น๋˜๋ฉฐ ๋ ˆํผ๋Ÿฐ์Šค ์‹œ๋งจํ‹ฑ์„ ๊ฐ€์ง€๋Š”๋ฐ, ์ด๋Š” ๋‹ค๋ฃจ๊ณ  ์ดํ•ดํ•˜๊ธฐ๊ฐ€ ๋งค์šฐ ์‰ฝ์Šต๋‹ˆ๋‹ค. ์‹ค์ œ๋กœ ํŒŒ์ด์ฌ์—์„œ๋Š” ๊ฐ์ฒด๊ฐ€ ์–ด๋””์— ์กด์žฌํ•˜๊ณ  ์–ด๋–ป๊ฒŒ ๋ ˆํผ๋Ÿฐ์Šค๋˜๋Š”์ง€ ์‹ ๊ฒฝ ์“ฐ์ง€ ์•Š๊ณ  ํ•˜๋ ค๋Š” ์ผ์—๋งŒ ์ง‘์ค‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ €๊ธ‰ ์–ธ์–ด์ธ C++๋Š” ์ด ๋ถ€๋ถ„์—์„œ ๋” ๋งŽ์€ ์˜ต์…˜์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” C++ ํ”„๋ก ํŠธ์—”๋“œ์˜ ๋ณต์žก์„ฑ์„ ์ฆ๊ฐ€์‹œํ‚ค๋ฉฐ ๊ทธ ์„ค๊ณ„์™€ ์ธ์ฒด๊ณตํ•™์  ์š”์†Œ์—๋„ ํฐ ์˜ํ–ฅ์„ ์ค๋‹ˆ๋‹ค. ํŠนํžˆ, C++ ํ”„๋ก ํŠธ์—”๋“œ ๋ชจ๋“ˆ์—์„œ๋Š” ๋ฐธ๋ฅ˜ ์‹œ๋งจํ‹ฑ ๋˜๋Š” ๋ ˆํผ๋Ÿฐ์Šค ์‹œ๋งจํ‹ฑ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ „์ž๊ฐ€ ์ง€๊ธˆ๊นŒ์ง€์˜ ์‚ฌ๋ก€์—์„œ ์‚ดํŽด๋ณธ ๊ฐ€์žฅ ๋‹จ์ˆœํ•œ ๊ฒฝ์šฐ๋กœ, ๋ชจ๋“ˆ ๊ฐ์ฒด๊ฐ€ ์Šคํƒ์— ํ• ๋‹น๋˜๊ณ  ํ•จ์ˆ˜์— ์ „๋‹ฌ๋  ๋•Œ ๋ ˆํผ๋Ÿฐ์Šค ํ˜น์€ ํฌ์ธํ„ฐ๋กœ ๋ณต์‚ฌ ๋ฐ ์ด๋™(std:move) ์‹œํ‚ค๊ฑฐ๋‚˜ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

struct Net : torch::nn::Module { };

void a(Net net) { }
void b(Net& net) { }
void c(Net* net) { }

int main() {
  Net net;
  a(net);
  a(std::move(net));
  b(net);
  c(&net);
}

ํ›„์ž(๋ ˆํผ๋Ÿฐ์Šค ์‹œ๋งจํ‹ฑ)์˜ ๊ฒฝ์šฐ, std::shared_ptr ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ชจ๋“  ๊ณณ์—์„œ shared_ptr ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค๋Š” ๊ฐ€์ •ํ•˜์—, ๋ ˆํผ๋Ÿฐ์Šค ์‹œ๋งจํ‹ฑ์˜ ์žฅ์ ์€ ํŒŒ์ด์ฌ์—์„œ์™€ ๊ฐ™์ด ๋ชจ๋“ˆ์ด ํ•จ์ˆ˜์— ์ „๋‹ฌ๋˜๊ณ  ์ธ์ž๊ฐ€ ์„ ์–ธ๋˜๋Š” ๋ฐฉ์‹์— ๋Œ€ํ•ด ์ƒ๊ฐํ•  ๋ถ€๋‹ด์„ ๋œ์–ด์ค€๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

struct Net : torch::nn::Module {};

void a(std::shared_ptr<Net> net) { }

int main() {
  auto net = std::make_shared<Net>();
  a(net);
}

๊ฒฝํ—˜์ ์œผ๋กœ, ๋™์  ์–ธ์–ด๋ฅผ ์‚ฌ์šฉํ•˜๋˜ ์—ฐ๊ตฌ์ž๋“ค์€ ๋น„๋ก ๋ฐธ๋ฅ˜ ์‹œ๋งจํ‹ฑ์ด ๋” C++์— "๋„ค์ดํ‹ฐ๋ธŒ"ํ•จ์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ  ๋ ˆํผ๋Ÿฐ์Šค ์‹œ๋งจํ‹ฑ์„ ํ›จ์”ฌ ์„ ํ˜ธํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ torch::nn::Module ์˜ ์„ค๊ณ„๋Š” ์‚ฌ์šฉ์ž ์นœํ™”์ ์ธ ํŒŒ์ด์ฌ API๋ฅผ ์œ ์‚ฌํ•˜๊ฒŒ ๋”ฐ๋ฅด๊ธฐ ์œ„ํ•ด shared ์˜ค๋„ˆ์‹ญ์— ์˜์กดํ•ฉ๋‹ˆ๋‹ค. ์•ž์„œ ์˜ˆ์‹œ๋กœ ๋“ค์—ˆ๋˜ Net ์˜ ์ •์˜๋ฅผ ์ถ•์•ฝํ•ด์„œ ๋‹ค์‹œ ์‚ดํŽด๋ด…์‹œ๋‹ค.

struct Net : torch::nn::Module {
  Net(int64_t N, int64_t M)
    : linear(register_module("linear", torch::nn::Linear(N, M)))
  { }
  torch::nn::Linear linear;
};

ํ•˜์œ„ ๋ชจ๋“ˆ์ธ linear ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด ์ด๋ฅผ ํด๋ž˜์Šค์— ์ง์ ‘ ์ €์žฅํ•˜๊ณ ์ž ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋™์‹œ์— ๋ชจ๋“ˆ์˜ ๊ธฐ์ดˆ ํด๋ž˜์Šค๊ฐ€ ์ด ํ•˜์œ„ ๋ชจ๋“ˆ์— ๋Œ€ํ•ด ์•Œ๊ณ  ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๊ธฐ๋ฅผ ์›ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด์„œ๋Š” ํ•ด๋‹น ํ•˜์œ„ ๋ชจ๋“ˆ์— ๋Œ€ํ•œ ์ฐธ์กฐ๋ฅผ ์ €์žฅํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด ์ˆœ๊ฐ„ ์ด๋ฏธ ์šฐ๋ฆฌ๋Š” shared ์˜ค๋„ˆ์‹ญ์„ ํ•„์š”๋กœ ํ•ฉ๋‹ˆ๋‹ค. torch::nn::Module ํด๋ž˜์Šค์™€ ๊ตฌ์ƒ ํด๋ž˜์Šค์ธ Net ๋ชจ๋‘์—์„œ ํ•˜์œ„ ๋ชจ๋“ˆ์— ๋Œ€ํ•œ ๋ ˆํผ๋Ÿฐ์Šค๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๊ธฐ์ดˆ ํด๋ž˜์Šค๋Š” ๋ชจ๋“ˆ์„ shared_ptr ๋กœ ์ €์žฅํ•˜๋ฉฐ ์ด์— ๋”ฐ๋ผ ๊ตฌ์ƒ ํด๋ž˜์Šค ๋˜ํ•œ ๋งˆ์ฐฌ๊ฐ€์ง€์ผ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

ํ•˜์ง€๋งŒ ์ž ๊น! ์œ„์˜ ์ฝ”๋“œ์—๋Š” shared_ptr ์— ๋Œ€ํ•œ ์–ธ๊ธ‰์ด ์—†์Šต๋‹ˆ๋‹ค! ์™œ ๊ทธ๋Ÿฐ ๊ฒƒ์ผ๊นŒ์š”? ์™œ๋ƒํ•˜๋ฉด std::shared_ptr<MyModule> ๋Š” ํƒ€์ดํ•‘ํ•˜๊ธฐ์— ๋„ˆ๋ฌด ๊ธธ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ์—ฐ๊ตฌ์›๋“ค์˜ ์ƒ์‚ฐ์„ฑ์„ ์œ ์ง€ํ•˜๊ธฐ ์œ„ํ•ด, ์šฐ๋ฆฌ๋Š” ๋ ˆํผ๋Ÿฐ์Šค ์‹œ๋งจํ‹ฑ์„ ์œ ์ง€ํ•˜๋ฉด์„œ ๋ฐธ๋ฅ˜ ์‹œ๋งจํ‹ฑ๋งŒ์˜ ์žฅ์ ์ธ shared_ptr ์— ๋Œ€ํ•œ ์–ธ๊ธ‰์„ ์ˆจ๊ธฐ๊ธฐ ์œ„ํ•œ ์ •๊ตํ•œ ๊ณ„ํš์„ ์„ธ์› ์Šต๋‹ˆ๋‹ค. ๊ทธ ์ž‘๋™ ๋ฐฉ์‹์„ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•ด ์ฝ”์–ด ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์— ์žˆ๋Š” torch::nn::Linear ๋ชจ๋“ˆ์˜ ๋‹จ์ˆœํ™”๋œ ์ •์˜๋ฅผ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. (์ „์ฒด ์ •์˜๋Š” ์—ฌ๊ธฐ ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.)

struct LinearImpl : torch::nn::Module {
  LinearImpl(int64_t in, int64_t out);

  Tensor forward(const Tensor& input);

  Tensor weight, bias;
};

TORCH_MODULE(Linear);

์š”์•ฝํ•˜์ž๋ฉด ์ด ๋ชจ๋“ˆ์€ Linear ๊ฐ€ ์•„๋‹Œ LinearImpl ์ด๋ผ๊ณ  ๋ถˆ๋ฆฝ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  TORCH_MODULE ๋ผ๋Š” ๋งคํฌ๋กœ๊ฐ€ ์‹ค์ œ Linear ํด๋ž˜์Šค๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ "์ƒ์„ฑ๋œ" ํด๋ž˜์Šค๋Š” std::shared_ptr<LinearImpl> ๋ฅผ ๊ฐ์‹ธ๋Š” ๋ž˜ํผ(wrapper)์ž…๋‹ˆ๋‹ค. ๋‹จ์ˆœํ•œ typedef๊ฐ€ ์•„๋‹Œ ๋ž˜ํผ์ด๋ฏ€๋กœ ์ƒ์„ฑ์ž๋„ ์—ฌ์ „ํžˆ ์˜ˆ์ƒํ•˜๋Š” ๋Œ€๋กœ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, std::make_shared<LinearImpl>(3, 4) ๊ฐ€ ์•„๋‹Œ torch::nn::Linear(3, 4) ๋ผ๊ณ  ์“ธ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ๋งคํฌ๋กœ์— ์˜ํ•ด ์ƒ์„ฑ๋œ ํด๋ž˜์Šค๋Š” holder ๋ชจ๋“ˆ์ด๋ผ๊ณ  ๋ถ€๋ฆ…๋‹ˆ๋‹ค. (shared) ํฌ์ธํ„ฐ์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ํ™”์‚ดํ‘œ ์—ฐ์‚ฐ์ž(์ฆ‰, model->forward(...))๋ฅผ ์‚ฌ์šฉํ•ด ๊ธฐ์ € ๊ฐ์ฒด์— ์•ก์„ธ์Šคํ•ฉ๋‹ˆ๋‹ค. ๊ฒฐ๋ก ์ ์œผ๋กœ ํŒŒ์ด์ฌ API์™€ ๋งค์šฐ ์œ ์‚ฌํ•œ ์˜ค๋„ˆ์‹ญ ๋ชจ๋ธ์„ ์–ป์—ˆ์Šต๋‹ˆ๋‹ค. ๊ธฐ๋ณธ์ ์œผ๋กœ ๋ ˆํผ๋Ÿฐ์Šค ์‹œ๋งจํ‹ฑ์„ ๋”ฐ๋ฅด์ง€๋งŒ, std:shared_ptr ๋‚˜ std::make_shared ๋“ฑ์„ ํƒ€์ดํ•‘ํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ์˜ Net ์˜ˆ์‹œ์—์„œ ๋ชจ๋“ˆ holder API๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

struct NetImpl : torch::nn::Module {};
TORCH_MODULE(Net);

void a(Net net) { }

int main() {
  Net net;
  a(net);
}

์—ฌ๊ธฐ์„œ ์–ธ๊ธ‰ํ•  ๋งŒํ•œ ๋ฏธ๋ฌ˜ํ•œ ๋ฌธ์ œ๊ฐ€ ํ•˜๋‚˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ธฐ๋ณธ ์ƒ์„ฑ์ž์— ์˜ํ•ด ๋งŒ๋“ค์–ด์ง„ std::shared_ptr ๋Š” "๋น„์–ด" ์žˆ์Šต๋‹ˆ๋‹ค. ์ฆ‰, null ํฌ์ธํ„ฐ์ž…๋‹ˆ๋‹ค. ๊ธฐ๋ณธ ์ƒ์„ฑ์ž๋กœ ๋งŒ๋“ค์–ด์ง„ Linear ์ด๋‚˜ Net ์€ ๋ฌด์—‡์ด์–ด์•ผ ํ• ๊นŒ์š”? ์Œ, ์ด๊ฑด ์–ด๋ ค์šด ๊ฒฐ์ •์ž…๋‹ˆ๋‹ค. ๋นˆ (null) std::shared_ptr<LinearImpl> ๋กœ ์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ Linear(3, 4) ๊ฐ€ std::make_shared<LinearImpl>(3, 4) ์™€ ๊ฐ™๋‹ค๋Š” ๊ฒƒ์„ ๊ธฐ์–ตํ•ฉ์‹œ๋‹ค. ์ฆ‰, Linear linear; ์ด null ํฌ์ธํ„ฐ์—ฌ์•ผ ํ•œ๋‹ค๊ณ  ๊ฒฐ์ •ํ•œ๋‹ค๋ฉด ์ƒ์„ฑ์ž์—์„œ ์ธ์ž๋ฅผ ์ „ํ˜€ ๋ฐ›์ง€ ์•Š๊ฑฐ๋‚˜ ๋ชจ๋“  ์ธ์ž์— ๋Œ€ํ•ด ๊ธฐ๋ณธ๊ฐ’์„ ์‚ฌ์šฉํ•˜๋Š” ๋ชจ๋“ˆ์„ ์ƒ์„ฑํ•  ๋ฐฉ๋ฒ•์ด ์—†์–ด์ง‘๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ด์œ ๋กœ ํ˜„์žฌ API์—์„œ ๊ธฐ๋ณธ ์ƒ์„ฑ์ž์— ์˜ํ•ด ๋งŒ๋“ค์–ด์ง„ ๋ชจ๋“ˆ holder(Linear() ๋“ฑ)๋Š” ๊ธฐ์ € ๋ชจ๋“ˆ(LinearImpl())์˜ ๊ธฐ๋ณธ ์ƒ์„ฑ์ž๋ฅผ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค. ๋งŒ์•ฝ ๊ธฐ์ € ๋ชจ๋“ˆ์— ๊ธฐ๋ณธ ์ƒ์„ฑ์ž๊ฐ€ ์—†์œผ๋ฉด ์ปดํŒŒ์ผ๋Ÿฌ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค. ๋ฐ˜๋Œ€๋กœ ๋นˆ holder๋ฅผ ์ƒ์„ฑํ•˜๋ ค๋ฉด holder ์ƒ์„ฑ์ž์— nullptr ๋ฅผ ์ „๋‹ฌํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

์‹ค์ œ๋กœ๋Š” ์•ž์—์„œ์™€ ๊ฐ™์ด ํ•˜์œ„ ๋ชจ๋“ˆ์„ ์‚ฌ์šฉํ•ด ๋ชจ๋“ˆ์„ ์ด๋‹ˆ์…œ๋ผ์ด์ € (initializer) ๋ชฉ๋ก ์— ๋“ฑ๋ก ๋ฐ ์ƒ์„ฑํ•˜๊ฑฐ๋‚˜,

struct Net : torch::nn::Module {
  Net(int64_t N, int64_t M)
    : linear(register_module("linear", torch::nn::Linear(N, M)))
  { }
  torch::nn::Linear linear;
};

ํŒŒ์ด์ฌ ์‚ฌ์šฉ์ž๋“ค์—๊ฒŒ ๋” ์นœ์ˆ™ํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ, ๋จผ์ € null ํฌ์ธํ„ฐ๋กœ ํ™€๋”๋ฅผ ์ƒ์„ฑํ•œ ์ดํ›„ ์ƒ์„ฑ์ž์—์„œ ๊ฐ’์„ ์ง€์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

struct Net : torch::nn::Module {
  Net(int64_t N, int64_t M) {
    linear = register_module("linear", torch::nn::Linear(N, M));
  }
  torch::nn::Linear linear{nullptr}; // construct an empty holder
};

๊ฒฐ๋ก ์ ์œผ๋กœ ์–ด๋–ค ์˜ค๋„ˆ์‹ญ ๋ชจ๋ธ, ์–ด๋–ค ์‹œ๋งจํ‹ฑ์„ ์‚ฌ์šฉํ•˜๋ฉด ์ข‹์„๊นŒ์š”? C++ ํ”„๋ก ํŠธ์—”๋“œ API๋Š” ๋ชจ๋“ˆ holder๊ฐ€ ์ œ๊ณตํ•˜๋Š” ์˜ค๋„ˆ์‹ญ ๋ชจ๋ธ์„ ๊ฐ€์žฅ ์ž˜ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. ์ด ๋ฉ”์ปค๋‹ˆ์ฆ˜์˜ ์œ ์ผํ•œ ๋‹จ์ ์€ ๋ชจ๋“ˆ ์„ ์–ธ ์•„๋ž˜์— boilerplate ํ•œ ์ค„์ด ์ถ”๊ฐ€๋œ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ฆ‰, ๊ฐ€์žฅ ๋‹จ์ˆœํ•œ ๋ชจ๋ธ์€ C++ ๋ชจ๋“ˆ์˜ ๊ธฐ์ดˆ๋ฅผ ๋ฐฐ์šธ ๋•Œ ๋‚˜์˜ค๋Š” ๋ฐธ๋ฅ˜ ์‹œ๋งจํ‹ฑ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์ž‘๊ณ  ๊ฐ„๋‹จํ•œ ์Šคํฌ๋ฆฝํŠธ์˜ ๊ฒฝ์šฐ, ์ด๊ฒƒ๋งŒ์œผ๋กœ ์ถฉ๋ถ„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์–ธ์  ๊ฐ€๋Š” ๊ธฐ์ˆ ์  ์ด์œ ๋กœ ์ธํ•ด ์ด ๊ธฐ๋Šฅ์ด ํ•ญ์ƒ ์ง€์›๋˜์ง€๋Š” ์•Š๋Š”๋‹ค๋Š” ์‚ฌ์‹ค์„ ์•Œ๊ฒŒ ๋  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ์ง๋ ฌํ™” API(torch::save ๋ฐ torch::load)๋Š” ๋ชจ๋“ˆ holder(ํ˜น์€ ์ผ๋ฐ˜ shared_ptr)๋งŒ์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ C++ ํ”„๋ก ํŠธ์—”๋“œ๋กœ ๋ชจ๋“ˆ์„ ์ •์˜ํ•  ๋•Œ์—๋Š” ๋ชจ๋“ˆ holder API ๋ฐฉ์‹์ด ๊ถŒ์žฅ๋˜๋ฉฐ, ์•ž์œผ๋กœ ๋ณธ ํŠœํ† ๋ฆฌ์–ผ์—์„œ ์ด API๋ฅผ ์‚ฌ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

DCGAN ๋ชจ๋“ˆ ์ •์˜ํ•˜๊ธฐ

์ด์ œ ์ด ๊ธ€์—์„œ ํ•ด๊ฒฐํ•˜๋ ค๋Š” ๋จธ์‹ ๋Ÿฌ๋‹ ํƒœ์Šคํฌ๋ฅผ ์œ„ํ•œ ๋ชจ๋“ˆ์„ ์ •์˜ํ•˜๋Š”๋ฐ ํ•„์š”ํ•œ ๋ฐฐ๊ฒฝ๊ณผ ๋„์ž…๋ถ€ ์„ค๋ช…์ด ๋๋‚ฌ์Šต๋‹ˆ๋‹ค. ๋‹ค์‹œ ์ƒ๊ธฐํ•˜์ž๋ฉด, ์šฐ๋ฆฌ์˜ ํƒœ์Šคํฌ๋Š” MNIST ๋ฐ์ดํ„ฐ์…‹ ์˜ ์ˆซ์ž ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ์ด ํƒœ์Šคํฌ๋ฅผ ํ’€๊ธฐ ์œ„ํ•ด ์ ๋Œ€์  ์ƒ์„ฑ ์‹ ๊ฒฝ๋ง(GAN) ์„ ์‚ฌ์šฉํ•˜๊ณ ์ž ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ ์ค‘์—์„œ๋„ ์šฐ๋ฆฌ๋Š” DCGAN ์•„ํ‚คํ…์ฒ˜ ๋ฅผ ์‚ฌ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. DCGAN์€ ๊ฐ€์žฅ ์ดˆ๊ธฐ์— ๋ฐœํ‘œ๋๋˜ ์ œ์ผ ๊ฐ„๋‹จํ•œ GAN์ด์ง€๋งŒ ์ด ํƒœ์Šคํฌ๋ฅผ ์œ„ํ•ด์„œ๋Š” ์ถฉ๋ถ„ํ•ฉ๋‹ˆ๋‹ค.

Tip

์ด ํŠœํ† ๋ฆฌ์–ผ์— ๋‚˜์˜จ ์†Œ์Šค ์ฝ”๋“œ ์ „์ฒด๋Š” ์ด ์ €์žฅ์†Œ ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

GAN์ด ๋ญ์˜€์ฃ ?

GAN์€ ์ƒ์„ฑ๊ธฐ(generator) ์™€ ํŒ๋ณ„๊ธฐ(discriminator) ๋ผ๋Š” ๋‘ ๊ฐ€์ง€ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค. ์ƒ์„ฑ๊ธฐ๋Š” ๋…ธ์ด์ฆˆ ๋ถ„ํฌ์—์„œ ์ƒ˜ํ”Œ์„ ์ž…๋ ฅ๋ฐ›๊ณ , ๊ฐ ๋…ธ์ด์ฆˆ ์ƒ˜ํ”Œ์„ ๋ชฉํ‘œ ๋ถ„ํฌ(์ด ๊ฒฝ์šฐ MNIST ๋ฐ์ดํ„ฐ์…‹)์™€ ์œ ์‚ฌํ•œ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒƒ์ด ๋ชฉํ‘œ์ž…๋‹ˆ๋‹ค. ํŒ๋ณ„๊ธฐ๋Š” MNIST ๋ฐ์ดํ„ฐ์…‹์˜ ์ง„์งœ ์ด๋ฏธ์ง€๋ฅผ ์ž…๋ ฅ๋ฐ›๊ฑฐ๋‚˜ ์ƒ์„ฑ๊ธฐ๋กœ๋ถ€ํ„ฐ ๊ฐ€์งœ ์ด๋ฏธ์ง€๋ฅผ ์ž…๋ ฅ๋ฐ›์Šต๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์–ด๋–ค ์ด๋ฏธ์ง€๊ฐ€ ์–ผ๋งˆ๋‚˜ ์ง„์งœ๊ฐ™์€์ง€ (1 ์— ๊ฐ€๊นŒ์šด ์ถœ๋ ฅ) ํ˜น์€ ๊ฐ€์งœ๊ฐ™์€ ์ง€ (0 ์— ๊ฐ€๊นŒ์šด ์ถœ๋ ฅ) ํŒ๋ณ„ํ•ฉ๋‹ˆ๋‹ค. ์ƒ์„ฑ๊ธฐ๊ฐ€ ๋งŒ๋“  ์ด๋ฏธ์ง€๊ฐ€ ์–ผ๋งˆ๋‚˜ ์ง„์งœ๊ฐ™์€ ์ง€ ํŒ๋ณ„๊ธฐ๊ฐ€ ํ”ผ๋“œ๋ฐฑํ•˜๊ณ  ์ด ํ”ผ๋“œ๋ฐฑ์€ ์ƒ์„ฑ๊ธฐ ํ•™์Šต์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ํŒ๋ณ„๊ธฐ๊ฐ€ ์ง„์งœ์— ๋Œ€ํ•œ ์•ˆ๋ชฉ์ด ์–ผ๋งˆ๋‚˜ ์ข‹์€ ์ง€์— ๋Œ€ํ•œ ํ”ผ๋“œ๋ฐฑ์€ ํŒ๋ณ„๊ธฐ๋ฅผ ์ตœ์ ํ™”ํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ์ด๋ก ์ ์œผ๋กœ, ์ƒ์„ฑ๊ธฐ์™€ ํŒ๋ณ„๊ธฐ ์‚ฌ์ด์˜ ์„ฌ์„ธํ•œ ๊ท ํ˜•์€ ์ด ๋‘˜์„ ๋™์‹œ์— ๊ฐœ์„ ์‹œํ‚ต๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ์ƒ์„ฑ๊ธฐ๋Š” ๋ชฉํ‘œ ๋ถ„ํฌ์™€ ๊ตฌ๋ณ„ํ•  ์ˆ˜ ์—†๋Š” ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๊ณ , (๊ทธ๋•Œ์ฏค์ด๋ฉด) ์ž˜ ํ•™์Šต๋˜์–ด ์žˆ์„ ํŒ๋ณ„๊ธฐ์˜ ์•ˆ๋ชฉ์„ ์†์—ฌ ์ง„์งœ์™€ ๊ฐ€์งœ ์ด๋ฏธ์ง€ ๋ชจ๋‘์— ๋Œ€ํ•ด 0.5 ์˜ ํ™•๋ฅ ์„ ์ถœ๋ ฅํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ตœ์ข… ๊ฒฐ๊ณผ๋ฌผ์€ ๋…ธ์ด์ฆˆ๋ฅผ ์ž…๋ ฅ๋ฐ›์•„ ์‹ค์ œ ์ˆซ์ž์˜ ์ด๋ฏธ์ง€๋ฅผ ์ถœ๋ ฅ์œผ๋กœ ์ƒ์„ฑํ•˜๋Š” ๊ธฐ๊ณ„์ž…๋‹ˆ๋‹ค.

์ƒ์„ฑ๊ธฐ (Generator) ๋ชจ๋“ˆ

๋จผ์ € ์ผ๋ จ์˜ ์ „์น˜๋œ (transposed) 2D ํ•ฉ์„ฑ๊ณฑ, ๋ฐฐ์น˜ ์ •๊ทœํ™” ๋ฐ ReLU ํ™œ์„ฑํ™” ์œ ๋‹›์œผ๋กœ ๊ตฌ์„ฑ๋œ ์ƒ์„ฑ๊ธฐ ๋ชจ๋“ˆ์„ ์ •์˜ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ๋ชจ๋“ˆ์˜ forward() ๋ฉ”์„œ๋“œ๋ฅผ ์ง์ ‘ ์ •์˜ํ•˜์—ฌ ๋ชจ๋“ˆ ๊ฐ„ ์ž…๋ ฅ์„ (ํ•จ์ˆ˜ํ˜•์œผ๋กœ) ๋ช…์‹œ์ ์œผ๋กœ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.

struct DCGANGeneratorImpl : nn::Module {
  DCGANGeneratorImpl(int kNoiseSize)
      : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
                  .bias(false)),
        batch_norm1(256),
        conv2(nn::ConvTranspose2dOptions(256, 128, 3)
                  .stride(2)
                  .padding(1)
                  .bias(false)),
        batch_norm2(128),
        conv3(nn::ConvTranspose2dOptions(128, 64, 4)
                  .stride(2)
                  .padding(1)
                  .bias(false)),
        batch_norm3(64),
        conv4(nn::ConvTranspose2dOptions(64, 1, 4)
                  .stride(2)
                  .padding(1)
                  .bias(false))
 {
   // register_module() is needed if we want to use the parameters() method later on
   register_module("conv1", conv1);
   register_module("conv2", conv2);
   register_module("conv3", conv3);
   register_module("conv4", conv4);
   register_module("batch_norm1", batch_norm1);
   register_module("batch_norm2", batch_norm2);
   register_module("batch_norm3", batch_norm3);
 }

 torch::Tensor forward(torch::Tensor x) {
   x = torch::relu(batch_norm1(conv1(x)));
   x = torch::relu(batch_norm2(conv2(x)));
   x = torch::relu(batch_norm3(conv3(x)));
   x = torch::tanh(conv4(x));
   return x;
 }

 nn::ConvTranspose2d conv1, conv2, conv3, conv4;
 nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);

DCGANGenerator generator(kNoiseSize);

์ด์ œ DCGANGenerator ์˜ forward() ๋ฅผ ํ˜ธ์ถœํ•ด ๋…ธ์ด์ฆˆ ์ƒ˜ํ”Œ์„ ์ด๋ฏธ์ง€์— ๋งคํ•‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์—ฌ๊ธฐ์„œ ์‚ฌ์šฉํ•œ nn::ConvTranspose2d ๋ฐ nn::BatchNorm2d ๋“ฑ์˜ ๋ชจ๋“ˆ์€ ์•ž์„œ ์„ค๋ช…ํ•œ ๊ตฌ์กฐ๋ฅผ ๋”ฐ๋ฆ…๋‹ˆ๋‹ค. ์ƒ์ˆ˜ kNoiseSize ๋Š” ์ž…๋ ฅ ๋…ธ์ด์ฆˆ ๋ฒกํ„ฐ์˜ ํฌ๊ธฐ๋ฅผ ๊ฒฐ์ •ํ•˜๋ฉฐ 100 ์œผ๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค. ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋Š” ๋ฌผ๋ก  ๋Œ€ํ•™์›์ƒ๋“ค์˜ ๋งŽ์€ ๋…ธ๋ ฅ์„ ํ†ตํ•ด ์„ธํŒ…๋์Šต๋‹ˆ๋‹ค.

Attention!

ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ •ํ•˜๋Š๋ผ ๋‹ค์นœ ๋Œ€ํ•™์›์ƒ์€ ์—†์—ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋“ค์€ ์„œ๋กœ์„œ๋กœ ๊ฐœ์‚ฌ๋ฃŒ๋ฅผ ๋จน์ด๋‹ˆ๊นŒ์š”.

Note

C++ ํ”„๋ก ํŠธ์—”๋“œ์˜ Conv2d ์™€ ๊ฐ™์€ ๊ธฐ๋ณธ ์ œ๊ณต ๋ชจ๋“ˆ์— ์˜ต์…˜์ด ์ „๋‹ฌ๋˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•œ ๊ฐ„๋‹จํžˆ ์„ค๋ช…ํ•˜์ž๋ฉด, ๋ชจ๋“  ๋ชจ๋“ˆ์€ ๋ช‡ ๊ฐ€์ง€ ํ•„์ˆ˜ ์˜ต์…˜์„ ๊ฐ–๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. (์˜ˆ: BatchNorm2d ์˜ feature ๊ฐœ์ˆ˜) ๋งŒ์•ฝ BatchNorm2d(128), Dropout(0.5), Conv2d(8, 4, 2) ์™€ ๊ฐ™์ด ํ•„์ˆ˜ ์˜ต์…˜๋งŒ ์„ค์ •ํ•˜๋ ค ํ•œ๋‹ค๋ฉด ๋ชจ๋“ˆ ์ƒ์„ฑ์ž์— ์ง์ ‘ ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. (์—ฌ๊ธฐ์„œ๋Š” ๊ฐ๊ฐ ์ž…๋ ฅ ์ฑ„๋„ ์ˆ˜, ์ถœ๋ ฅ ์ฑ„๋„ ์ˆ˜ ๋ฐ ์ปค๋„ ํฌ๊ธฐ๋ฅผ ์˜๋ฏธ) ๊ทธ๋Ÿฌ๋‚˜ ๋งŒ์•ฝ Conv2d ์˜ bias ์™€ ๊ฐ™์ด ์ผ๋ฐ˜์ ์œผ๋กœ ๊ธฐ๋ณธ๊ฐ’์„ ์‚ฌ์šฉํ•˜๋Š” ๋‹ค๋ฅธ ์˜ต์…˜์„ ์ˆ˜์ •ํ•ด์•ผ ํ•˜๋Š” ๊ฒฝ์šฐ, options ๊ฐ์ฒด๋ฅผ ์ƒ์„ฑํ•ด ์ „๋‹ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. C++ ํ”„๋ก ํŠธ์—”๋“œ์˜ ๋ชจ๋“ˆ์€ ModuleOptions ์ด๋ผ๊ณ  ํ•˜๋Š” ์—ฐ๊ด€๋œ ์˜ต์…˜ struct๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ Module ์€ ํ•ด๋‹น ๋ชจ๋“ˆ์˜ ์ด๋ฆ„์œผ๋กœ, ์˜ˆ๋ฅผ ๋“ค์–ด Linear ์˜ ๊ฒฝ์šฐ LinearOptions ์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ์œ„์˜ Conv2d ๋ชจ๋“ˆ์— ๋Œ€ํ•ด ์ด๋ฅผ ์ˆ˜ํ–‰ํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

ํŒ๋ณ„๊ธฐ(Discriminator) ๋ชจ๋“ˆ

ํŒ๋ณ„๊ธฐ๋Š” ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ํ•ฉ์„ฑ๊ณฑ, ๋ฐฐ์น˜ ์ •๊ทœํ™” ๋ฐ ํ™œ์„ฑํ™”์˜ ์—ฐ์†์ž…๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ด๋ฒˆ์— ํ•ฉ์„ฑ๊ณฑ์€ ์ „์น˜๋˜์ง€ ์•Š์€ ๊ธฐ๋ณธ ํ•ฉ์„ฑ๊ณฑ์ด๋ฉฐ, ์ผ๋ฐ˜์  ReLU ๋Œ€์‹ ์— ์•ŒํŒŒ ๊ฐ’์ด 0.2์ธ leaky ReLU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ ์ตœ์ข… ํ™œ์„ฑํ™”๋Š” ๊ฐ’์„ 0๊ณผ 1 ์‚ฌ์ด์˜ ๋ฒ”์œ„๋กœ ์••์ถ•ํ•˜๋Š” Sigmoid๊ฐ€ ๋ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ ์ด๋ ‡๊ฒŒ ์••์ถ•๋œ ๊ฐ’์„ ํŒ๋ณ„์ž๊ฐ€ ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด ์ถœ๋ ฅํ•˜๋Š” ํ™•๋ฅ ๋กœ ํ•ด์„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํŒ๋ณ„๊ธฐ๋ฅผ ๋งŒ๋“ค๊ธฐ ์œ„ํ•ด Sequential ๋ชจ๋“ˆ์ด๋ผ๋Š” ๋‹ค๋ฅธ ๊ฒƒ์„ ์‹œ๋„ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ํŒŒ์ด์ฌ์—์„œ์™€ ๊ฐ™์ด, PyTorch๋Š” ๋ชจ๋ธ ์ •์˜๋ฅผ ์œ„ํ•ด ๋‘ ๊ฐ€์ง€ API๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. (์ƒ์„ฑ๊ธฐ ๋ชจ๋“ˆ ์˜ˆ์‹œ์™€ ๊ฐ™์ด) ์ž…๋ ฅ์ด ์—ฐ์†์ ์ธ ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ์ „๋‹ฌ๋˜๋Š” ํ•จ์ˆ˜ํ˜• API์™€ ์ „์ฒด ๋ชจ๋ธ์„ ํ•˜์œ„ ๋ชจ๋“ˆ๋กœ ํฌํ•จํ•˜๋Š” Sequential ๋ชจ๋“ˆ์„ ์ƒ์„ฑํ•˜๋Š” ๊ฐ์ฒด ์ง€ํ–ฅํ˜• API์ž…๋‹ˆ๋‹ค. Sequential ์„ ์‚ฌ์šฉํ•˜๋ฉด ํŒ๋ณ„๊ธฐ๋Š” ๋Œ€๋žต ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

nn::Sequential discriminator(
  // Layer 1
  nn::Conv2d(
      nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
  nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
  // Layer 2
  nn::Conv2d(
      nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
  nn::BatchNorm2d(128),
  nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
  // Layer 3
  nn::Conv2d(
      nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
  nn::BatchNorm2d(256),
  nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
  // Layer 4
  nn::Conv2d(
      nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
  nn::Sigmoid());

Tip

Sequential ๋ชจ๋“ˆ์€ ๋‹จ์ˆœํ•œ ํ•จ์ˆ˜ ํ•ฉ์„ฑ๋งŒ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ์ฒซ ๋ฒˆ์งธ ํ•˜์œ„ ๋ชจ๋“ˆ์˜ ์ถœ๋ ฅ์€ ๋‘ ๋ฒˆ์งธ ํ•˜์œ„ ๋ชจ๋“ˆ์˜ ์ž…๋ ฅ์ด ๋˜๊ณ  ์„ธ ๋ฒˆ์งธ ํ•˜์œ„ ๋ชจ๋“ˆ์˜ ์ถœ๋ ฅ์€ ๋„ค ๋ฒˆ์งธ ํ•˜์œ„ ๋ชจ๋“ˆ์˜ ์ž…๋ ฅ์ด ๋˜๊ณ  ์ดํ›„์—๋„ ๋งˆ์ฐฌ๊ฐ€์ง€์ž…๋‹ˆ๋‹ค.

๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

์ด์ œ ์ƒ์„ฑ๊ธฐ์™€ ํŒ๋ณ„๊ธฐ ๋ชจ๋ธ์„ ์ •์˜ํ–ˆ์œผ๋ฏ€๋กœ ์ด๋Ÿฌํ•œ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ฌ ๋ฐ์ดํ„ฐ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ํŒŒ์ด์ฌ๊ณผ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ C++ ํ”„๋ก ํŠธ์—”๋“œ๋Š” ๊ฐ•๋ ฅํ•œ ๋ณ‘๋ ฌ ๋ฐ์ดํ„ฐ ๋กœ๋”(data loader)๋ฅผ ์ œ๊ณตํ•œ๋‹ค. ์ด ๋ฐ์ดํ„ฐ ๋กœ๋”๋Š” ์‚ฌ์šฉ์ž๊ฐ€ ์ง์ ‘ ์ •์˜ํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐ์ดํ„ฐ์…‹์—์„œ ๋ฐ์ดํ„ฐ ๋ฐฐ์น˜๋ฅผ ์ฝ์„ ์ˆ˜ ์žˆ์œผ๋ฉฐ ์„ค์ •์„ ์œ„ํ•œ ๋งŽ์€ ์˜ต์…˜์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

Note

ํŒŒ์ด์ฌ ๋ฐ์ดํ„ฐ ๋กœ๋”๊ฐ€ ๋ฉ€ํ‹ฐ ํ”„๋กœ์„ธ์‹ฑ์„ ์‚ฌ์šฉํ•˜๋Š” ๋ฐ˜๋ฉด, C++ ๋ฐ์ดํ„ฐ ๋กœ๋”๋Š” ์‹ค์ œ๋กœ ๋ฉ€ํ‹ฐ ์Šค๋ ˆ๋”ฉ์„ ์‚ฌ์šฉํ•ด ์–ด๋– ํ•œ ์ƒˆ๋กœ์šด ํ”„๋กœ์„ธ์Šค๋„ ์‹œ์ž‘ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

๋ฐ์ดํ„ฐ ๋กœ๋”๋Š” torch::data:: ๋„ค์ž„์ŠคํŽ˜์ด์Šค์— ํฌํ•จ๋œ C++ ํ”„๋ก ํŠธ์—”๋“œ์˜ data API์˜ ์ผ๋ถ€์ž…๋‹ˆ๋‹ค. ์ด API๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋ช‡ ๊ฐ€์ง€ ์ปดํฌ๋„ŒํŠธ๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.

  • ๋ฐ์ดํ„ฐ ๋กœ๋” ํด๋ž˜์Šค
  • ๋ฐ์ดํ„ฐ์…‹์„ ์ •์˜ํ•˜๊ธฐ ์œ„ํ•œ API
  • ๋ณ€ํ™˜ ์„ ์ •์˜ํ•˜๊ธฐ ์œ„ํ•œ API (๋ฐ์ดํ„ฐ์…‹์— ์ ์šฉ ๊ฐ€๋Šฅ)
  • ์ƒ˜ํ”Œ๋Ÿฌ ๋ฅผ ์ •์˜ํ•˜๊ธฐ ์œ„ํ•œ API (๋ฐ์ดํ„ฐ์…‹์„ ์œ„ํ•œ ์ธ๋ฑ์Šค๋ฅผ ์ƒ์„ฑ)
  • ๊ธฐ์กด ๋ฐ์ดํ„ฐ์…‹, ๋ณ€ํ™˜, ์ƒ˜ํ”Œ๋Ÿฌ๋“ค์˜ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” C++ ํ”„๋ก ํŠธ์—”๋“œ์™€ ํ•จ๊ป˜ ์ œ๊ณต๋˜๋Š” MNIST ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. torch::data::datasets::MNIST ์ธ์Šคํ„ด์Šค๋ฅผ ๋งŒ๋“ค์–ด ๋‹ค์Œ ๋‘ ๊ฐ€์ง€ ๋ณ€ํ™˜์„ ์ ์šฉํ•ด๋ด…์‹œ๋‹ค. ์ฒซ์งธ, ์ด๋ฏธ์ง€๋ฅผ ์ •๊ทœํ™”ํ•˜์—ฌ -1 ๊ณผ +1 ์‚ฌ์ด์— ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค. (๊ธฐ์กด ๋ฒ”์œ„๋Š” 0 ๊ณผ 1 ์‚ฌ์ด) ๋‘˜์งธ, ํ…์„œ ๋ฐฐ์น˜(batch)๋ฅผ ์ฒซ ๋ฒˆ์งธ ์ฐจ์›์„ ๋”ฐ๋ผ ๋‹จ์ผ ํ…์„œ๋กœ ์Œ“๋Š” ์ด๋ฅธ๋ฐ” Stack collation ์„ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.

auto dataset = torch::data::datasets::MNIST("./mnist")
    .map(torch::data::transforms::Normalize<>(0.5, 0.5))
    .map(torch::data::transforms::Stack<>());

MNIST ๋ฐ์ดํ„ฐ์…‹์€ ํ•™์Šต ๋ฐ”์ด๋„ˆ๋ฆฌ ์‹คํ–‰ ์œ„์น˜๋ฅผ ๊ธฐ์ค€์œผ๋กœ ./mnist ๋””๋ ‰ํ† ๋ฆฌ์— ์œ„์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. MNIST ๋ฐ์ดํ„ฐ์…‹์€ ์ด ์Šคํฌ๋ฆฝํŠธ ๋ฅผ ์‚ฌ์šฉํ•ด ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋‹ค์Œ์œผ๋กœ, ๋ฐ์ดํ„ฐ ๋กœ๋”๋ฅผ ๋งŒ๋“ค๊ณ  ์ด ๋ฐ์ดํ„ฐ์…‹์„ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ ๋กœ๋”๋ฅผ ๋งŒ๋“ค๊ธฐ ์œ„ํ•ด torch::data::make_data_loader ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด ๋กœ๋”๋Š” ์˜ฌ๋ฐ”๋ฅธ ํƒ€์ž…(๋ฐ์ดํ„ฐ์…‹ ํƒ€์ž…, ์ƒ˜ํ”Œ๋Ÿฌ ํƒ€์ž… ๋ฐ ๊ธฐํƒ€ ๊ตฌํ˜„ ์„ธ๋ถ€์‚ฌํ•ญ์— ๋”ฐ๋ผ ๊ฒฐ์ •๋จ)์˜ std::unique_ptr ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

auto data_loader = torch::data::make_data_loader(std::move(dataset));

๋ฐ์ดํ„ฐ ๋กœ๋”์—๋Š” ๋งŽ์€ ์˜ต์…˜์ด ์ œ๊ณต๋ฉ๋‹ˆ๋‹ค. ์ „์ฒด ๋ชฉ๋ก์€ ์—ฌ๊ธฐ ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์†๋„๋ฅผ ๋†’์ด๊ธฐ ์œ„ํ•ด ์ž‘์—…์ž ์ˆ˜๋ฅผ ๋Š˜๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ธฐ๋ณธ๊ฐ’์€ 0์ด๋ฉฐ, ์ด๋Š” ์ฃผ ์“ฐ๋ ˆ๋“œ๊ฐ€ ์‚ฌ์šฉ๋จ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค. workers ๋ฅผ 2 ๋กœ ์„ค์ •ํ•˜๋ฉด ๋ฐ์ดํ„ฐ๋ฅผ ๋™์‹œ์— ๋กœ๋“œํ•˜๋Š” ์“ฐ๋ ˆ๋“œ๊ฐ€ ๋‘ ๊ฐœ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ ๋ฐฐ์น˜ ํฌ๊ธฐ๋ฅผ ๊ธฐ๋ณธ๊ฐ’ 1 ์—์„œ 64 (kBatchSize ๊ฐ’) ์™€ ๊ฐ™์ด ๋” ์ ๋‹นํ•œ ๊ฐ’์œผ๋กœ ๋Š˜๋ ค์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด DataLoaderOptions ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ค์–ด ์ ์ ˆํ•œ ์†์„ฑ์„ ์„ค์ •ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

auto data_loader = torch::data::make_data_loader(
    std::move(dataset),
    torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));

์ด์ œ ๋ฐ์ดํ„ฐ ๋ฐฐ์น˜๋ฅผ ๋กœ๋“œํ•˜๋Š” ๋ฃจํ”„๋ฅผ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ง€๊ธˆ์€ ์ฝ˜์†”์—๋งŒ ์ถœ๋ ฅํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

for (torch::data::Example<>& batch : *data_loader) {
  std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
  for (int64_t i = 0; i < batch.data.size(0); ++i) {
    std::cout << batch.target[i].item<int64_t>() << " ";
  }
  std::cout << std::endl;
}

์ด ๊ฒฝ์šฐ ๋ฐ์ดํ„ฐ ๋กœ๋”๊ฐ€ ๋ฐ˜ํ™˜ํ•˜๋Š” ํƒ€์ž…์€ torch::data::Example ์ž…๋‹ˆ๋‹ค. ์ด ํƒ€์ž…์€ ๋ฐ์ดํ„ฐ๋ฅผ ์œ„ํ•œ data ํ•„๋“œ์™€ ๋ ˆ์ด๋ธ”์„ ์œ„ํ•œ target ํ•„๋“œ๊ฐ€ ์žˆ๋Š” ๊ฐ„๋‹จํ•œ struct์ž…๋‹ˆ๋‹ค. ์•ž์„œ Stack collation์„ ์ ์šฉํ–ˆ๊ธฐ ๋•Œ๋ฌธ์—, ๋ฐ์ดํ„ฐ ๋กœ๋”๋Š” ์ด example์„ ํ•˜๋‚˜๋งŒ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ ๋กœ๋”์— collation์„ ์ ์šฉํ•˜์ง€ ์•Š์œผ๋ฉด, std::vector<torch::data::Example<>> ๋ฅผ yieldํ•˜๋ฉฐ, ๊ฐ ๋ฐฐ์น˜์˜ example์—๋Š” ํ•˜๋‚˜์˜ element๊ฐ€ ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ด ์ฝ”๋“œ๋ฅผ ๋‹ค์‹œ ๋นŒ๋“œํ•˜๊ณ  ์‹คํ–‰ํ•˜๋ฉด ๋Œ€๋žต ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋‚ด์šฉ์„ ์–ป์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

root@fa350df05ecf:/home/build# make
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcgan
root@fa350df05ecf:/home/build# make
[100%] Built target dcgan
root@fa350df05ecf:/home/build# ./dcgan
Batch size: 64 | Labels: 5 2 6 7 2 1 6 7 0 1 6 2 3 6 9 1 8 4 0 6 5 3 3 0 4 6 6 6 4 0 8 6 0 6 9 2 4 0 2 8 6 3 3 2 9 2 0 1 4 2 3 4 8 2 9 9 3 5 8 0 0 7 9 9
Batch size: 64 | Labels: 2 2 4 7 1 2 8 8 6 9 0 2 2 9 3 6 1 3 8 0 4 4 8 8 8 9 2 6 4 7 1 5 0 9 7 5 4 3 5 4 1 2 8 0 7 1 9 6 1 6 5 3 4 4 1 2 3 2 3 5 0 1 6 2
Batch size: 64 | Labels: 4 5 4 2 1 4 8 3 8 3 6 1 5 4 3 6 2 2 5 1 3 1 5 0 8 2 1 5 3 2 4 4 5 9 7 2 8 9 2 0 6 7 4 3 8 3 5 8 8 3 0 5 8 0 8 7 8 5 5 6 1 7 8 0
Batch size: 64 | Labels: 3 3 7 1 4 1 6 1 0 3 6 4 0 2 5 4 0 4 2 8 1 9 6 5 1 6 3 2 8 9 2 3 8 7 4 5 9 6 0 8 3 0 0 6 4 8 2 5 4 1 8 3 7 8 0 0 8 9 6 7 2 1 4 7
Batch size: 64 | Labels: 3 0 5 5 9 8 3 9 8 9 5 9 5 0 4 1 2 7 7 2 0 0 5 4 8 7 7 6 1 0 7 9 3 0 6 3 2 6 2 7 6 3 3 4 0 5 8 8 9 1 9 2 1 9 4 4 9 2 4 6 2 9 4 0
Batch size: 64 | Labels: 9 6 7 5 3 5 9 0 8 6 6 7 8 2 1 9 8 8 1 1 8 2 0 7 1 4 1 6 7 5 1 7 7 4 0 3 2 9 0 6 6 3 4 4 8 1 2 8 6 9 2 0 3 1 2 8 5 6 4 8 5 8 6 2
Batch size: 64 | Labels: 9 3 0 3 6 5 1 8 6 0 1 9 9 1 6 1 7 7 4 4 4 7 8 8 6 7 8 2 6 0 4 6 8 2 5 3 9 8 4 0 9 9 3 7 0 5 8 2 4 5 6 2 8 2 5 3 7 1 9 1 8 2 2 7
Batch size: 64 | Labels: 9 1 9 2 7 2 6 0 8 6 8 7 7 4 8 6 1 1 6 8 5 7 9 1 3 2 0 5 1 7 3 1 6 1 0 8 6 0 8 1 0 5 4 9 3 8 5 8 4 8 0 1 2 6 2 4 2 7 7 3 7 4 5 3
Batch size: 64 | Labels: 8 8 3 1 8 6 4 2 9 5 8 0 2 8 6 6 7 0 9 8 3 8 7 1 6 6 2 7 7 4 5 5 2 1 7 9 5 4 9 1 0 3 1 9 3 9 8 8 5 3 7 5 3 6 8 9 4 2 0 1 2 5 4 7
Batch size: 64 | Labels: 9 2 7 0 8 4 4 2 7 5 0 0 6 2 0 5 9 5 9 8 8 9 3 5 7 5 4 7 3 0 5 7 6 5 7 1 6 2 8 7 6 3 2 6 5 6 1 2 7 7 0 0 5 9 0 0 9 1 7 8 3 2 9 4
Batch size: 64 | Labels: 7 6 5 7 7 5 2 2 4 9 9 4 8 7 4 8 9 4 5 7 1 2 6 9 8 5 1 2 3 6 7 8 1 1 3 9 8 7 9 5 0 8 5 1 8 7 2 6 5 1 2 0 9 7 4 0 9 0 4 6 0 0 8 6
...

์ฆ‰, MNIST ๋ฐ์ดํ„ฐ์…‹์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ์„ฑ๊ณต์ ์œผ๋กœ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ•™์Šต ๋ฃจํ”„ ์ž‘์„ฑํ•˜๊ธฐ

์ด์ œ ์˜ˆ์ œ์˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋ถ€๋ถ„์„ ๋งˆ๋ฌด๋ฆฌํ•˜๊ณ  ์ƒ์„ฑ๊ธฐ์™€ ํŒ๋ณ„๊ธฐ ์‚ฌ์ด์—์„œ ์ผ์–ด๋‚˜๋Š” ์„ฌ์„ธํ•œ ์ž‘์šฉ์„ ๊ตฌํ˜„ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋จผ์ € ์ƒ์„ฑ๊ธฐ์™€ ํŒ๋ณ„๊ธฐ ๊ฐ๊ฐ์„ ์œ„ํ•ด ์ด ๋‘ ๊ฐœ์˜ optimizer๋ฅผ ์ƒ์„ฑํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๊ฐ€ ์‚ฌ์šฉํ•˜๋Š” optimizer๋Š” Adam ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.

torch::optim::Adam generator_optimizer(
    generator->parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple(0.5, 0.5)));
torch::optim::Adam discriminator_optimizer(
    discriminator->parameters(), torch::optim::AdamOptions(5e-4).betas(std::make_tuple(0.5, 0.5)));

Note

์ด ๊ธ€ ์ž‘์„ฑ ๋‹น์‹œ, C++ ํ”„๋ก ํŠธ์—”๋“œ๊ฐ€ Adagrad, Adam, LBFGS, RMSprop ๋ฐ SGD๋ฅผ ๊ตฌํ˜„ํ•˜๋Š” ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ตœ์‹  ๋ฆฌ์ŠคํŠธ๋Š” docs ์— ์žˆ์Šต๋‹ˆ๋‹ค.

๋‹ค์Œ์œผ๋กœ, ์šฐ๋ฆฌ์˜ ํ•™์Šต ๋ฃจํ”„๋ฅผ ์ˆ˜์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋งค ์—ํญ๋งˆ๋‹ค ๋ฐ์ดํ„ฐ ๋กœ๋”๋ฅผ ๋ฐ˜๋ณต ์‹คํ–‰ํ•˜๋Š” ๋ฐ”๊นฅ ๋ฃจํ”„๋ฅผ ์ถ”๊ฐ€ํ•ด ๋‹ค์Œ์˜ GAN ํ•™์Šต ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค.

for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
  int64_t batch_index = 0;
  for (torch::data::Example<>& batch : *data_loader) {
    // Train discriminator with real images.
    discriminator->zero_grad();
    torch::Tensor real_images = batch.data;
    torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
    torch::Tensor real_output = discriminator->forward(real_images);
    torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
    d_loss_real.backward();

    // Train discriminator with fake images.
    torch::Tensor noise = torch::randn({batch.data.size(0), kNoiseSize, 1, 1});
    torch::Tensor fake_images = generator->forward(noise);
    torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
    torch::Tensor fake_output = discriminator->forward(fake_images.detach());
    torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
    d_loss_fake.backward();

    torch::Tensor d_loss = d_loss_real + d_loss_fake;
    discriminator_optimizer.step();

    // Train generator.
    generator->zero_grad();
    fake_labels.fill_(1);
    fake_output = discriminator->forward(fake_images);
    torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
    g_loss.backward();
    generator_optimizer.step();

    std::printf(
        "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
        epoch,
        kNumberOfEpochs,
        ++batch_index,
        batches_per_epoch,
        d_loss.item<float>(),
        g_loss.item<float>());
  }
}

์œ„ ์ฝ”๋“œ๋Š” ๋จผ์ € ์ง„์งœ (real) ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด ํŒ๋ณ„๊ธฐ๋ฅผ ํ‰๊ฐ€ํ•˜๋Š”๋ฐ, ์ด ๋•Œ ํŒ๋ณ„๊ธฐ๋Š” ๋†’์€ ํ™•๋ฅ ์„ ์ถœ๋ ฅํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด torch::empty(batch.data.size(0)).uniform_(0.8, 1.0) ๋ฅผ ๋ชฉํ‘œ ํ™•๋ฅ  ๊ฐ’์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

Note

ํŒ๋ณ„๊ธฐ๋ฅผ ๋ณด๋‹ค ๊ฒฌ๊ณ ํ•˜๊ฒŒ ํ•™์Šตํ•˜๊ธฐ ์œ„ํ•ด ๋ชจ๋“  ๊ณณ์—์„œ 1.0์ด ์•„๋‹Œ 0.8๊ณผ 1.0 ์‚ฌ์ด์˜ ๊ท ์ผ ๋ถ„ํฌ์—์„œ ์ž„์˜์˜ ๊ฐ’์„ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค. ์ด ํŠธ๋ฆญ์„ label smoothing ์ด๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

ํŒ๋ณ„๊ธฐ๋ฅผ ํ‰๊ฐ€ํ•˜๊ธฐ์— ์•ž์„œ ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ 0์œผ๋กœ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ์†์‹ค์„ ๊ณ„์‚ฐํ•œ ํ›„ d_loss.backward() ๋ฅผ ํ˜ธ์ถœํ•ด ์ด๋ฅผ ๋„คํŠธ์›Œํฌ์— ์—ญ์ „ํŒŒํ•ฉ๋‹ˆ๋‹ค. ๊ฐ€์งœ (fake) ์ด๋ฏธ์ง€๋“ค์— ๋Œ€ํ•ด์„œ ์ด ๊ณผ์ •์„ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ์…‹์˜ ์ด๋ฏธ์ง€๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋Œ€์‹ , ์ƒ์„ฑ์ž์— ๋ฌด์ž‘์œ„ ๋…ธ์ด์ฆˆ๋ฅผ ์ž…๋ ฅํ•˜์—ฌ ์—ฌ๊ธฐ์„œ ์‚ฌ์šฉํ•  ๊ฐ€์งœ ์ด๋ฏธ์ง€๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ๊ทธ ๊ฐ€์งœ ์ด๋ฏธ์ง€๋“ค์„ ํŒ๋ณ„๊ธฐ์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฒˆ์—๋Š” ํŒ๋ณ„๊ธฐ๊ฐ€ ๋‚ฎ์€ ํ™•๋ฅ , ์ด์ƒ์ ์œผ๋กœ๋Š” ๋ชจ๋‘ 0์„ ์ถœ๋ ฅํ•˜๊ธฐ๋ฅผ ๋ฐ”๋ž๋‹ˆ๋‹ค. ์ง„์งœ ์ด๋ฏธ์ง€์™€ ๊ฐ€์งœ ์ด๋ฏธ์ง€ ๋ฐฐ์น˜ ๋ชจ๋‘์— ๋Œ€ํ•œ ํŒ๋ณ„๊ธฐ ์†์‹ค์„ ๊ณ„์‚ฐํ•œ ํ›„์—๋Š”, ํŒ๋ณ„๊ธฐ์˜ optimizer ๋งค๊ฐœ๋ณ€์ˆ˜ ์—…๋ฐ์ดํŠธ๋ฅผ ํ•œ ๋‹จ๊ณ„์”ฉ ์ง„ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ƒ์„ฑ๊ธฐ๋ฅผ ํ•™์Šต์‹œํ‚ค๊ธฐ ์œ„ํ•ด ์šฐ์„  ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ๋‹ค์‹œ ํ•œ๋ฒˆ 0์œผ๋กœ ์„ค์ •ํ•˜๊ณ  ๋‹ค์‹œ ๊ฐ€์งœ ์ด๋ฏธ์ง€๋กœ ํŒ๋ณ„๊ธฐ๋ฅผ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ด๋ฒˆ์—๋Š” ํŒ๋ณ„๊ธฐ๊ฐ€ ํ™•๋ฅ  1์— ๋งค์šฐ ๊ทผ์ ‘ํ•˜๊ฒŒ ์ถœ๋ ฅํ•˜๊ฒŒ ํ•˜์—ฌ, ์ƒ์„ฑ๊ธฐ๊ฐ€ ํŒ๋ณ„๊ธฐ๋ฅผ ์†์—ฌ ์‹ค์ œ (๋ฐ์ดํ„ฐ์…‹์— ์žˆ๋Š”) ์ง„์งœ๋ผ๊ณ  ์ƒ๊ฐํ•˜๋Š” ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋ ค ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด fake_labels ํ…์„œ๋ฅผ ๋ชจ๋‘ 1๋กœ ์ฑ„์šฐ๊ฒ ์Šต๋‹ˆ๋‹ค. ๋งˆ์ง€๋ง‰์œผ๋กœ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์—…๋ฐ์ดํŠธํ•˜๊ธฐ ์œ„ํ•ด ์ƒ์„ฑ๊ธฐ์˜ optimzier ๋งค๊ฐœ๋ณ€์ˆ˜ ์—…๋ฐ์ดํŠธ๋ฅผ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

์ด์ œ CPU๋กœ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ฌ ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ƒํƒœ๋‚˜ ์ƒ˜ํ”Œ ์ถœ๋ ฅ์„ ์บก์ฒ˜ํ•  ์ˆ˜ ์žˆ๋Š” ์ฝ”๋“œ๋Š” ์•„์ง ์—†์ง€๋งŒ ์ž ์‹œ ํ›„์— ์ถ”๊ฐ€ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์ง€๊ธˆ์€ ๋ชจ๋ธ์ด ๋ฌด์–ธ๊ฐ€ ๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ณ  ์žˆ๋‹ค๋Š” ๊ฒƒ๋งŒ์„ ๊ด€์ฐฐํ•˜๊ณ , ๋‚˜์ค‘์—๋Š” ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์ด ๋ฌด์–ธ๊ฐ€๊ฐ€ ์˜๋ฏธ ์žˆ๋Š”์ง€ ์—ฌ๋ถ€๋ฅผ ํ™•์ธํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋‹ค์‹œ ๋นŒ๋“œํ•˜๊ณ  ์‹คํ–‰ํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋‚ด์šฉ์ด ์ถœ๋ ฅ๋ผ์•ผ ํ•ฉ๋‹ˆ๋‹ค.

root@3c0711f20896:/home/build# make && ./dcgan
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcga
[ 1/10][100/938] D_loss: 0.6876 | G_loss: 4.1304
[ 1/10][200/938] D_loss: 0.3776 | G_loss: 4.3101
[ 1/10][300/938] D_loss: 0.3652 | G_loss: 4.6626
[ 1/10][400/938] D_loss: 0.8057 | G_loss: 2.2795
[ 1/10][500/938] D_loss: 0.3531 | G_loss: 4.4452
[ 1/10][600/938] D_loss: 0.3501 | G_loss: 5.0811
[ 1/10][700/938] D_loss: 0.3581 | G_loss: 4.5623
[ 1/10][800/938] D_loss: 0.6423 | G_loss: 1.7385
[ 1/10][900/938] D_loss: 0.3592 | G_loss: 4.7333
[ 2/10][100/938] D_loss: 0.4660 | G_loss: 2.5242
[ 2/10][200/938] D_loss: 0.6364 | G_loss: 2.0886
[ 2/10][300/938] D_loss: 0.3717 | G_loss: 3.8103
[ 2/10][400/938] D_loss: 1.0201 | G_loss: 1.3544
[ 2/10][500/938] D_loss: 0.4522 | G_loss: 2.6545
...

GPU๋กœ ์ด๋™ํ•˜๊ธฐ

์ด ์Šคํฌ๋ฆฝํŠธ๋Š” CPU์—์„œ ์ž˜ ๋™์ž‘ํ•˜์ง€๋งŒ, ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ์ด GPU์—์„œ ํ›จ์”ฌ ๋น ๋ฅด๋‹ค๋Š” ๊ฒƒ์€ ์ž˜ ์•Œ๋ ค์ง„ ์‚ฌ์‹ค์ž…๋‹ˆ๋‹ค. ์–ด๋–ป๊ฒŒ ํ•™์Šต์„ GPU๋กœ ์˜ฎ๊ธธ ์ˆ˜ ์žˆ์„ ์ง€์— ๋Œ€ํ•ด ๋น ๋ฅด๊ฒŒ ๋…ผ์˜ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ํ•ด์•ผ ํ•  ์ผ ๋‘ ๊ฐ€์ง€๋กœ GPU ์žฅ์น˜(device) ์‚ฌ์–‘์„ ์šฐ๋ฆฌ๊ฐ€ ์ง์ ‘ ํ• ๋‹นํ•œ ํ…์„œ์— ์ „๋‹ฌํ•˜๋Š” ๊ฒƒ๊ณผ, C++ ํ”„๋ก ํŠธ์—”๋“œ์˜ ๋ชจ๋“  ํ…์„œ์™€ ๋ชจ๋“ˆ์ด ๊ฐ–๊ณ  ์žˆ๋Š” to() ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•ด ๋‹ค๋ฅธ ๋ชจ๋“  ํ…์„œ๋ฅผ GPU์— ๋ช…์‹œ์ ์œผ๋กœ ๋ณต์‚ฌํ•˜๋Š” ๊ฒƒ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๋‘ ๊ฐ€์ง€๋ฅผ ๋ชจ๋‘ ๋‹ฌ์„ฑํ•˜๋Š” ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ ์ตœ์ƒ์œ„์— torch::Device ์ธ์Šคํ„ด์Šค๋ฅผ ๋งŒ๋“ค์–ด torch::zeros ์™€ ๊ฐ™์€ ํ…์„œ ํŒฉํ† ๋ฆฌ ํ•จ์ˆ˜๋‚˜ to() ๋ฉ”์„œ๋“œ์— ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋จผ์ € CPU device๋กœ ์ด๋ฅผ ๊ตฌํ˜„ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

// ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ ์ตœ์ƒ๋‹จ์— ์ด ์ฝ”๋“œ๋ฅผ ๋„ฃ์œผ์„ธ์š”.
torch::Device device(torch::kCPU);

์•„๋ž˜์™€ ๊ฐ™์€ ์ƒˆ๋กœ์šด ํ…์„œ ํ• ๋‹น์˜ ๊ฒฝ์šฐ,

torch::Tensor fake_labels = torch::zeros(batch.data.size(0));

๋งˆ์ง€๋ง‰ ์ธ์ž๋กœ device ๋ฅผ ๋ฐ›๋„๋ก ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค.

torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);

MNIST ๋ฐ์ดํ„ฐ์…‹์˜ ํ…์„œ์ฒ˜๋Ÿผ ์šฐ๋ฆฌ๊ฐ€ ์ง์ ‘ ์ƒ์„ฑํ•˜์ง€ ์•Š๋Š” ํ…์„œ์—์„œ๋Š” ๋ช…์‹œ์ ์œผ๋กœ to() ํ˜ธ์ถœ์„ ์‚ฝ์ž…ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์•„๋ž˜ ์ฝ”๋“œ์˜ ๊ฒฝ์šฐ,

torch::Tensor real_images = batch.data;

๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋ณ€ํ•ฉ๋‹ˆ๋‹ค.

torch::Tensor real_images = batch.data.to(device);

๋˜ํ•œ, ๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์˜ฌ๋ฐ”๋ฅธ ์žฅ์น˜๋กœ ์˜ฎ๊ฒจ์•ผ ํ•ฉ๋‹ˆ๋‹ค.

generator->to(device);
discriminator->to(device);

Note

๋งŒ์ผ ํ…์„œ๊ฐ€ ์ด๋ฏธ to() ์— ์ „๋‹ฌ๋œ ์žฅ์น˜ ์ƒ์— ์žˆ๋‹ค๋ฉด ๊ทธ ํ˜ธ์ถœ์€ ์•„๋ฌด ์ผ๋„ ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์‚ฌ๋ณธ์ด ์ƒ์„ฑ๋˜์ง€๋„ ์•Š์Šต๋‹ˆ๋‹ค.

์ด์ œ CPU์—์„œ ์‹คํ–‰๋˜๋Š” ์ด์ „์˜ ์ฝ”๋“œ๊ฐ€ ๋ณด๋‹ค ๋ช…์‹œ์ ์œผ๋กœ ๋ฐ”๋€Œ์—ˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ด์ œ๋Š” ์žฅ์น˜๋ฅผ CUDA ์žฅ์น˜๋กœ ๋ณ€๊ฒฝํ•˜๋Š” ๊ฒƒ ๋˜ํ•œ ๋งค์šฐ ์‰ฝ์Šต๋‹ˆ๋‹ค.

torch::Device device(torch::kCUDA)

์ด์ œ ๋ชจ๋“  ํ…์„œ๊ฐ€ GPU์— ์กด์žฌํ•˜๋ฉฐ ์–ด๋– ํ•œ ๋‹ค์šด์ŠคํŠธ๋ฆผ ์ฝ”๋“œ ๋ณ€๊ฒฝ ์—†์ด๋„ ๋ชจ๋“  ์—ฐ์‚ฐ์„ ์œ„ํ•ด ๋น ๋ฅธ CUDA ์ปค๋„์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค. ํŠน์ • ์ธ๋ฑ์Šค์˜ ์žฅ์น˜๋ฅผ ์ง€์ •ํ•˜๋ ค๋ฉด Device ์ƒ์„ฑ์ž์˜ ๋‘ ๋ฒˆ์งธ ์ธ์ž๋กœ ์ „๋‹ฌํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค. ์„œ๋กœ ๋‹ค๋ฅธ ์žฅ์น˜์— ์„œ๋กœ ๋‹ค๋ฅธ ํ…์„œ๊ฐ€ ์กด์žฌํ•˜๊ธฐ๋ฅผ ์›ํ•˜๋Š” ๊ฒฝ์šฐ, ๋ณ„๋„์˜ ์žฅ์น˜ ์ธ์Šคํ„ด์Šค(์˜ˆ: CUDA ์žฅ์น˜ 0๊ณผ ๋‹ค๋ฅธ CUDA ์žฅ์น˜ 1)๋ฅผ ์ „๋‹ฌํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ, ์ด๋Ÿฌํ•œ ์„ค์ •์„ ๋™์ ์œผ๋กœ ์ˆ˜ํ–‰ํ•  ์ˆ˜๋„ ์žˆ์–ด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์˜ ํœด๋Œ€์„ฑ์„ ๋†’์ด๋Š” ๋ฐ ์ข…์ข… ์œ ์šฉํ•˜๊ฒŒ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

torch::Device device = torch::kCPU;
if (torch::cuda::is_available()) {
  std::cout << "CUDA is available! Training on GPU." << std::endl;
  device = torch::kCUDA;
}

๋‚˜์•„๊ฐ€ ์•„๋ž˜์™€ ๊ฐ™์€ ์ฝ”๋“œ๋„ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);

ํ•™์Šต ์ƒํƒœ ์ €์žฅ ๋ฐ ๋ณต์›ํ•˜๊ธฐ

๋งˆ์ง€๋ง‰์œผ๋กœ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ์ถ”๊ฐ€ํ•ด์•ผ ํ•  ๋‚ด์šฉ์€ ๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜ ๋ฐ ์˜ตํ‹ฐ๋งˆ์ด์ €์˜ ์ƒํƒœ, ๊ทธ๋ฆฌ๊ณ  ์ƒ์„ฑ๋œ ๋ช‡ ๊ฐœ์˜ ์ด๋ฏธ์ง€ ์ƒ˜ํ”Œ์„ ์ฃผ๊ธฐ์ ์œผ๋กœ ์ €์žฅํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ํ•™์Šต ๊ณผ์ • ๋„์ค‘์— ์ปดํ“จํ„ฐ๊ฐ€ ๋‹ค์šด๋˜๋ฉด ์ด๋ ‡๊ฒŒ ์ €์žฅ๋œ ์ƒํƒœ๋กœ๋ถ€ํ„ฐ ํ•™์Šต ์ƒํƒœ๋ฅผ ๋ณต์›ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ์žฅ์‹œ๊ฐ„ ์ง€์†๋˜๋Š” ํ•™์Šต์„ ์œ„ํ•ด ํ•„์ˆ˜๋กœ ์š”๊ตฌ๋ฉ๋‹ˆ๋‹ค. ๋‹คํ–‰ํžˆ๋„ C++ ํ”„๋ก ํŠธ์—”๋“œ๋Š” ๊ฐœ๋ณ„ ํ…์„œ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ๋ชจ๋ธ ๋ฐ ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ๋ฅผ ์ง๋ ฌํ™”ํ•˜๊ณ  ์—ญ์ง๋ ฌํ™”ํ•  ์ˆ˜ ์žˆ๋Š” API๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

์ด๋ฅผ ์œ„ํ•œ ํ•ต์‹ฌ API๋Š” torch::save(thing,filename) ์™€ torch::load(thing,filename) ๋กœ, ์—ฌ๊ธฐ์„œ thing ์€ torch::nn::Module ์˜ ํ•˜์œ„ ํด๋ž˜์Šค ํ˜น์€ ์šฐ๋ฆฌ์˜ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์˜ Adam ๊ฐ์ฒด์™€ ๊ฐ™์€ ์˜ตํ‹ฐ๋งˆ์ด์ € ์ธ์Šคํ„ด์Šค๊ฐ€ ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ชจ๋ธ ๋ฐ ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ๋ฅผ ํŠน์ • ์ฃผ๊ธฐ๋งˆ๋‹ค ์ €์žฅํ•˜๋„๋ก ํ•™์Šต ๋ฃจํ”„๋ฅผ ์ˆ˜์ •ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

if (batch_index % kCheckpointEvery == 0) {
  // ๋ชจ๋ธ ๋ฐ ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
  torch::save(generator, "generator-checkpoint.pt");
  torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
  torch::save(discriminator, "discriminator-checkpoint.pt");
  torch::save(discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
  // ์ƒ์„ฑ๊ธฐ๋ฅผ ์ƒ˜ํ”Œ๋งํ•˜๊ณ  ์ด๋ฏธ์ง€๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
  torch::Tensor samples = generator->forward(torch::randn({8, kNoiseSize, 1, 1}, device));
  torch::save((samples + 1.0) / 2.0, torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
  std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
}

์—ฌ๊ธฐ์„œ 100 ๋ฐฐ์น˜๋งˆ๋‹ค ์ƒํƒœ๋ฅผ ์ €์žฅํ•˜๋ ค๋ฉด kCheckpointEvery ๋ฅผ 100 ๊ณผ ๊ฐ™์€ ์ •์ˆ˜๋กœ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, checkpoint_counter ๋Š” ์ƒํƒœ๋ฅผ ์ €์žฅํ•  ๋•Œ๋งˆ๋‹ค ์ฆ๊ฐ€ํ•˜๋Š” ์นด์šดํ„ฐ์ž…๋‹ˆ๋‹ค.

ํ•™์Šต ์ƒํƒœ๋ฅผ ๋ณต์›ํ•˜๊ธฐ ์œ„ํ•ด ๋ชจ๋ธ ๋ฐ ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ๋ชจ๋‘ ์ƒ์„ฑํ•œ ํ›„ ํ•™์Šต ๋ฃจํ”„ ์•ž์— ๋‹ค์Œ ์ฝ”๋“œ๋ฅผ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

torch::optim::Adam generator_optimizer(
    generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
torch::optim::Adam discriminator_optimizer(
    discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));

if (kRestoreFromCheckpoint) {
  torch::load(generator, "generator-checkpoint.pt");
  torch::load(generator_optimizer, "generator-optimizer-checkpoint.pt");
  torch::load(discriminator, "discriminator-checkpoint.pt");
  torch::load(
      discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
}

int64_t checkpoint_counter = 0;
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
  int64_t batch_index = 0;
  for (torch::data::Example<>& batch : *data_loader) {

์ƒ์„ฑํ•œ ์ด๋ฏธ์ง€ ๊ฒ€์‚ฌํ•˜๊ธฐ

ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๊ฐ€ ์™„์„ฑ๋˜์–ด CPU์—์„œ๋“  GPU์—์„œ๋“  GAN์„ ํ›ˆ๋ จ์‹œํ‚ฌ ์ค€๋น„๊ฐ€ ๋์Šต๋‹ˆ๋‹ค. ํ•™์Šต ๊ณผ์ •์˜ ์ค‘๊ฐ„ ์ถœ๋ ฅ์„ ๊ฒ€์‚ฌํ•˜๊ธฐ ์œ„ํ•ด "dcgan-sample-xxx.pt" ์— ์ฃผ๊ธฐ์ ์œผ๋กœ ์ด๋ฏธ์ง€ ์ƒ˜ํ”Œ์„ ์ €์žฅํ•˜๋Š” ์ฝ”๋“œ๋ฅผ ์ถ”๊ฐ€ํ–ˆ์œผ๋‹ˆ, ํ…์„œ๋“ค์„ ๋ถˆ๋Ÿฌ์™€ matplotlib๋กœ ์‹œ๊ฐํ™”ํ•˜๋Š” ๊ฐ„๋‹จํ•œ ํŒŒ์ด์ฌ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ž‘์„ฑํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

import argparse

import matplotlib.pyplot as plt
import torch


parser = argparse.ArgumentParser()
parser.add_argument("-i", "--sample-file", required=True)
parser.add_argument("-o", "--out-file", default="out.png")
parser.add_argument("-d", "--dimension", type=int, default=3)
options = parser.parse_args()

module = torch.jit.load(options.sample_file)
images = list(module.parameters())[0]

for index in range(options.dimension * options.dimension):
  image = images[index].detach().cpu().reshape(28, 28).mul(255).to(torch.uint8)
  array = image.numpy()
  axis = plt.subplot(options.dimension, options.dimension, 1 + index)
  plt.imshow(array, cmap="gray")
  axis.get_xaxis().set_visible(False)
  axis.get_yaxis().set_visible(False)

plt.savefig(options.out_file)
print("Saved ", options.out_file)

์ด์ œ ๋ชจ๋ธ์„ ์•ฝ 30 ์—ํญ ์ •๋„ ํ•™์Šต์‹œํ‚ต์‹œ๋‹ค.

root@3c0711f20896:/home/build# make && ./dcgan                                                                                                                                10:17:57
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcgan
CUDA is available! Training on GPU.
[ 1/30][200/938] D_loss: 0.4953 | G_loss: 4.0195
-> checkpoint 1
[ 1/30][400/938] D_loss: 0.3610 | G_loss: 4.8148
-> checkpoint 2
[ 1/30][600/938] D_loss: 0.4072 | G_loss: 4.36760
-> checkpoint 3
[ 1/30][800/938] D_loss: 0.4444 | G_loss: 4.0250
-> checkpoint 4
[ 2/30][200/938] D_loss: 0.3761 | G_loss: 3.8790
-> checkpoint 5
[ 2/30][400/938] D_loss: 0.3977 | G_loss: 3.3315
...
-> checkpoint 120
[30/30][938/938] D_loss: 0.3610 | G_loss: 3.8084

๊ทธ๋ฆฌ๊ณ  ์ด๋ฏธ์ง€๋“ค์„ ํ”Œ๋กฏ์— ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.

root@3c0711f20896:/home/build# python display.py -i dcgan-sample-100.pt
Saved out.png

๊ทธ ๊ฒฐ๊ณผ๋Š” ์•„๋ž˜์™€ ๊ฐ™์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

digits

์ˆซ์ž๋„ค์š”! ๋งŒ์„ธ! ์ด์ œ ์—ฌ๋Ÿฌ๋ถ„ ์ฐจ๋ก€์ž…๋‹ˆ๋‹ค. ์ˆซ์ž๊ฐ€ ๋ณด๋‹ค ๋‚˜์•„ ๋ณด์ด๋„๋ก ๋ชจ๋ธ์„ ๊ฐœ์„ ํ•  ์ˆ˜ ์žˆ๋‚˜์š”?

๊ฒฐ๋ก 

์ด ํŠœํ† ๋ฆฌ์–ผ์„ ํ†ตํ•ด PyTorch C++ ํ”„๋ก ํŠธ์—”๋“œ์— ๋Œ€ํ•œ ์–ด๋Š ์ •๋„ ์ดํ•ด๋„๊ฐ€ ์ƒ๊ธฐ์…จ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค. ํ•„์—ฐ์ ์œผ๋กœ PyTorch ๊ฐ™์€ ๋จธ์‹ ๋Ÿฌ๋‹ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋Š” ๋งค์šฐ ๋‹ค์–‘ํ•˜๊ณ  ๊ด‘๋ฒ”์œ„ํ•œ API๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ, ์—ฌ๊ธฐ์„œ ๋…ผ์˜ํ•˜๊ธฐ์— ์‹œ๊ฐ„๊ณผ ๊ณต๊ฐ„์ด ๋ถ€์กฑํ–ˆ๋˜ ๊ฐœ๋…๋“ค์ด ๋งŽ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ง์ ‘ API๋ฅผ ์‚ฌ์šฉํ•ด๋ณด๊ณ , ๋ฌธ์„œ, ๊ทธ ์ค‘์—์„œ๋„ ํŠนํžˆ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ API ์„น์…˜์„ ์ฐธ์กฐํ•ด๋ณด๋Š” ๊ฒƒ์„ ๊ถŒ์žฅ๋“œ๋ฆฝ๋‹ˆ๋‹ค. ๋˜ํ•œ, C++ ํ”„๋ก ํŠธ์—”๋“œ๊ฐ€ ํŒŒ์ด์ฌ ํ”„๋ก ํŠธ์—”๋“œ์˜ ๋””์ž์ธ๊ณผ ์‹œ๋งจํ‹ฑ์„ ๋”ฐ๋ฅธ๋‹ค๋Š” ์‚ฌ์‹ค์„ ์ž˜ ๊ธฐ์–ตํ•˜๋ฉด ๋ณด๋‹ค ๋น ๋ฅด๊ฒŒ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

Tip

๋ณธ ํŠœํ† ๋ฆฌ์–ผ์— ๋Œ€ํ•œ ์ „์ฒด ์†Œ์Šค์ฝ”๋“œ๋Š” ์ด ์ €์žฅ์†Œ ์— ์ œ๊ณต๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

์–ธ์ œ๋‚˜ ๊ทธ๋ ‡๋“ฏ์ด ์–ด๋–ค ๋ฌธ์ œ๊ฐ€ ์ƒ๊ธฐ๊ฑฐ๋‚˜ ์งˆ๋ฌธ์ด ์žˆ์œผ๋ฉด ์ €ํฌ ํฌ๋Ÿผ ์„ ์ด์šฉํ•˜๊ฑฐ๋‚˜ Github ์ด์Šˆ ๋กœ ์—ฐ๋ฝ์ฃผ์„ธ์š”.