Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

[DONT MERGE] Brain Tumor Segmentation #1228

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions plugins/data/brats/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recursive-include digitsDataPluginBrats *.html
2 changes: 2 additions & 0 deletions plugins/data/brats/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
This DIGITS plug-in demonstrates how to load data from the BRATS dataset.
https://www.smir.ch/BRATS/Start2016
6 changes: 6 additions & 0 deletions plugins/data/brats/digitsDataPluginBrats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

from .data import DataIngestion

__all__ = ['DataIngestion']
150 changes: 150 additions & 0 deletions plugins/data/brats/digitsDataPluginBrats/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import os
import re

import numpy as np

from digits.utils import subclass, override, constants
from digits.utils.constants import COLOR_PALETTE_ATTRIBUTE
from digits.extensions.data.interface import DataIngestionInterface
from .forms import DatasetForm, InferenceForm
from . import utils


DATASET_TEMPLATE = "templates/dataset_template.html"
INFERENCE_TEMPLATE = "templates/inference_template.html"


@subclass
class DataIngestion(DataIngestionInterface):
"""
A data ingestion extension for the BRATS dataset
"""

def __init__(self, is_inference_db=False, **kwargs):
super(DataIngestion, self).__init__(**kwargs)

self.userdata['is_inference_db'] = is_inference_db

if 'files' not in self.userdata:
files = utils.find_files(self.dataset_folder,
self.group_id,
self.modality)
if not len(files):
raise ValueError("Failed to find data files in %s for "
"group %s and modality %s"
% (self.dataset_folder, self.group_id, self.modality))
self.userdata['files'] = files

# label palette (0->black (background), 1->white (foreground), others->black)
palette = [0, 0, 0, 255, 255, 255] + [0] * (254 * 3)
self.userdata[COLOR_PALETTE_ATTRIBUTE] = palette

self.userdata['class_labels'] = ['background', 'complete tumor']

@override
def encode_entry(self, entry):
if self.userdata['is_inference_db']:
# for inference, use image with maximum tumor area
filter_method = 'max'
else:
filter_method = self.userdata['filter_method']
feature, label = utils.encode_sample(entry, filter_method)

data = []
if feature.size > 0:
if self.userdata['channel_conversion'] != 'none':
# extract 2D slices: split across axial dimension
features = np.split(feature, feature.shape[0])
labels = np.split(label, label.shape[0])

data = []
for image, label in zip(features, labels):
if self.userdata['channel_conversion'] == 'L':
feature = image
elif self.userdata['channel_conversion'] == 'RGB':
image = image[0]
feature = np.empty(shape=(3, image.shape[0], image.shape[1]),
dtype=image.dtype)
# just copy the same data over the three color channels
feature[:3] = [image, image, image]
data.append((feature, label))
else:
data.append((feature, label))
return data

@staticmethod
@override
def get_category():
return "Images"

@staticmethod
@override
def get_id():
return "images-brats"

@staticmethod
@override
def get_dataset_form():
return DatasetForm()

@staticmethod
@override
def get_dataset_template(form):
"""
parameters:
- form: form returned by get_dataset_form(). This may be populated
with values if the job was cloned
return:
- (template, context) tuple
- template is a Jinja template to use for rendering dataset creation
options
- context is a dictionary of context variables to use for rendering
the form
"""
extension_dir = os.path.dirname(os.path.abspath(__file__))
template = open(os.path.join(extension_dir, DATASET_TEMPLATE), "r").read()
context = {'form': form}
return (template, context)

@override
def get_inference_form(self):
all_entries = self.userdata['files']
n_val_entries = int(len(all_entries)*self.userdata['pct_val']/100)
val_entries = self.userdata['files'][:n_val_entries]
form = InferenceForm()
for idx, entry in enumerate(val_entries):
match = re.match('.*pat(\d+)_.*', entry[0])
if match:
form.validation_record.choices.append((str(idx), 'Patient %s' % match.group(1)))
return form

@staticmethod
@override
def get_inference_template(form):
extension_dir = os.path.dirname(os.path.abspath(__file__))
template = open(os.path.join(extension_dir, INFERENCE_TEMPLATE), "r").read()
context = {'form': form}
return (template, context)

@staticmethod
@override
def get_title():
return "Brain Tumor Segmentation"

@override
def itemize_entries(self, stage):
all_entries = self.userdata['files']
entries = []
if not self.userdata['is_inference_db']:
n_val_entries = int(len(all_entries)*self.pct_val/100)
if stage == constants.TRAIN_DB:
entries = all_entries[n_val_entries:]
elif stage == constants.VAL_DB:
entries = all_entries[:n_val_entries]
elif stage == constants.TEST_DB:
entries = [all_entries[int(self.validation_record)]]

return entries
128 changes: 128 additions & 0 deletions plugins/data/brats/digitsDataPluginBrats/forms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import os

from digits import utils
from digits.utils import subclass
from flask.ext.wtf import Form
from wtforms import validators


@subclass
class DatasetForm(Form):
"""
A form used to create a Sunnybrook dataset
"""

def validate_folder_path(form, field):
if not field.data:
pass
else:
# make sure the filesystem path exists
if not os.path.exists(field.data) or not os.path.isdir(field.data):
raise validators.ValidationError(
'Folder does not exist or is not reachable')
else:
return True

