Skip to content

Commit

Permalink
BRATS plug-in
Browse files Browse the repository at this point in the history
  • Loading branch information
gheinrich committed Oct 31, 2016
1 parent 609f31d commit 335efed
Show file tree
Hide file tree
Showing 10 changed files with 495 additions and 0 deletions.
1 change: 1 addition & 0 deletions plugins/data/brats/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recursive-include digitsDataPluginBrats *.html
2 changes: 2 additions & 0 deletions plugins/data/brats/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
This DIGITS plug-in demonstrates how to load data from the BRATS dataset.
https://www.smir.ch/BRATS/Start2016
6 changes: 6 additions & 0 deletions plugins/data/brats/digitsDataPluginBrats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

from .data import DataIngestion

__all__ = ['DataIngestion']
148 changes: 148 additions & 0 deletions plugins/data/brats/digitsDataPluginBrats/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import os
import re

import numpy as np

from digits.utils import subclass, override, constants
from digits.utils.constants import COLOR_PALETTE_ATTRIBUTE
from digits.extensions.data.interface import DataIngestionInterface
from .forms import DatasetForm, InferenceForm
from . import utils


DATASET_TEMPLATE = "templates/dataset_template.html"
INFERENCE_TEMPLATE = "templates/inference_template.html"


@subclass
class DataIngestion(DataIngestionInterface):
"""
A data ingestion extension for the BRATS dataset
"""

def __init__(self, is_inference_db=False, **kwargs):
super(DataIngestion, self).__init__(**kwargs)

self.userdata['is_inference_db'] = is_inference_db

if 'files' not in self.userdata:
files = utils.find_files(self.dataset_folder,
self.group_id,
self.modality)
if not len(files):
raise ValueError("Failed to find data files in %s for "
"group %s and modality %s"
% (self.dataset_folder, self.group_id, self.modality))
self.userdata['files'] = files

# label palette (0->black (background), 1->white (foreground), others->black)
palette = [0, 0, 0, 255, 255, 255] + [0] * (254 * 3)
self.userdata[COLOR_PALETTE_ATTRIBUTE] = palette

self.userdata['class_labels'] = ['background', 'complete tumor']

@override
def encode_entry(self, entry):
if self.userdata['is_inference_db']:
# for inference, use image with maximum tumor area
filter_method = 'max'
else:
filter_method = self.userdata['filter_method']
print entry
feature, label = utils.encode_sample(entry, filter_method)

data = []

if feature.shape[0] > 0:
# split across axial dimension
features = np.split(feature, feature.shape[0])
labels = np.split(label, label.shape[0])

for image, label in zip(features, labels):
if self.userdata['channel_conversion'] == 'L':
feature = image
elif self.userdata['channel_conversion'] == 'RGB':
image = image[0]
feature = np.empty(shape=(3, image.shape[0], image.shape[1]),
dtype=image.dtype)
# just copy the same data over the three color channels
feature[:3] = [image, image, image]
data.append((feature, label))
return data

@staticmethod
@override
def get_category():
return "Images"

@staticmethod
@override
def get_id():
return "images-brats"

@staticmethod
@override
def get_dataset_form():
return DatasetForm()

@staticmethod
@override
def get_dataset_template(form):
"""
parameters:
- form: form returned by get_dataset_form(). This may be populated
with values if the job was cloned
return:
- (template, context) tuple
- template is a Jinja template to use for rendering dataset creation
options
- context is a dictionary of context variables to use for rendering
the form
"""
extension_dir = os.path.dirname(os.path.abspath(__file__))
template = open(os.path.join(extension_dir, DATASET_TEMPLATE), "r").read()
context = {'form': form}
return (template, context)

@override
def get_inference_form(self):
all_entries = self.userdata['files']
n_val_entries = int(len(all_entries)*self.userdata['pct_val']/100)
val_entries = self.userdata['files'][:n_val_entries]
form = InferenceForm()
for idx, entry in enumerate(val_entries):
match = re.match('.*pat(\d+)_.*', entry[0])
if match:
form.validation_record.choices.append((str(idx), 'Patient %s' % match.group(1)))
return form

@staticmethod
@override
def get_inference_template(form):
extension_dir = os.path.dirname(os.path.abspath(__file__))
template = open(os.path.join(extension_dir, INFERENCE_TEMPLATE), "r").read()
context = {'form': form}
return (template, context)

@staticmethod
@override
def get_title():
return "Brain Tumor Segmentation"

@override
def itemize_entries(self, stage):
all_entries = self.userdata['files']
entries = []
if not self.userdata['is_inference_db']:
n_val_entries = int(len(all_entries)*self.pct_val/100)
if stage == constants.TRAIN_DB:
entries = all_entries[n_val_entries:]
elif stage == constants.VAL_DB:
entries = all_entries[:n_val_entries]
elif stage == constants.TEST_DB:
entries = [all_entries[int(self.validation_record)]]

return entries
127 changes: 127 additions & 0 deletions plugins/data/brats/digitsDataPluginBrats/forms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import os

from digits import utils
from digits.utils import subclass
from flask.ext.wtf import Form
from wtforms import validators


@subclass
class DatasetForm(Form):
"""
A form used to create a Sunnybrook dataset
"""

def validate_folder_path(form, field):
if not field.data:
pass
else:
# make sure the filesystem path exists
if not os.path.exists(field.data) or not os.path.isdir(field.data):
raise validators.ValidationError(
'Folder does not exist or is not reachable')
else:
return True

