Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Image Resizing and Colab Support #465

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions examples/binary segmentation (camvid).ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,21 @@
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"If dataset needs to be resized use these values MAIN_SIZE_X,MAIN_SIZE_Y = 512,512 else put MAIN_SIZE_X,MAIN_SIZE_Y = None,None\n",
"\"\"\"\n",
"MAIN_SIZE_X,MAIN_SIZE_Y = None,None\n",
"\n",
"\n",
"\n",
"# helper function for data visualization\n",
"def visualize(**images):\n",
" \"\"\"PLot images in one row.\"\"\"\n",
" n = len(images)\n",
" plt.figure(figsize=(16, 5))\n",
" plt.figure(figsize=(20, 12))\n",
" for i, (name, image) in enumerate(images.items()):\n",
" plt.subplot(1, n, i + 1)\n",
" plt.axis('off')\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plt.title(' '.join(name.split('_')).title())\n",
Expand Down Expand Up @@ -175,7 +183,11 @@
" # read data\n",
" image = cv2.imread(self.images_fps[i])\n",
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
" if MAIN_SIZE_X,MAIN_SIZE_Y == None,None:\n",
" image = cv2.resize(image,(MAIN_SIZE_Y,MAIN_SIZE_X))\n",
" mask = cv2.imread(self.masks_fps[i], 0)\n",
" if MAIN_SIZE_X,MAIN_SIZE_Y == None,None:\n",
" mask = cv2.resize(mask,(MAIN_SIZE_Y,MAIN_SIZE_X))\n",
" \n",
" # extract certain classes from mask (e.g. cars)\n",
" masks = [(mask == v) for v in self.class_values]\n",
Expand Down Expand Up @@ -216,7 +228,6 @@
" self.batch_size = batch_size\n",
" self.shuffle = shuffle\n",
" self.indexes = np.arange(len(dataset))\n",
"\n",
" self.on_epoch_end()\n",
"\n",
" def __getitem__(self, i):\n",
Expand All @@ -230,8 +241,8 @@
" \n",
" # transpose list of lists\n",
" batch = [np.stack(samples, axis=0) for samples in zip(*data)]\n",
" \n",
" return batch\n",
" return tuple(batch)\n",
"\n",
" \n",
" def __len__(self):\n",
" \"\"\"Denotes the number of batches per epoch\"\"\"\n",
Expand Down Expand Up @@ -447,6 +458,10 @@
}
],
"source": [
"\"\"\"\n",
"While Using in Colab Use\n",
"%env SM_FRAMEWORK=tf.keras\n",
"\"\"\"\n",
"import segmentation_models as sm\n",
"\n",
"# segmentation_models could also use `tf.keras` if you do not have Keras installed\n",
Expand Down