dataset_folder = utils.forms.StringField(
u'Dataset folder',
validators=[
validators.DataRequired(),
validate_folder_path,
],
tooltip="Specify the path to a BRATS dataset."
)

group_id = utils.forms.SelectField(
'Group',
choices=[
('HGG', 'High-Grade Group'),
('LGG', 'Low-Grade Group'),
],
default='HGG',
tooltip="Select a group to train on."
)

modality = utils.forms.SelectField(
'Modality',
choices=[
('all', 'All'),
('Flair', 'FLAIR'),
('T1', 'T1'),
('T1c', 'T1c'),
('T2', 'T2'),
],
default='Flair',
tooltip="Select a modality to train on."
)

filter_method = utils.forms.SelectField(
'Filter',
choices=[
('all', 'All'),
('max', 'Max'),
('threshold', 'Threshold'),
],
default='all',
tooltip="Select a slice filter: 'All' retains all axial slices, "
"'Max' retains only the slice that exhibits max tumor area, "
"'Threshold' retains only slices that have more than "
"1000-pixel tumor area"
)

channel_conversion = utils.forms.SelectField(
'Channel conversion',
choices=[
('none', 'None - 3D grayscale images'),
('RGB', 'RGB - slice into 2D color images'),
('L', 'Grayscale - slice into 2D grayscale images'),
],
default='L',
tooltip="Perform selected channel conversion."
)

pct_val = utils.forms.IntegerField(
u'% for validation',
default=10,
validators=[
validators.NumberRange(min=0, max=100)
],
tooltip="You can choose to set apart a certain percentage of images "
"from the training images for the validation set."
)


@subclass
class InferenceForm(Form):

def validate_file_path(form, field):
if not field.data:
pass
else:
# make sure the filesystem path exists
if not os.path.exists(field.data) and not os.path.isdir(field.data):
raise validators.ValidationError(
'File does not exist or is not reachable')
else:
return True
"""
A form used to perform inference on a text classification dataset
"""
test_image_file = utils.forms.StringField(
u'Image file',
validators=[
validate_file_path,
],
tooltip="Provide an image"
)

validation_record = utils.forms.SelectField(
'Record from validation set',
choices=[
('none', '- select record -'),
],
default='none',
tooltip="Test a record from the validation set."
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

{% from "helper.html" import print_flashes %}
{% from "helper.html" import print_errors %}
{% from "helper.html" import mark_errors %}

<div class="form-group{{mark_errors([form.dataset_folder])}}">
{{ form.dataset_folder.label }}
{{ form.dataset_folder.tooltip }}
{{ form.dataset_folder(class='form-control autocomplete_path', placeholder='folder') }}
</div>

<div class="form-group{{mark_errors([form.group_id])}}">
{{ form.group_id.label }}
{{ form.group_id.tooltip }}
{{ form.group_id(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.modality])}}">
{{ form.modality.label }}
{{ form.modality.tooltip }}
{{ form.modality(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.channel_conversion])}}">
{{ form.channel_conversion.label }}
{{ form.channel_conversion.tooltip }}
{{ form.channel_conversion(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.filter_method])}}">
{{ form.filter_method.label }}
{{ form.filter_method.tooltip }}
{{ form.filter_method(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.pct_val])}}">
{{ form.pct_val.label }}
{{ form.pct_val.tooltip }}
{{ form.pct_val(class='form-control') }}
</div>
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

{% from "helper.html" import print_flashes %}
{% from "helper.html" import print_errors %}
{% from "helper.html" import mark_errors %}

<div class="row">
<div class="col-sm-6">
<h3>Test a record from validation set</h3>
<div class="form-group">
<div class="form-group{{mark_errors([form.validation_record])}}">
{{ form.validation_record.label }}
{{ form.validation_record.tooltip }}
{{ form.validation_record(class='form-control') }}
</div>
</div>
</div>
</div>
37 changes: 37 additions & 0 deletions plugins/data/brats/digitsDataPluginBrats/templates/template.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

{% from "helper.html" import print_flashes %}
{% from "helper.html" import print_errors %}
{% from "helper.html" import mark_errors %}

{{ form.data_stage(class='form-control') }}

<div class="form-group{{mark_errors([form.train_data_file])}}">
{{ form.train_data_file.label }}
{{ form.train_data_file.tooltip }}
{{ form.train_data_file(class='form-control autocomplete_path', placeholder='.csv file') }}
</div>

<div class="form-group{{mark_errors([form.val_data_file])}}">
{{ form.val_data_file.label }}
{{ form.val_data_file.tooltip }}
{{ form.val_data_file(class='form-control autocomplete_path', placeholder='.csv file') }}
</div>

<div class="form-group{{mark_errors([form.alphabet])}}">
{{ form.alphabet.label }}
{{ form.alphabet.tooltip }}
{{ form.alphabet(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.class_labels_file])}}">
{{ form.class_labels_file.label }}
{{ form.class_labels_file.tooltip }}
{{ form.class_labels_file(class='form-control autocomplete_path', placeholder='.txt file') }}
</div>

<div class="form-group{{mark_errors([form.max_chars_per_sample])}}">
{{ form.max_chars_per_sample.label }}
{{ form.max_chars_per_sample.tooltip }}
{{ form.max_chars_per_sample(class='form-control') }}
</div>
Loading