dataset_folder = utils.forms.StringField(
u'Dataset folder',
validators=[
validators.DataRequired(),
validate_folder_path,
],
tooltip="Specify the path to a BRATS dataset."
)

group_id = utils.forms.SelectField(
'Group',
choices=[
('HGG', 'High-Grade Group'),
('LGG', 'Low-Grade Group'),
],
default='HGG',
tooltip="Select a group to train on."
)

modality = utils.forms.SelectField(
'Modality',
choices=[
('all', 'All'),
('Flair', 'FLAIR'),
('T1', 'T1'),
('T1c', 'T1c'),
('T2', 'T2'),
],
default='Flair',
tooltip="Select a modality to train on."
)

filter_method = utils.forms.SelectField(
'Filter',
choices=[
('all', 'All'),
('max', 'Max'),
('threshold', 'Threshold'),
],
default='all',
tooltip="Select a slice filter: 'All' retains all axial slices, "
"'Max' retains only the slice that exhibits max tumor area, "
"'Threshold' retains only slices that have more than "
"1000-pixel tumor area"
)

channel_conversion = utils.forms.SelectField(
'Channel conversion',
choices=[
('RGB', 'RGB'),
('L', 'Grayscale'),
],
default='L',
tooltip="Perform selected channel conversion."
)

pct_val = utils.forms.IntegerField(
u'% for validation',
default=10,
validators=[
validators.NumberRange(min=0, max=100)
],
tooltip="You can choose to set apart a certain percentage of images "
"from the training images for the validation set."
)


@subclass
class InferenceForm(Form):

def validate_file_path(form, field):
if not field.data:
pass
else:
# make sure the filesystem path exists
if not os.path.exists(field.data) and not os.path.isdir(field.data):
raise validators.ValidationError(
'File does not exist or is not reachable')
else:
return True
"""
A form used to perform inference on a text classification dataset
"""
test_image_file = utils.forms.StringField(
u'Image file',
validators=[
validate_file_path,
],
tooltip="Provide an image"
)

validation_record = utils.forms.SelectField(
'Record from validation set',
choices=[
('none', '- select record -'),
],
default='none',
tooltip="Test a record from the validation set."
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

{% from "helper.html" import print_flashes %}
{% from "helper.html" import print_errors %}
{% from "helper.html" import mark_errors %}

<div class="form-group{{mark_errors([form.dataset_folder])}}">
{{ form.dataset_folder.label }}
{{ form.dataset_folder.tooltip }}
{{ form.dataset_folder(class='form-control autocomplete_path', placeholder='folder') }}
</div>

<div class="form-group{{mark_errors([form.group_id])}}">
{{ form.group_id.label }}
{{ form.group_id.tooltip }}
{{ form.group_id(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.modality])}}">
{{ form.modality.label }}
{{ form.modality.tooltip }}
{{ form.modality(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.channel_conversion])}}">
{{ form.channel_conversion.label }}
{{ form.channel_conversion.tooltip }}
{{ form.channel_conversion(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.filter_method])}}">
{{ form.filter_method.label }}
{{ form.filter_method.tooltip }}
{{ form.filter_method(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.pct_val])}}">
{{ form.pct_val.label }}
{{ form.pct_val.tooltip }}
{{ form.pct_val(class='form-control') }}
</div>
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

{% from "helper.html" import print_flashes %}
{% from "helper.html" import print_errors %}
{% from "helper.html" import mark_errors %}

<div class="row">
<div class="col-sm-6">
<h3>Test a record from valiation set</h3>
<div class="form-group">
<div class="form-group{{mark_errors([form.validation_record])}}">
{{ form.validation_record.label }}
{{ form.validation_record.tooltip }}
{{ form.validation_record(class='form-control') }}
</div>
</div>
</div>
</div>
37 changes: 37 additions & 0 deletions plugins/data/brats/digitsDataPluginBrats/templates/template.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

{% from "helper.html" import print_flashes %}
{% from "helper.html" import print_errors %}
{% from "helper.html" import mark_errors %}

{{ form.data_stage(class='form-control') }}

<div class="form-group{{mark_errors([form.train_data_file])}}">
{{ form.train_data_file.label }}
{{ form.train_data_file.tooltip }}
{{ form.train_data_file(class='form-control autocomplete_path', placeholder='.csv file') }}
</div>

<div class="form-group{{mark_errors([form.val_data_file])}}">
{{ form.val_data_file.label }}
{{ form.val_data_file.tooltip }}
{{ form.val_data_file(class='form-control autocomplete_path', placeholder='.csv file') }}
</div>

<div class="form-group{{mark_errors([form.alphabet])}}">
{{ form.alphabet.label }}
{{ form.alphabet.tooltip }}
{{ form.alphabet(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.class_labels_file])}}">
{{ form.class_labels_file.label }}
{{ form.class_labels_file.tooltip }}
{{ form.class_labels_file(class='form-control autocomplete_path', placeholder='.txt file') }}
</div>

<div class="form-group{{mark_errors([form.max_chars_per_sample])}}">
{{ form.max_chars_per_sample.label }}
{{ form.max_chars_per_sample.tooltip }}
{{ form.max_chars_per_sample(class='form-control') }}
</div>
Loading

0 comments on commit 335efed

Please sign in to comment.