Skip to content

Commit

Permalink
perf: 为技能识别加入5倍数据增强
Browse files Browse the repository at this point in the history
  • Loading branch information
MistEO committed Apr 9, 2023
1 parent e1d3993 commit bdfebba
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions combat/skill_ready/train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,25 @@
" self.data = [ self.get(i) for i in range(len(self))]\n",
" \n",
" def __len__(self):\n",
" return len(self.y) + len(self.n)\n",
" return self.len_y() + self.len_n()\n",
" \n",
" def len_y(self):\n",
" # 给 y 来个 5 倍数据增强\n",
" return len(self.y) * 5\n",
" \n",
" def len_n(self):\n",
" return len(self.n)\n",
"\n",
" def get(self, index):\n",
" if index < len(self.y):\n",
" if index < self.len_y():\n",
" if index % 1000 == 0:\n",
" print(f'load y: {index} / {len(self.y)}')\n",
" path = self.y[index]\n",
" print(f'load y: {index} / {self.len_y()}')\n",
" path = self.y[index % len(self.y)]\n",
" label = 1\n",
" else:\n",
" if index % 1000 == 0:\n",
" print(f'load n: {index - len(self.y)} / {len(self.n)}')\n",
" path = self.n[index - len(self.y)]\n",
" print(f'load n: {index - self.len_y()} / {self.len_n()}')\n",
" path = self.n[(index - self.len_y()) % len(self.n)]\n",
" label = 0\n",
" image = self.loader(path)\n",
" image = self.transform(image)\n",
Expand Down Expand Up @@ -324,7 +331,7 @@
"\n",
" for epoch in range(start_epoch, 1000):\n",
" train(epoch)\n",
" if epoch % test_interval != 0:\n",
" if epoch % test_interval != 0 or epoch < 50:\n",
" continue\n",
" \n",
" loss, acc = test()\n",
Expand Down

0 comments on commit bdfebba

Please sign in to comment.