diff --git a/client/package-lock.json b/client/package-lock.json index 4315062..2b8e826 100644 --- a/client/package-lock.json +++ b/client/package-lock.json @@ -3905,6 +3905,11 @@ "integrity": "sha1-sgOOhG3DO6pXlhKNCAS0VbjB4h0=", "dev": true }, + "debounce": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/debounce/-/debounce-1.2.0.tgz", + "integrity": "sha512-mYtLl1xfZLi1m4RtQYlZgJUNQjl4ZxVnHzIR8nLLgi4q1YT8o/WM+MK/f8yfcc9s5Ir5zRaPZyZU6xs1Syoocg==" + }, "debug": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/debug/-/debug-3.1.0.tgz", @@ -11771,6 +11776,14 @@ } } }, + "vue-google-charts": { + "version": "0.3.3", + "resolved": "https://registry.npmjs.org/vue-google-charts/-/vue-google-charts-0.3.3.tgz", + "integrity": "sha512-dvp5C0m3uAfB7PbAWPOqnYGJLUKyK0NPmnjOE8FDUhVySeErojOPZVd5zNOq4m9RhhwmNFP31bsxVYNb+0+f9w==", + "requires": { + "debounce": "^1.1.0" + } + }, "vue-hot-reload-api": { "version": "2.3.4", "resolved": "https://registry.npmjs.org/vue-hot-reload-api/-/vue-hot-reload-api-2.3.4.tgz", @@ -11822,9 +11835,9 @@ "dev": true }, "vuetify": { - "version": "2.3.16", - "resolved": "https://registry.npmjs.org/vuetify/-/vuetify-2.3.16.tgz", - "integrity": "sha512-LHPqY+Gmyb/75xJscO0a3CuB4ZdpqHLNaGMAbmfTyapI8Q02+hjABEZzitFU/XObD2KhrNWPJzmGZPhbshGUzg==" + "version": "2.3.22", + "resolved": "https://registry.npmjs.org/vuetify/-/vuetify-2.3.22.tgz", + "integrity": "sha512-ZNJA7DCCFucY+Zg4387Q/U2/6YqRxvsue7Atp6iJIKJqgFUzyxrDO9Mod8vsgL/knWpPgWVTOUjbOBy1OaJHbA==" }, "vuetify-loader": { "version": "1.6.0", diff --git a/client/package.json b/client/package.json index a422364..ede65ab 100644 --- a/client/package.json +++ b/client/package.json @@ -16,7 +16,8 @@ "vue": "^2.6.10", "vue-async-computed": "^3.5.1", "vue-router": "3.4.3", - "vuetify": "^2.3.7", + "vuetify": "2.3.22", + "vue-google-charts": "0.3.3", "vuex": "^3.0.1" }, "devDependencies": { diff --git a/client/src/components/DataImportExport.vue b/client/src/components/DataImportExport.vue index 10b6f20..a7186de 100644 --- a/client/src/components/DataImportExport.vue +++ b/client/src/components/DataImportExport.vue @@ -4,12 +4,15 @@ import { mapActions } from "vuex"; export default { name: "DataImportExport", components: {}, - inject: ["girderRest"], + inject: ["girderRest", "notificationBus"], data: () => ({ importEnabled: false, exportEnabled: false, importing: false, - importDialog: false + importDialog: false, + reevaluateDialog: false, + reevaluating: false, + learningMode: "randomForest" }), async created() { var { data: result } = await this.girderRest.get( @@ -18,6 +21,18 @@ export default { this.importEnabled = result.import; this.exportEnabled = result.export; }, + mounted() { + this.notificationBus.$on( + "message:miqa.learning_with_data", + this.learningFinished + ); + }, + beforeDestroy() { + this.notificationBus.$off( + "message:miqa.learning_with_data", + this.learningFinished + ); + }, methods: { ...mapActions(["loadSessions"]), async importData() { @@ -25,9 +40,14 @@ export default { try { var { data: result } = await this.girderRest.post("miqa/data/import"); this.importing = false; + let msg = ""; + if (result.error) { + msg = `Import failed: ${result.error}`; + } else { + msg = `Import finished with ${result.success} scans imported and ${result.failed} failed.`; + } this.$snackbar({ - text: `Import finished. - With ${result.success} scans succeeded and ${result.failed} failed.`, + text: msg, timeout: 6000 }); this.loadSessions(); @@ -47,6 +67,20 @@ export default { text: "Saved data to json file successfully.", positiveButton: "Ok" }); + }, + reevaluate() { + this.reevaluating = true; + this.girderRest.post(`/learning/retrain_with_data/${this.learningMode}`); + }, + learningFinished(a) { + this.reevaluating = false; + this.reevaluateDialog = false; + this.loadSessions(); + this.$prompt({ + title: "Re-evaluate", + text: "Re-evaluate successfully", + positiveButton: "Ok" + }); } } }; @@ -64,6 +98,13 @@ export default { Export + Retrain @@ -84,6 +125,42 @@ export default { + + + + Re-evaluate + + + This will update the learning model with values of all current + sessions and reevaluate current unmarked sessions + + + + + + + + + + Cancel + Re-evaluate + + + diff --git a/client/src/components/MetricsDisplay.vue b/client/src/components/MetricsDisplay.vue new file mode 100644 index 0000000..9380f3b --- /dev/null +++ b/client/src/components/MetricsDisplay.vue @@ -0,0 +1,98 @@ + + + diff --git a/client/src/components/NavigationTabs.vue b/client/src/components/NavigationTabs.vue index 8feae6d..22f29f7 100644 --- a/client/src/components/NavigationTabs.vue +++ b/client/src/components/NavigationTabs.vue @@ -30,6 +30,10 @@ export default { view_column Sessions + + bar_chart + Metrics + settings Settings diff --git a/client/src/main.js b/client/src/main.js index 74b497e..a91d86f 100644 --- a/client/src/main.js +++ b/client/src/main.js @@ -8,6 +8,7 @@ import App from "./App.vue"; import router from "./router"; import store from "./store"; import Girder, { RestClient, utils } from "@girder/components/src"; +import NotificationBus from "@girder/components/src/utils/notifications"; import { API_URL, STATIC_PATH } from "./constants"; import vMousetrap from "./vue-utilities/v-mousetrap"; @@ -34,6 +35,8 @@ Vue.use(snackbarService(vuetify)); Vue.use(promptService(vuetify)); girder.rest = new RestClient({ apiRoot: API_URL }); +const notificationBus = new NotificationBus(girder.rest); +notificationBus.connect(); import config from "itk/itkConfig"; config.itkModulesPath = STATIC_PATH + config.itkModulesPath; @@ -49,7 +52,7 @@ girder.rest.fetchUser().then(() => { router, store, render: h => h(App), - provide: { girderRest: girder.rest } + provide: { girderRest: girder.rest, notificationBus } }) .$mount("#app") .$snackbarAttach() diff --git a/client/src/router.js b/client/src/router.js index 2399b56..f0ab590 100644 --- a/client/src/router.js +++ b/client/src/router.js @@ -5,6 +5,7 @@ import girder from "./girder"; import Settings from "./views/Settings.vue"; import Dataset from "./views/Dataset.vue"; import Login from "./views/Login.vue"; +import Metrics from "./views/Metrics.vue"; Vue.use(Router); @@ -41,6 +42,12 @@ export default new Router({ component: Settings, beforeEnter: beforeEnterAdmin }, + { + path: "/metrics", + name: "metrics", + component: Metrics, + beforeEnter: beforeEnterAdmin + }, // Order matters { path: "/:datasetId?", diff --git a/client/src/store/index.js b/client/src/store/index.js index 89c414b..793dc4d 100644 --- a/client/src/store/index.js +++ b/client/src/store/index.js @@ -58,6 +58,10 @@ const store = new Vuex.Store({ return state.datasets[datasetId]; }; }, + allDatasets(state) { + console.log('allDatasets'); + return Object.keys(state.datasets).map(dsId => state.datasets[dsId]); + }, currentSession(state, getters) { if (getters.currentDataset) { const curSessionId = getters.currentDataset.session; diff --git a/client/src/utils/iqmMeta.js b/client/src/utils/iqmMeta.js new file mode 100644 index 0000000..3c3671a --- /dev/null +++ b/client/src/utils/iqmMeta.js @@ -0,0 +1,114 @@ +export default { + cjv: { + display: "CJV", + fullname: "Coefficient of Joint Variation", + description: + "Higher values are related to the presence of heavy head motion and large INU artifacts. So, lower values are preferred." + }, + cnr: { + display: "CNR", + fullname: "Contrast-to-Noise Ratio", + description: + "An extension of the SNR calculation to evaluate how separated the tissue distributions of GM and WM are. Higher values indicate better quality." + }, + efc: { + display: "EFC", + fullname: "Entropy-Focus Criterion", + description: + "Uses the Shannon entropy of voxel intensities as an indication of ghosting and blurring induced by head motion. Lower values are better." + }, + fber: { + display: "FBER", + fullname: "Foreground-Background Energy Ratio", + description: + "Calculated as the mean energy of image values within the head relative the mean energy of image values in the air mask. Consequently, higher values are better." + }, + fwhm: { + display: "FWHM", + fullname: "Full-Width Half-Maximum", + description: + "An estimation of the blurriness of the image using AFNI’s 3dFWHMx. Smaller is better. It is calculated for x-axis, y-axis, z-axis, and average value." + }, + icvs: { + display: "ICV", + fullname: "Intra-Cranial Volume", + description: + "Estimation of the icv of each tissue calculated on the FSL FAST’s segmentation. Normative values fall around 20%, 45% and 35% for cerebrospinal fluid (CSF), White Matter (WM), and Grey Matter (GM), respectively. Thus, it provides 3 values." + }, + inu: { + display: "INU", + fullname: "Intensity Non-Uniformity", + description: + "MRIQC measures the location and spread of the bias field extracted estimated by the inu correction. The smaller spreads located around 1.0 are better. It provides the median and range for INU, thus 2 values." + }, + qi: { + display: "QI", + fullname: "Quality Index", + description: + "1: measures the amount of artifactual intensities in the air surrounding the head above the nasio-cerebellar axis.
2: a calculation of the goodness-of-fit of a chi-square distribution on the air mask. The smaller the better for both numbers." + }, + rpve: { + display: "rPVE", + fullname: "Residual Partial Volume Effect", + description: + "A tissue-wise sum of partial volumes that fall in the range [5%-95%] of the total volume of a pixel. Smaller residual partial volume effects (rPVEs) are better. It provides the score for cerebro-spinal fluid(CSF), white matter(WM), grey matter(GM)." + }, + size: { + display: "size", + fullname: "", + description: "" + }, + snr: { + display: "SNR", + fullname: "Signal-to-Noise Ratio", + description: + "SNR is reported using air background as noise reference for cerebro-spinal fluid(CSF), white matter(WM), grey matter(GM), and total. Also, a simplified calculation using the within tissue variance is also provided for CSF, WM, GM, and total image. Higher the SNR, the better the image quality." + }, + snrd: { + display: "SNRd", + fullname: "Signal-to-Noise Ratio-dietrich", + description: + "SNR is reported using air background as noise reference for cerebro-spinal fluid(CSF), white matter(WM), grey matter(GM), and total. Also, a simplified calculation using the within tissue variance is also provided for CSF, WM, GM, and total image using SNR-dietrich. Higher the SNR, the better the image quality." + }, + spacing: { + display: "spacing", + fullname: "", + description: "" + }, + summary_bg: { + display: "SSTATS background", + fullname: "Summary Statistics Background", + description: + "Several summary statistics (mean, standard deviation, percentiles 5% and 95%, and kurtosis) are computed within the following regions of interest: background, CSF, WM, and GM. There are 8 scores for each region, thus 32 values are reported." + }, + summary_csf: { + display: "SSTATS CSF", + fullname: "Summary Statistics CSF", + description: + "Several summary statistics (mean, standard deviation, percentiles 5% and 95%, and kurtosis) are computed within the following regions of interest: background, CSF, WM, and GM. There are 8 scores for each region, thus 32 values are reported." + }, + summary_gm: { + display: "SSTATS GM", + fullname: "Summary Statistics GM", + description: + "Several summary statistics (mean, standard deviation, percentiles 5% and 95%, and kurtosis) are computed within the following regions of interest: background, CSF, WM, and GM. There are 8 scores for each region, thus 32 values are reported." + }, + summary_wm: { + display: "SSTATS WM", + fullname: "Summary Statistics WM", + description: + "Several summary statistics (mean, standard deviation, percentiles 5% and 95%, and kurtosis) are computed within the following regions of interest: background, CSF, WM, and GM. There are 8 scores for each region, thus 32 values are reported." + }, + tpm_overlap: { + display: "TPMs", + fullname: "Tissue Probability Maps", + description: + "Overlap of tissue probability maps estimated from the image and the corresponding maps from the ICBM nonlinear-asymmetric 2009c template. TPM is calculated for CSF, WM, and GM." + }, + wm2max: { + display: "WM2MAX", + fullname: "White Matter to Maximum Intensity Ratio", + description: + "The white-matter to maximum intensity ratio is the median intensity within the WM mask over the 95% percentile of the full intensity distribution, that captures the existence of long tails due to hyper-intensity of the carotid vessels and fat. Values should be around the interval [0.6, 0.8]." + } +}; diff --git a/client/src/views/Dataset.vue b/client/src/views/Dataset.vue index 7fe732d..74071a2 100644 --- a/client/src/views/Dataset.vue +++ b/client/src/views/Dataset.vue @@ -16,11 +16,12 @@ import WindowControl from "@/components/WindowControl"; import ScreenshotDialog from "@/components/ScreenshotDialog"; import EmailDialog from "@/components/EmailDialog"; import KeyboardShortcutDialog from "@/components/KeyboardShortcutDialog"; +import MetricsDisplay from "@/components/MetricsDisplay"; import NavigationTabs from "@/components/NavigationTabs"; import { cleanDatasetName } from "@/utils/helper"; export default { - name: "dataset", + name: "Dataset", components: { NavbarTitle, UserButton, @@ -31,7 +32,8 @@ export default { ScreenshotDialog, EmailDialog, KeyboardShortcutDialog, - NavigationTabs + NavigationTabs, + MetricsDisplay }, inject: ["girderRest", "userLevel"], data: () => ({ @@ -47,7 +49,8 @@ export default { showNotePopup: false, keyboardShortcutDialog: false, scanning: false, - direction: "forward" + direction: "forward", + initializeLoading: false }), computed: { ...mapState([ @@ -102,7 +105,9 @@ export default { this.debouncedDatasetSliderChange, 30 ); + this.initializeLoading = true; await Promise.all([this.loadSessions(), this.loadSites()]); + this.initializeLoading = false; var datasetId = this.$route.params.datasetId; var dataset = this.getDataset(datasetId); if (dataset) { @@ -305,6 +310,12 @@ export default { handleMouseUp() { this.scanning = false; window.cancelAnimationFrame(this.nextAnimRequest); + }, + getMetricColor(dataset) { + if (!dataset || !dataset.meta.goodProb) { + return null; + } + return dataset.meta.goodProb > 0.5 ? "green" : "red"; } } }; @@ -316,9 +327,29 @@ export default { - - keyboard - + + + + email + + keyboard + +import { mapGetters, mapActions } from "vuex"; +import { GChart } from "vue-google-charts"; + +import GenericNavigationBar from "@/components/GenericNavigationBar"; +import iqmMeta from "../utils/iqmMeta"; + +export default { + name: "Metrics", + components: { + GenericNavigationBar, + GChart + }, + data() { + return { + chartsReady: false + }; + }, + computed: { + iqmMeta: () => iqmMeta, + ...mapGetters(["allDatasets"]), + allDatasetWithIQM() { + if (!this.allDatasets) { + return []; + } + return this.allDatasets.filter( + dataset => dataset.meta && dataset.meta.iqm + ); + }, + iqmMetricsTables() { + if (!this.allDatasetWithIQM.length) { + return []; + } + return this.allDatasetWithIQM[0].meta.iqm.map((metric, index) => { + var key = Object.keys(metric)[0]; + return { + key, + ...this.calculateDataTableArrayAndOptions( + key, + this.allDatasetWithIQM.map( + dataset => Object.values(dataset.meta.iqm[index])[0] + ) + ) + }; + }); + } + }, + async created() { + if (!this.allDatasets || !this.allDatasets.length) { + this.loadSessions(); + } + }, + methods: { + ...mapActions(["loadSessions"]), + calculateDataTableArrayAndOptions(key, values) { + var headers = ["Category"]; + var rows = []; + var options = { + title: iqmMeta[key].fullname, + width: "100%", + height: "300", + hAxis: { + baseline: 0, + baselineColor: "transparent", + minValue: -1, + maxValue: undefined, + ticks: [], + gridlines: { color: "transparent" } + }, + legend: "none", + tooltip: { trigger: "selection" } + }; + if (!Array.isArray(values[0])) { + headers.push(key, { type: "string", role: "tooltip" }); + // options.width = '400'; + options.numberOfSeries = 1; + options.hAxis.maxValue = 4; + options.hAxis.ticks.push({ v: 0, f: key }); + values.forEach(value => { + rows.push([ + Math.floor(Math.random() * 12) * 0.05, + value, + `${key}: ${value}` + ]); + }); + } else { + let numberOfSubs = values[0].length; + options.numberOfSeries = numberOfSubs; + // options.width = 200 + 100 * numberOfSubs; + options.hAxis.maxValue = numberOfSubs * 2; + values[0].forEach((subValueObject, i) => { + var sub = Object.keys(subValueObject)[0]; + headers.push(sub, { type: "string", role: "tooltip" }); + options.hAxis.ticks.push({ v: i * 2, f: `${key}_${sub}` }); + }); + values.forEach(subValueObjects => { + subValueObjects.forEach((subValueObject, i) => { + var sub = Object.keys(subValueObject)[0]; + var value = subValueObject[sub]; + var row = [ + Math.floor(Math.random() * 12) * 0.05 + i * 2, + ...new Array(numberOfSubs * 2).fill(null) + ]; + row[1 + i * 2] = value; + row[2 + i * 2] = `${key}_${sub}: ${value}`; + rows.push(row); + }); + }); + } + return { + values: [headers, ...rows], + options + }; + }, + getFlexSize(table) { + var { numberOfSeries } = table.options; + var size = Math.round(numberOfSeries * 1.5 + 2); + size = size > 12 ? 12 : size; + return size; + }, + onChartReady(chart, table) { + this.chartsReady = true; + chart.setAction({ + id: "sample", + text: "See dataset", + action: () => { + var { row } = chart.getSelection()[0]; + var { numberOfSeries } = table.options; + var datasetIndex = Math.floor(row / numberOfSeries); + var dataset = this.allDatasetWithIQM[datasetIndex]; + this.$router.push(dataset._id); + } + }); + } + } +}; + + + diff --git a/client/src/vue-utilities/snackbar-service/Snackbar.vue b/client/src/vue-utilities/snackbar-service/Snackbar.vue index 552cf94..b136964 100644 --- a/client/src/vue-utilities/snackbar-service/Snackbar.vue +++ b/client/src/vue-utilities/snackbar-service/Snackbar.vue @@ -15,7 +15,7 @@ export default { text: "", button: "", callback: null, - timeout: 0, + timeout: -1, options: { left: true } }), methods: { diff --git a/development.md b/development.md index f83e3e9..64b5449 100644 --- a/development.md +++ b/development.md @@ -3,19 +3,106 @@ MIQA has server, client two components. They are located under *server* and *client* directory respectively. ### Prerequisite -* Pyhton 3.5+ +* Python 3.5+ * Mongodb running at the default port * Node 8 ### Server +#### Pre-requisite for active learning: + +Both `celery` and RabbitMQ are required for the active learning component of the application, which uses Girder Worker. + +RabbitMQ is available for Mac OS via homebrew, or on debian Linux systems via `apt-get`. There is also a Docker image available. See the Girder Worker [docs](https://girder-worker.readthedocs.io/en/latest/getting-started.html) on getting started for details. Normally, `celery` is available as a python module, but do not install it directly, as the latest version is not compatible with Girder Worker. Rather, install the `girder-worker` module, which installs appropriate versions of dependencids. The install command is shown in the steps below. + #### Setup + * `git clone https://github.com/OpenImaging/miqa.git` * `pip install -e miqa/server/` +* `pip install -e miqa/miqa_worker_task/` +* `pip install -e miqa/mriqc/` +* `pip install girder-worker` * `girder build` * `girder serve` + Now a running girder instance should be available at `localhost:8080/girder` +#### Extra steps for active learning + +First start the rabbitmq server: + +``` +/usr/local/sbin/rabbitmq-server +``` + +Assuming the installation was successful, this should print output similar to the following: + +``` +$ /usr/local/sbin/rabbitmq-server +Configuring logger redirection + + ## ## RabbitMQ 3.8.9 + ## ## + ########## Copyright (c) 2007-2020 VMware, Inc. or its affiliates. + ###### ## + ########## Licensed under the MPL 2.0. Website: https://rabbitmq.com + + Doc guides: https://rabbitmq.com/documentation.html + Support: https://rabbitmq.com/contact.html + Tutorials: https://rabbitmq.com/getstarted.html + Monitoring: https://rabbitmq.com/monitoring.html + + Logs: /usr/local/var/log/rabbitmq/rabbit@localhost.log + /usr/local/var/log/rabbitmq/rabbit@localhost_upgrade.log + + Config file(s): (none) + + Starting broker... + + completed with 6 plugins. +``` + +Set the MRIQC master path in your environment and start `celery`: + +``` +export MIQA_MRIQC_PATH=/Users/scott/miqa/mriqc_master_folder +celery worker -A girder_worker.app -l info +``` + +If there were no errors, then this should print something like: + +``` +$ celery worker -A girder_worker.app -l info +/_REDACTED_/site-packages/celery/backends/amqp.py:67: CPendingDeprecationWarning: + The AMQP result backend is scheduled for deprecation in version 4.0 and removal in version v5.0. Please use RPC backend or a persistent backend. + + alternative='Please use RPC backend or a persistent backend.') + + -------------- celery@_REDACTED_.local v4.4.7 (cliffs) +--- ***** ----- +-- ******* ---- Darwin-19.6.0-x86_64-i386-64bit 2021-01-14 17:41:41 +- *** --- * --- +- ** ---------- [config] +- ** ---------- .> app: girder_worker:0x10d6c27d0 +- ** ---------- .> transport: amqp://guest:**@localhost:5672// +- ** ---------- .> results: amqp:// +- *** --- * --- .> concurrency: 8 (prefork) +-- ******* ---- .> task events: OFF (enable -E to monitor tasks in this worker) +--- ***** ----- + -------------- [queues] + .> celery exchange=celery(direct) key=celery + + +[tasks] + . girder_worker.docker.tasks.docker_run + . miqa_worker_task.tasks.retrain_with_data_task + +[2021-01-14 17:41:41,600: INFO/MainProcess] Connected to amqp://guest:**@127.0.0.1:5672// +[2021-01-14 17:41:41,615: INFO/MainProcess] mingle: searching for neighbors +[2021-01-14 17:41:42,648: INFO/MainProcess] mingle: all alone +[2021-01-14 17:41:42,665: INFO/MainProcess] celery@_REDACTED_.local ready. +``` + #### Setup Girder with Miqa Server * Navigate to `localhost:8080/girder` * Create a user @@ -52,3 +139,16 @@ importing arbitrary datasets, see `IMPORTING_DATA.md`. * Navigate back to localhost:8081 and navigate to `Settings` tab * Suppose the repo is located at home diretory. * Set `import path` to `~/miqa/sample_data/sample.json` and set `export path` to something like `~/miqa/sample_data/sample-output.json` + +#### Running MRIQC on a dataset using MIQA `mriqc` module + +The `mriqc` module in this project can be used to generate quality metrics on a set of images and update the csv file with a new column containing the appropriately formatted value string. + + python data2mriqc.py \ + --csv_input_path /Users/scott/miqa/mriqc_master_folder/input/scans_to_review-2019-01-23.csv \ + --root "" \ + --bids_output_dir /Users/scott/miqa/mriqc_master_folder/input/bids_output \ + --mriqc_output_path /Users/scott/miqa/mriqc_master_folder/input/mriqc_output \ + --csv_output_path /Users/scott/miqa/mriqc_master_folder/mriqc_output.csv + +If the csv file containing the images you want to QC already contain absolute paths, then the `--root` cli arg should be the empty string as shown above. Otherwise, `--root` should be the path to the folder which is the root of the scan paths given in the csv. diff --git a/miqa_worker_task/.gitignore b/miqa_worker_task/.gitignore new file mode 100644 index 0000000..d81445f --- /dev/null +++ b/miqa_worker_task/.gitignore @@ -0,0 +1,2 @@ +*.egg-info +*.py[cod] diff --git a/miqa_worker_task/README.rst b/miqa_worker_task/README.rst new file mode 100644 index 0000000..261cb5e --- /dev/null +++ b/miqa_worker_task/README.rst @@ -0,0 +1 @@ +Worker task module for MIQA (with learning) diff --git a/miqa_worker_task/miqa_worker_task/__init__.py b/miqa_worker_task/miqa_worker_task/__init__.py new file mode 100644 index 0000000..0462393 --- /dev/null +++ b/miqa_worker_task/miqa_worker_task/__init__.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- + +"""Top-level package for miqa_worker_task.""" + +__author__ = """Kitware Inc""" +__email__ = 'kitware@kitware.com' +__version__ = '0.0.0' + + +from girder_worker import GirderWorkerPluginABC + + +class MiqaWorkerTask(GirderWorkerPluginABC): + def __init__(self, app, *args, **kwargs): + self.app = app + + def task_imports(self): + # Return a list of python importable paths to the + # plugin's path directory + return ['miqa_worker_task.tasks'] diff --git a/miqa_worker_task/miqa_worker_task/dummy_learn.py b/miqa_worker_task/miqa_worker_task/dummy_learn.py new file mode 100644 index 0000000..ce31b6a --- /dev/null +++ b/miqa_worker_task/miqa_worker_task/dummy_learn.py @@ -0,0 +1,27 @@ +import os +import shutil +import tempfile +import time + + +def retrain(csv_path=None): + # pseudo code + # if not csv_path: + # retrain model with the master csv + # return None + # else: + # csv = read(csv_path) + # if csv has rows with decision: + # update master csv and retrain model + # evalute all rows with model regard less if there is a decision and save the score to a dedicate column + # return a path to the new csv file with new scores + + # dummy code below + with open(csv_path, 'r') as csv_file: + print(csv_file.read()) + new_path = os.path.join( + tempfile.mkdtemp(), 'new_session.csv') + time.sleep(4) + shutil.copyfile(csv_path, new_path) + print(csv_path, new_path) + return new_path diff --git a/miqa_worker_task/miqa_worker_task/tasks.py b/miqa_worker_task/miqa_worker_task/tasks.py new file mode 100644 index 0000000..0b0a24d --- /dev/null +++ b/miqa_worker_task/miqa_worker_task/tasks.py @@ -0,0 +1,17 @@ +import os + +from girder import logger +from girder_worker.app import app +from girder_worker.utils import girder_job + + +@girder_job(title='Retrain with data') +@app.task(bind=True) +def retrain_with_data_task(self, csv_file_path, learningMode): + from mriqc.active_learner import train + + # the path where the master_folder exists + master_path = os.environ['MIQA_MRIQC_PATH'] + + # logger.info('Calling train({0}, {1})'.format(master_path, csv_file_path)) + return train(master_path, csv_file_path, learningMode) diff --git a/miqa_worker_task/miqa_worker_task/transform.py b/miqa_worker_task/miqa_worker_task/transform.py new file mode 100644 index 0000000..3a27da7 --- /dev/null +++ b/miqa_worker_task/miqa_worker_task/transform.py @@ -0,0 +1,30 @@ +import os +import shutil +import tempfile + +from girder import logger +from girder_worker_utils.transforms.girder_io import GirderClientTransform + + +class TextToFile(GirderClientTransform): + def __init__(self, stringIO, **kwargs): + super(TextToFile, self).__init__(**kwargs) + self.stringIO = stringIO + + def _repr_model_(self): + return "{}".format(self.__class__.__name__) + + def transform(self): + self.file_path = os.path.join( + tempfile.mkdtemp(), 'session.csv') + + with open(self.file_path, "w") as csv_file: + self.stringIO.seek(0) + shutil.copyfileobj(self.stringIO, csv_file) + + return self.file_path + + def cleanup(self): + logger.info('TextToFile will cleanup {0}'.format(self.file_path)) + shutil.rmtree(os.path.dirname(self.file_path), + ignore_errors=True) diff --git a/miqa_worker_task/setup.py b/miqa_worker_task/setup.py new file mode 100644 index 0000000..e7d747c --- /dev/null +++ b/miqa_worker_task/setup.py @@ -0,0 +1,45 @@ +from setuptools import setup, find_packages + +with open('README.rst', 'r') as fh: + long_desc = fh.read() + +setup(name='miqa_worker_task', + version='0.0.1', + description='Miqa active learning task', + long_description=long_desc, + author='Kitware Inc', + author_email='kitware@kitware.com', + license='MIT license', + classifiers=[ + 'Development Status :: 3 - Alpha', + 'License :: OSI Approved :: MIT License', + 'Topic :: Scientific/Engineering :: Medical Science Apps.', + 'Intended Audience :: Healthcare Industry', + 'Intended Audience :: Science/Research', + 'Natural Language :: English', + 'Programming Language :: Python' + ], + install_requires=[ + 'girder_worker', + 'girder_worker_utils' + # TODO: Add additional packages required by both + # producer and consumer side installations + ], + extras_require={ + 'girder': [ + # TODO: Add dependencies here that are required for the + # package to work on the producer (Girder) side. + ], + 'worker': [ + # TODO: Add dependencies here that are required for the + # package to work on the consumer (Girder Worker) side. + ] + }, + include_package_data=True, + entry_points={ + 'girder_worker_plugins': [ + 'miqa_worker_task = miqa_worker_task:MiqaWorkerTask', + ] + }, + packages=find_packages(), + zip_safe=False) diff --git a/mriqc/master_folder.zip b/mriqc/master_folder.zip new file mode 100755 index 0000000..1f56db2 Binary files /dev/null and b/mriqc/master_folder.zip differ diff --git a/mriqc/mriqc/active_learner.py b/mriqc/mriqc/active_learner.py index 8f32d0a..c845452 100644 --- a/mriqc/mriqc/active_learner.py +++ b/mriqc/mriqc/active_learner.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Fri Jul 26 09:38:23 2019 @@ -6,62 +6,80 @@ @author: dhruv.sharma """ import os - -from .model import Model -from .data_loader import Data -from .strategy import uncertainty_sampling from glob import glob -def predict(master_path, path): +from girder import logger + +from mriqc.model_rf import ModelRF +from mriqc.model_nn import ModelNN +from mriqc.data_loader import Data +from mriqc.strategy import uncertainty_sampling + + +def debug(msg): + logger.info(msg) + print(msg) + + +def predict(master_path, path, learningMode="randomForest"): ''' This function forms the engine for the first part of MIQA and AL. This function helps in making the predictions for the new data that has been input. - + Args: master_path: the absolute path to the master folder with all important directories like training_data.csv, model_weights, log files - path: path to the csv file with the new input data + path: path to the csv file with the new input data (must contain 'IQMs' col w/ quality metrics) Returns: path: the new path to the file with the new input data ''' - # read the variables and data, load the model + # read the variables and data, load the model weights_dir = 'saved_model' model_path = os.path.join(master_path, weights_dir) - - model = Model() + new_data = Data(path) - + # if the model weights folder isn't there, make one if not os.path.isdir(model_path): os.mkdir(model_path) - - # if the model hasn't been trained yet, train it - if(len(glob(os.path.join(model_path, '*.pkl'))) == 0): - train(master_path) - + + if learningMode == "randomForest": + model = ModelRF() + # if the model hasn't been trained yet, train it + if (len(glob(os.path.join(model_path, '*.pkl'))) == 0): + train(master_path) + elif learningMode == "neuralNetwork": + model = ModelNN() + # if the model hasn't been trained yet, train it + if (len(glob(os.path.join(model_path, '*.pth'))) == 0): + train(master_path) + else: + raise Exception("Unknown learningMode: " + learningMode) + # load the most recently saved model model.load_model(model_path) - + # get the features of the new data X_new_sub_ids, X_new_features = new_data.get_features() - + # make predictions X_new_preds = model.predict_proba(X_new_features) - + # add these predictions to the csv new_data.set_predictions(X_new_sub_ids, X_new_preds) - + # save the predictions to the csv new_data.save() - + return path -def train(master_path, csv_path=None): + +def train(master_path, csv_path=None, learningMode="randomForest"): ''' This function ins the engine to train the model with the new data just labeled by the user of MIQA. The model can also be trained on the previously available labeled data. - + Args: master_path: the absolute path to the master folder with all important directories like training_data.csv, model_weights, log files @@ -70,55 +88,75 @@ def train(master_path, csv_path=None): None ''' # read the variables and data, load the model - + debug("master_path: " + master_path) + debug("csv_path: " + csv_path) + debug("learningMode: " + learningMode) + weights_dir = 'saved_model' model_path = os.path.join(master_path, weights_dir) - + training_data_path = os.path.join(master_path, 'training_data.csv') training_data = Data(training_data_path) - + if not os.path.isdir(model_path): os.mkdir(model_path) - + if csv_path is not None: training_data_path = os.path.join(master_path, 'training_data.csv') training_data = Data(training_data_path) # load the data new_data = Data(csv_path) - + # get the query points + # debug('Going to get possible query points') idx, pred, _ = new_data.get_possible_query_points() query_idx = uncertainty_sampling(idx, good_preds=pred, n_instances=2) - + # extract the query points from data + # debug('Going to extract query points') query_data = new_data.extract_query_points(query_idx) + # debug(query_data) + # add this new data to training_ data + # debug('Going to add new query data to existing data') training_data.add_new_data(query_data) - + # debug(training_data) + + # debug('Going to save updated training data') training_data.save() - + # load the (updated) training data and get the features and labels training_data_path = os.path.join(master_path, 'training_data.csv') + # debug('Going to read training data in again {0}'.format(training_data_path)) training_data = Data(training_data_path) - + + # debug('Going to get features and labels') _, X, y = training_data.get_feature_and_labels() - + # debug('Got features and labels') + + # debug('got features') + # debug(X) + # debug('got labels') + # debug(y) + + if learningMode == "randomForest": + model = ModelRF() + elif learningMode == "neuralNetwork": + model = ModelNN() + else: + raise Exception("Unknown learningMode: " + learningMode) + # upload the model and fit over the dataset - model = Model() model.fit(X, y) model.save_model(model_path) - + # update the predictions for the new data with the newly trained model if csv_path is not None: return predict(master_path, csv_path) - + + if __name__ == '__main__': - master_path = '/home/dhruv.sharma/Projects/MRIQC_AL/master_folder' - csv_path = '/home/dhruv.sharma/Projects/MRIQC_AL/mriqc_output.csv' - decision_path = '/home/dhruv.sharma/Projects/MRIQC_AL/mriqc_output_decision_dummy.csv' - train(master_path, decision_path) -# predict(master_path, csv_path) - - - - \ No newline at end of file + master_path = r'M:\Dev\zarr\sample data new-2020-07' + decision_path = r'M:\Dev\zarr\sample data new-2020-07\scans_to_review_output.csv' + + train(master_path, decision_path, "neuralNetwork") diff --git a/mriqc/mriqc/data2bids/__init__.py b/mriqc/mriqc/data2bids/__init__.py index daa2b02..4205644 100644 --- a/mriqc/mriqc/data2bids/__init__.py +++ b/mriqc/mriqc/data2bids/__init__.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Fri Jun 21 16:41:47 2019 diff --git a/mriqc/mriqc/data2bids/bidsify.py b/mriqc/mriqc/data2bids/bidsify.py index 849d0df..e363388 100644 --- a/mriqc/mriqc/data2bids/bidsify.py +++ b/mriqc/mriqc/data2bids/bidsify.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Wed Jun 12 15:21:02 2019 diff --git a/mriqc/mriqc/data2bids/generate_bids_json.py b/mriqc/mriqc/data2bids/generate_bids_json.py index be81733..04e5f75 100644 --- a/mriqc/mriqc/data2bids/generate_bids_json.py +++ b/mriqc/mriqc/data2bids/generate_bids_json.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Tue Jun 11 17:43:09 2019 diff --git a/mriqc/mriqc/data2bids/restructure_files.py b/mriqc/mriqc/data2bids/restructure_files.py index b53003e..dc2448e 100644 --- a/mriqc/mriqc/data2bids/restructure_files.py +++ b/mriqc/mriqc/data2bids/restructure_files.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Tue Jun 11 11:58:18 2019 @@ -64,7 +64,7 @@ def group_by_subject(list_of_scans): return subject_wise -def get_initial_dict(name = "Project", License = "ABC", funding = "XYZ", +def get_initial_dict(name = "Project", License = "ABC", funding = ["UVW", "XYZ"], refs_and_links = ["paper1", "paper2"], DatabaseDOI = "mm/dd/yyyy", authors = ["author1", "author2", "author3"], bids_version="1.0.2", ): @@ -77,7 +77,7 @@ def get_initial_dict(name = "Project", License = "ABC", funding = "XYZ", "BIDSVersion": "1.0.2" or the version you are using, "License": License under which your data is distributed, "Authors": ["Author1", "Author2", "Author3", "etc."], - "funding": Put your funding sources here, + "funding": ["Put your funding sources here", "as a list"], "refs_and_links": ["e.g. data-paper", "(methods-)paper", "etc."], "DatasetDOI": DOI of dataset (if there is one) @@ -214,4 +214,4 @@ def main(): if __name__ == '__main__': main() - \ No newline at end of file + diff --git a/mriqc/mriqc/data2mriqc.py b/mriqc/mriqc/data2mriqc.py index 20987c4..0d50443 100644 --- a/mriqc/mriqc/data2mriqc.py +++ b/mriqc/mriqc/data2mriqc.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Fri Jun 21 16:39:21 2019 @@ -31,7 +31,7 @@ def main(): ap.add_argument("-bo", "--bids_output_dir", required=True, help = "path to the BIDS output directory") ap.add_argument("-mo", "--mriqc_output_path", required=True, - help = "path to the MRIQC output directory") + help = "path to the MRIQC output directory (*** MUST END IN A '/' ***)") ap.add_argument("-co", "--csv_output_path", required=True, help = "path to save the processed csv file") @@ -44,6 +44,7 @@ def main(): csv_output_path = args["csv_output_path"]# "../../mriqc_output.csv" header, csv_content = get_csv_contents(csv_input_path) + subject_wise_data = group_by_subject(csv_content) # print(subject_wise_data['E08706']) @@ -67,19 +68,20 @@ def main(): ############ #MRIQC command # input -> bids_output_path, mriqc_output_path - command = 'docker run -ti --rm -v '+bids_output_path+':/bids_dataset:ro -v '+mriqc_output+':/output poldracklab/mriqc:latest /bids_dataset /output participant --participant_label' + command = 'docker run -ti --rm -v '+bids_output_path+':/bids_dataset:ro -v '+mriqc_output+':/output poldracklab/mriqc:latest /bids_dataset /output participant --no-sub' os.system(command) ############ iqms, iqm_values = get_iqms(mriqc_output) iqm_dict = iqms_to_dict(iqms, iqm_values) - + header, csv_data = get_csv_contents(csv_input_path) + csv_dict = input_csv_to_dict(csv_data) - + generate_csv(csv_output_path, iqm_dict, csv_dict, header) if __name__ == '__main__': main() - \ No newline at end of file + diff --git a/mriqc/mriqc/data_loader.py b/mriqc/mriqc/data_loader.py index 4dbc406..5c4c638 100644 --- a/mriqc/mriqc/data_loader.py +++ b/mriqc/mriqc/data_loader.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Jul 25 18:22:44 2019 @@ -8,26 +8,28 @@ import pandas as pd +from girder import logger + class Data(): ''' This class is to handle all the operations related to the data in the MIQA platform. It considers the CSV file as the main input and manipulates it only, assuming that the paths to the data is static. The state of the data changes at many occasions throughout its course of the MIQA. I have tried to handle all those cases independently - + This class loads the data and maintains it as a class variable which is a Pandas dataframe. Args: csv_path: The input it takes is the absolute path to the CSV file with the data. - + ''' def __init__(self, csv_path): self.path = csv_path self.data = pd.read_csv(self.path) - + def save(self, path=None): ''' This function overwrites the current csv file with the data it has till now. - + Args: path: path to save the file Returns: @@ -37,39 +39,45 @@ def save(self, path=None): self.data.to_csv(self.path, index = False) else: self.data.to_csv(path, index = False) - + def _extract_iqm(self, iqms): ''' A helper function to extract the list of IQMs from the string format - + Args: iqms: a string containing th IQM key and value pairs Returns: iqm_vals: a list of the IQM values in float type ''' iqm_vals = [] + # logger.info('raw iqms') + # logger.info(iqms) iqms = iqms.split(';')[:-1] + # logger.info('trimmed iqms') + # logger.info(iqms) for iqm in iqms: label, val = iqm.split(':') val = float(val) iqm_vals.append(val) + # logger.info('parse iqm vals') + # logger.info(iqm_vals) return iqm_vals - + def get_features(self): ''' This function is to access the IQM featueres for the available dataset - + Args: None Returns: sub_ids: The subjects for which the IQMs are available features: a list of IQMs for each of the subjects in sub_ids ''' - + row_count = self.data.shape[0] features = [] sub_ids = [] - + for i in range(row_count): sub_id = self.data["xnat_experiment_id"][i] scan_type = self.data["scan_type"][i] @@ -78,14 +86,14 @@ def get_features(self): iqms = self._extract_iqm(iqm) sub_ids.append((sub_id,scan_type)) features.append(iqms) - + return sub_ids, features - + def set_predictions(self, sub_ids, predictions): ''' This function is to set the predictions corresponding to the subjects for which we have the predictions and update the current dataframe - + Args: sub_ids: The a list of tuples (subject, scan type) for which the IQMs are available predictions: the list of probabilities of the images being good @@ -94,19 +102,19 @@ def set_predictions(self, sub_ids, predictions): ''' if 'good_prob' not in list(self.data.columns.values): self.data['good_prob'] = None - + row_count = self.data.shape[0] for i in range(row_count): sub_id = self.data['xnat_experiment_id'][i] scan_type = self.data['scan_type'][i] if((sub_id, scan_type) in sub_ids): ind = sub_ids.index((sub_id, scan_type)) - self.data['good_prob'][i] = predictions[ind] - + self.data.loc[i, 'good_prob'] = predictions[ind] + def get_feature_and_labels(self): ''' This function is to extract the features and the labels for the data points - + Args: None Returns: @@ -118,72 +126,79 @@ def get_feature_and_labels(self): features = [] sub_ids = [] labels = [] - + for i in range(row_count): sub_id = self.data["xnat_experiment_id"][i] scan_type = self.data["scan_type"][i] iqm = self.data["IQMs"][i] + # logger.info('{0} -> {1}'.format(sub_id, iqm)) if(type(iqm)!=type(0.0)): try: + iqms = self._extract_iqm(iqm) + sub_ids.append((sub_id,scan_type)) + features.append(iqms) if(self.data['decision'][i] == 1 or self.data['decision'][i] == 2): - iqms = self._extract_iqm(iqm) - sub_ids.append((sub_id,scan_type)) - features.append(iqms) labels.append(1) elif(self.data['decision'][i] == -1): - iqms = self._extract_iqm(iqm) - sub_ids.append((sub_id,scan_type)) - features.append(iqms) labels.append(0) + else: + logger.info('decision case fall through: {0}'.format(self.data['decision'][i])) except: - print("subject", sub_id, "decision", self.data.decision[i]) - + logger.info("encountered exception processing subject {0}, decision {1}, IQMs {2}".format( + sub_id, self.data.decision[i], iqm)) + + if len(features) != len(labels): + logger.info('Well this is going to be a problem ({0} feature arrays but {1} labels)'.format(len(features), len(labels))) + return sub_ids, features, labels - + def get_possible_query_points(self): ''' This function extracts the indices of the data points which have predictions and returns a list of their indices and their predictions. - + Args: None Returns: indices: a list of the index of the data points which have predictions made. - + preds: the predictions for the image being good quality. - + features: the features of the input data, would be useful for batch-mode sampling. ''' indices = [] preds = [] features = [] row_count = self.data.shape[0] - + + # logger.info(self.data.shape) + # logger.info(self.data) + for i in range(row_count): if(self.data["good_prob"][i]!= None and type(self.data.IQMs[i])!=type(0.0)): indices.append(i) preds.append(self.data.good_prob[i]) features.append(self._extract_iqm(self.data.IQMs[i])) - + return indices, preds, features - + def extract_query_points(self, indices): ''' This function extracts the query points from the data and returns this subset dataframe with all its attributes - + Args: indices: the indices corresponding to the query points Returns: query_data: a subset of the original dataframe with all the data info ''' return self.data.loc[indices] - + def add_new_data(self, new_data): ''' This function is to add new data to the current data. basically for the main training data CSV. - + Args: new_data: a Dataframe containing the new data info Returns: @@ -193,6 +208,5 @@ def add_new_data(self, new_data): ind = ~ new_data.xnat_experiment_id.isin(self.data.xnat_experiment_id) & new_data.scan_type.isin(self.data.scan_type) new_data = new_data.drop(columns = ['good_prob']) self.data = self.data.append(new_data[ind]) - - - \ No newline at end of file + + diff --git a/mriqc/mriqc/model_nn.py b/mriqc/mriqc/model_nn.py new file mode 100644 index 0000000..20e8ef4 --- /dev/null +++ b/mriqc/mriqc/model_nn.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from girder import logger +from datetime import datetime +import os +import glob + +import random +import numpy as np + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader + +from sklearn.preprocessing import StandardScaler +from sklearn.model_selection import train_test_split +from sklearn.metrics import confusion_matrix, classification_report + +torch.manual_seed(1983) # maybe remove this later? + + +class trainData(Dataset): + def __init__(self, X_data, y_data): + self.X_data = X_data + self.y_data = y_data + + def __getitem__(self, index): + return self.X_data[index], self.y_data[index] + + def __len__(self): + return len(self.X_data) + + +class testData(Dataset): + def __init__(self, X_data): + self.X_data = X_data + + def __getitem__(self, index): + return self.X_data[index] + + def __len__(self): + return len(self.X_data) + + +class binaryClassification(nn.Module): + def __init__(self): + super(binaryClassification, self).__init__() + # First number is fixed: it is the number of input features + self.layer_1 = nn.Linear(58, 256) + self.layer_2 = nn.Linear(256, 64) + self.layer_out = nn.Linear(64, 1) + + self.relu = nn.ReLU() + self.dropout = nn.Dropout(p=0.1) + self.batchnorm1 = nn.BatchNorm1d(256) + self.batchnorm2 = nn.BatchNorm1d(64) + + def forward(self, inputs): + x = self.relu(self.layer_1(inputs)) + x = self.batchnorm1(x) + x = self.relu(self.layer_2(x)) + x = self.batchnorm2(x) + x = self.dropout(x) + x = self.layer_out(x) + + return x + + +def binary_acc(y_pred, y_test): + y_pred_tag = torch.round(torch.sigmoid(y_pred)) + + correct_results_sum = (y_pred_tag == y_test).sum().float() + acc = correct_results_sum / y_test.shape[0] + acc = torch.round(acc * 100) + + return acc + + +def debug(msg): + logger.info(msg) + print(msg) + + +class ModelNN(): + ''' + This class defines the model that we'll be using for our image quality prediction task. + Here we use neural network via PyTorch library. + ''' + + def __init__(self): + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model = binaryClassification() + self.model.to(self.device) + + def fit(self, X, y): + ''' + This function is to fit the data to the model. + + Args: + X: features of the data to be fit + y: the true labels + Returns: + None + ''' + # debug('features') + # debug(X) + # debug('labels') + # debug(y) + + train_data = trainData(torch.FloatTensor(X), torch.FloatTensor(y)) + train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True) + + # debug(self.model) + criterion = nn.BCEWithLogitsLoss() + optimizer = optim.Adam(self.model.parameters(), lr=0.001) + + self.model.train() + for e in range(0, 10): + epoch_loss = 0 + epoch_acc = 0 + for X_batch, y_batch in train_loader: + X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device) + optimizer.zero_grad() + + y_pred = self.model(X_batch) + + loss = criterion(y_pred, y_batch.unsqueeze(1)) + acc = binary_acc(y_pred, y_batch.unsqueeze(1)) + + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + epoch_acc += acc.item() + + print( + f'Epoch {e + 0:03}: | Loss: {epoch_loss / len(train_loader):.5f} | Acc: {epoch_acc / len(train_loader):.3f}') + + # save the current model + debug("Finished training") + debug(self.model) + + def predict_proba(self, X): + ''' + This function is to predict the probabilities of the input the data. + + Args: + X: data to be predicted + Returns: + prob_good: array for the probabilities of data point being good + ''' + test_data = testData(torch.FloatTensor(X)) + test_loader = DataLoader(dataset=test_data, batch_size=1) + + y_pred_list = [] + self.model.eval() + with torch.no_grad(): + for X_batch in test_loader: + X_batch = X_batch.to(self.device) + y_test_pred = self.model(X_batch) + y_test_pred = torch.sigmoid(y_test_pred) + y_pred_list.append(y_test_pred.cpu().numpy()) + + y_pred_list = [a.squeeze().tolist() for a in y_pred_list] + return y_pred_list + + def save_model(self, model_path): + ''' + This function saves the trained model in the given path + + Args: + model_path: The path where the model needs to be saved + Returns: + None + ''' + if not os.path.isdir(model_path): + os.mkdir(model_path) + + file_name = os.path.join(model_path, 'nnc_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.pth') + torch.save(self.model.state_dict(), file_name) + + def load_model(self, model_path): + ''' + This function loads the most recently saved trained model in the given path + + Args: + model_path: The path where the model needs to be saved + Returns: + None + ''' + saved_models = glob.glob(os.path.join(model_path, '*.pth')) + chckpt = max(saved_models, key=os.path.getctime) + self.model.load_state_dict(torch.load(chckpt)) diff --git a/mriqc/mriqc/model.py b/mriqc/mriqc/model_rf.py similarity index 87% rename from mriqc/mriqc/model.py rename to mriqc/mriqc/model_rf.py index 613cfb6..a7251f8 100644 --- a/mriqc/mriqc/model.py +++ b/mriqc/mriqc/model_rf.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Jul 25 17:37:09 2019 @@ -7,42 +7,49 @@ """ from sklearn.ensemble import RandomForestClassifier -from sklearn.externals import joblib +import joblib +from girder import logger from datetime import datetime import os import glob -class Model(): +class ModelRF(): ''' This class defines the model that we'll be using for our image quality prediction task. - To start with, we are using Random Forest classifier. But, this is flexible and can be + To start with, we are using Random Forest classifier. But, this is flexible and can be changed for future purposes by defining a new model and making suitable changes to the functions to make them compatible to the new model. - + Args: n_estimators: The number of classifiers in the RandomForest ''' - + def __init__(self, n_estimators=30): self.model = RandomForestClassifier(n_estimators=n_estimators, max_features=5) - + def fit(self, X, y): ''' This function is to fit the data to the model. - + Args: X: features of the data to be fit y: the true labels Returns: None ''' + # logger.info('Inside model.fit') + # logger.info('features') + # logger.info(X) + # logger.info('labels') + # logger.info(y) + self.model.fit(X, y) - + def predict_proba(self, X): ''' This function is to predict the probabilities of the input the data. - + Args: X: data to be predicted Returns: @@ -50,11 +57,11 @@ def predict_proba(self, X): ''' predictions = self.model.predict_proba(X) return predictions[:,1] - + def save_model(self, model_path): ''' This function saves the trained model in the given path - + Args: model_path: The path where the model needs to be saved Returns: @@ -62,14 +69,14 @@ def save_model(self, model_path): ''' if not os.path.isdir(model_path): os.mkdir(model_path) - - file_name = os.path.join(model_path, 'rfc_'+datetime.now().isoformat()+'.pkl') + + file_name = os.path.join(model_path, 'rfc_'+datetime.now().strftime("%Y%m%d-%H%M%S")+'.pkl') joblib.dump(self.model, file_name) - + def load_model(self, model_path): ''' This function loads the most recently saved trained model in the given path - + Args: model_path: The path where the model needs to be saved Returns: @@ -79,5 +86,4 @@ def load_model(self, model_path): chckpt = max(saved_models, key=os.path.getctime) model = joblib.load(chckpt) self.model = model - - \ No newline at end of file + diff --git a/mriqc/mriqc/process_iqms.py b/mriqc/mriqc/process_iqms.py index fe9aab8..aaaee7f 100644 --- a/mriqc/mriqc/process_iqms.py +++ b/mriqc/mriqc/process_iqms.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Fri Jun 21 10:13:52 2019 @@ -42,7 +42,6 @@ def process_json_file(json_path): 'spacing_' not in k and 'size_' not in k): iqm_values[k] = data[k] iqms.append(k) - return iqms, iqm_values @@ -238,4 +237,4 @@ def main(): if __name__ == '__main__': main() - \ No newline at end of file + diff --git a/mriqc/mriqc/strategy.py b/mriqc/mriqc/strategy.py index 337f0a6..b7b3308 100644 --- a/mriqc/mriqc/strategy.py +++ b/mriqc/mriqc/strategy.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Fri Jul 26 09:08:03 2019 diff --git a/mriqc/requirements.txt b/mriqc/requirements.txt index 8ce37c6..7db728c 100644 --- a/mriqc/requirements.txt +++ b/mriqc/requirements.txt @@ -2,4 +2,4 @@ pandas numpy scipy scikit-learn - +torch diff --git a/mriqc/setup.py b/mriqc/setup.py index b553fb3..2998419 100644 --- a/mriqc/setup.py +++ b/mriqc/setup.py @@ -8,6 +8,7 @@ 'numpy', 'scipy', 'scikit-learn', + 'torch', ], url='', author='Kitware', diff --git a/server/miqa_server/__init__.py b/server/miqa_server/__init__.py index 2612080..f61beb0 100644 --- a/server/miqa_server/__init__.py +++ b/server/miqa_server/__init__.py @@ -1,5 +1,5 @@ import datetime -from girder import events, plugin +from girder import events, logger, plugin from girder.models.user import User from girder.utility import server @@ -7,6 +7,7 @@ from .session import Session from .email import Email from .setting import SettingResource +from .learning import Learning class GirderPlugin(plugin.GirderPlugin): @@ -21,5 +22,19 @@ def load(self, info): info['apiRoot'].miqa = Session() info['apiRoot'].miqa_email = Email() info['apiRoot'].miqa_setting = SettingResource() + info['apiRoot'].learning = Learning() -server.getStaticRoot = lambda: 'static' + events.bind('jobs.job.update.after', 'active_learning', afterJobUpdate) + + +def afterJobUpdate(event): + # learned from https://github.com/girder/large_image/blob/girder-3/girder/girder_large_image/__init__.py#L83 + # logger.info('afterJobUpdate event triggered') + job = event.info['job'] + meta = job.get('meta', {}) + if (meta.get('creator') != 'miqa' or not meta.get('itemId') or + meta.get('task') != 'learning_with_data'): + # logger.info('Ignoring unknown event: {0}'.format()) + return + # logger.info('Calling Learning.afterJobUpdate()') + Learning.afterJobUpdate(job) diff --git a/server/miqa_server/conversion/csv_to_json.py b/server/miqa_server/conversion/csv_to_json.py index 8a1d72b..b159322 100644 --- a/server/miqa_server/conversion/csv_to_json.py +++ b/server/miqa_server/conversion/csv_to_json.py @@ -54,7 +54,7 @@ def csvContentToJsonObject(csvContent): nifti_folder = scan['nifti_folder'] subdir = nifti_folder.split(common_path_prefix)[1] if 'site' in scan: - site = site['scan'] + site = scan['site'] elif nifti_folder.startswith('/fs/storage/XNAT/archive/'): # Special case handling to match previous implementation splits = nifti_folder.split('/') @@ -73,6 +73,10 @@ def csvContentToJsonObject(csvContent): } if 'decision' in scan: scan_obj['decision'] = scan['decision'] + if 'IQMs' in scan: + scan_obj['iqms'] = scan['IQMs'] + if 'good_prob' in scan: + scan_obj['good_prob'] = scan['good_prob'] scans.append(scan_obj) # Build list of unique experiments diff --git a/server/miqa_server/conversion/json_to_csv.py b/server/miqa_server/conversion/json_to_csv.py new file mode 100644 index 0000000..4adf810 --- /dev/null +++ b/server/miqa_server/conversion/json_to_csv.py @@ -0,0 +1,151 @@ +import argparse +import csv +import io +import json +import os + + +def jsonObjectToCsvContent(jsonObject): + experiments = {} + for exper in jsonObject['experiments']: + experiments[exper['id']] = exper['note'] + + scans = jsonObject['scans'] + dataRoot = jsonObject['data_root'] + + rowList = [] + + optionalFields = { + 'decision': True, + 'IQMs': True, + 'scan_note': True, + 'good_prob': True + } + + for scan in scans: + expId = scan['experiment_id'] + + scanPath = scan['path'] + pathComps = scanPath.split(os.path.sep) + firstComps = pathComps[:-1] + lastComp = pathComps[-1] + + try: + splitIdx = lastComp.index('_') + except ValueError as valErr: + print('ERROR: the following scan cannot be converted to CSV row') + print(scan) + continue + + scanId = scan['id'] + parsedScanId = lastComp[:splitIdx] + if scanId != parsedScanId: + print('ERROR: expected scan id {0}, but found {1} instead'.format( + scanId, parsedScanId)) + print(scan) + continue + + scanType = scan['type'] + parsedScanType = lastComp[splitIdx + 1:] + if scanType != parsedScanType: + print('ERROR: expected scan type {0}, but found {1} instead'.format( + scanType, parsedScanType)) + print(scan) + continue + + niftiFolder = os.path.join(dataRoot, *firstComps) + + nextRow = { + 'xnat_experiment_id': expId, + 'nifti_folder': niftiFolder, + 'scan_id': scanId, + 'scan_type': scanType, + 'experiment_note': experiments[expId] + } + + if 'decision' in scan: + nextRow['decision'] = scan['decision'] + else: + optionalFields['decision'] = False + + if 'note' in scan: + nextRow['scan_note'] = scan['note'] + else: + optionalFields['note'] = False + + if 'iqms' in scan: + nextRow['IQMs'] = scan['iqms'] + else: + optionalFields['IQMs'] = False + + if 'good_prob' in scan: + nextRow['good_prob'] = scan['good_prob'] + else: + optionalFields['good_prob'] = False + + rowList.append(nextRow) + + fieldNames = [ + 'xnat_experiment_id', + 'nifti_folder', + 'scan_id', + 'scan_type', + 'experiment_note' + ] + + fieldNames.extend([key for (key, val) in optionalFields.items() if val]) + + # Now we have the dictionary representing the data and the field names, we + # can write them to the stringio object + output = io.StringIO() + writer = csv.DictWriter(output, fieldnames=fieldNames, dialect='unix') + writer.writeheader() + for row in rowList: + writer.writerow(row) + + return output + + +def jsonToCsv(jsonFilePath, csvFilePath): + print('Reading input json from {0}'.format(jsonFilePath)) + + with open(jsonFilePath) as jsonFile: + jsonObject = json.load(jsonFile) + + csvContent = jsonObjectToCsvContent(jsonObject) + + print('Writing output csv to {0}'.format(csvFilePath)) + + with open(csvFilePath, 'w') as fd: + fd.write(csvContent.getvalue()) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Convert JSON file to original CSV format') + parser.add_argument('inputfile', + type=str, + help='Absolute system path to JSON file to be converted') + parser.add_argument('--output', + type=str, + default=None, + help='Absolute systemn path where CSV output should be ' + + 'written. If not provided, the default is to write ' + + 'the output to the directory where the input file ' + + 'is located and replace the "json" extension with "csv".') + args = parser.parse_args() + + inputFilePath = args.inputfile + + if not os.path.isfile(inputFilePath): + print('Please absolute provide path to input JSON file, {0} is not valid'.format(inputFilePath)) + sys.exit(1) + + outputFilePath = args.output + + if not outputFilePath: + filePath = os.path.dirname(inputFilePath) + fileName = os.path.basename(inputFilePath) + nameWithoutExtension = fileName[:fileName.rindex('.')] + outputFilePath = os.path.join(filePath, '{0}.csv'.format(nameWithoutExtension)) + + jsonToCsv(inputFilePath, outputFilePath) diff --git a/server/miqa_server/learning.py b/server/miqa_server/learning.py new file mode 100644 index 0000000..0b11580 --- /dev/null +++ b/server/miqa_server/learning.py @@ -0,0 +1,80 @@ +import datetime + +from girder.api.rest import Resource +from girder.api import access +from girder.api.describe import Description, autoDescribeRoute +from girder import logger +from girder.models.collection import Collection +from girder.models.folder import Folder +from girder.models.item import Item +from girder.models.notification import Notification +from girder.models.file import File +from girder.models.user import User +from girder_jobs import Job +from girder_jobs.constants import JobStatus +from girder.exceptions import RestException, ValidationException +from miqa_worker_task.tasks import retrain_with_data_task +from miqa_worker_task.transform import TextToFile +from girder_worker_utils.transforms.girder_io import GirderUploadToItem, GirderClientTransform, GirderFileId + +from .util import findSessionsFolder, findTempFolder, getExportCSV, importCSV + + +class Learning(Resource): + def __init__(self): + super(Learning, self).__init__() + self.resourceName = 'learning' + + self.route('POST', ('retrain_with_data', ':learningMode'), self.retrainWithData) + + @access.admin + @autoDescribeRoute( + Description('') + .param('learningMode', 'The learning mode (either "neuralNet" or "randomForest")', required=True) + .errorResponse()) + def retrainWithData(self, learningMode, params): + # learned from https://github.com/girder/large_image/blob/girder-3/girder/girder_large_image/models/image_item.py#L108 + user = self.getCurrentUser() + item = Item().createItem('temp', user, findTempFolder(user, True)) + logger.info('learningMode = {0}'.format(learningMode)) + result = retrain_with_data_task.delay( + TextToFile(getExportCSV()), + learningMode, + girder_job_other_fields={'meta': { + 'creator': 'miqa', + 'itemId': str(item['_id']), + 'task': 'learning_with_data', + }}, + girder_result_hooks=[ + GirderUploadToItem(str(item['_id']), delete_file=False), + ] + ) + return result.job + + @staticmethod + def afterJobUpdate(job): + # logger.info('^^^^^^^^^^ inside afterJobUpdate ^^^^^^^^^^') + if job['status'] != JobStatus.SUCCESS: + # logger.info('Sadly job status is {0} instead of SUCCESS'.format(job['status'])) + # logger.info(job) + return + meta = job.get('meta', {}) + item = Item().load(meta['itemId'], force=True) + folder = Folder().load(item['folderId'], force=True) + user = User().load(job.get('userId'), force=True) + file = Item().childFiles(item)[0] + csv_content = '' + for chunk in File().download(file)(): + csv_content += chunk.decode("utf-8") + # logger.info('Here is the json_content') + # logger.info(csv_content) + # For now just reimport instead of update, should update record instead + result = importCSV(csv_content, user) + Folder().remove(folder) + # Throws an error for some reason, not really needed anyway + # Job().updateJob(job, progressMessage="session re-evaluated") + Notification().createNotification( + type='miqa.learning_with_data', + data=result, + user={'_id': job.get('userId')}, + expires=datetime.datetime.utcnow() + datetime.timedelta(seconds=30)) diff --git a/server/miqa_server/schema/data_import.py b/server/miqa_server/schema/data_import.py index 5ce46f5..a272e4e 100644 --- a/server/miqa_server/schema/data_import.py +++ b/server/miqa_server/schema/data_import.py @@ -48,6 +48,13 @@ 'items': {'type': 'string'}, }, 'imagePattern': {'type': 'string'}, + 'iqms': {'type': 'string'}, + 'good_prob': { + 'oneOf': [ + {'type': 'number'}, + {'type': 'string'}, + ], + }, }, 'oneOf': [ { 'required': ['experiment_id', 'path', 'site_id', 'id', 'type', 'images'], }, diff --git a/server/miqa_server/session.py b/server/miqa_server/session.py index 43646ac..4e11508 100644 --- a/server/miqa_server/session.py +++ b/server/miqa_server/session.py @@ -1,8 +1,6 @@ import datetime import io import json -from jsonschema import validate -from jsonschema.exceptions import ValidationError as JSONValidationError import os from girder.api.rest import Resource, setResponseHeader, setContentDisposition @@ -10,18 +8,14 @@ from girder.constants import AccessType from girder.exceptions import RestException from girder.api.describe import Description, autoDescribeRoute -from girder import logger -from girder.models.collection import Collection -from girder.models.assetstore import Assetstore from girder.models.folder import Folder from girder.models.item import Item from girder.models.setting import Setting from girder.utility.progress import noProgress -from .conversion.csv_to_json import csvContentToJsonObject -from .setting import fileWritable, tryAddSites +from .setting import fileWritable from .constants import exportpathKey, importpathKey -from .schema.data_import import schema +from .util import findSessionsFolder, getExportJSON, importData class Session(Resource): @@ -42,7 +36,7 @@ def getSessions(self, params): def _getSessions(self): user = self.getCurrentUser() - sessionsFolder = self.findSessionsFolder() + sessionsFolder = findSessionsFolder() if not sessionsFolder: return [] experiments = [] @@ -99,127 +93,12 @@ def _getSessions(self): Description('') .errorResponse()) def dataImport(self, params): - user = self.getCurrentUser() importpath = os.path.expanduser(Setting().get(importpathKey)) if not os.path.isfile(importpath): raise RestException('import path does not exist ({0}'.format(importpath), code=404) - json_content = None - - if importpath.endswith('.csv'): - with open(importpath) as fd: - csv_content = fd.read() - try: - json_content = csvContentToJsonObject(csv_content) - validate(json_content, schema) - except (JSONValidationError, Exception) as inst: - return { - "error": 'Invalid CSV file: {0}'.format(inst.message), - "success": successCount, - "failed": failedCount - } - else: - with open(importpath) as json_file: - json_content = json.load(json_file) - try: - validate(json_content, schema) - except JSONValidationError as inst: - return { - "error": 'Invalid JSON file: {0}'.format(inst.message), - "success": successCount, - "failed": failedCount - } - - existingSessionsFolder = self.findSessionsFolder(user) - if existingSessionsFolder: - existingSessionsFolder['name'] = 'sessions_' + \ - datetime.datetime.now().strftime("%Y-%m-%d %I:%M:%S %p") - Folder().save(existingSessionsFolder) - sessionsFolder = self.findSessionsFolder(user, True) - Item().createItem('json', user, sessionsFolder, description=json.dumps(json_content)) + return importData(importpath, self.getCurrentUser()) - datasetRoot = json_content['data_root'] - experiments = json_content['experiments'] - sites = json_content['sites'] - - successCount = 0 - failedCount = 0 - sites = set() - for scan in json_content['scans']: - experimentId = scan['experiment_id'] - experimentNote = '' - for experiment in experiments: - if experiment['id'] == experimentId: - experimentNote = experiment['note'] - scanPath = scan['path'] - site = scan['site_id'] - sites.add(site) - scanId = scan['id'] - scanType = scan['type'] - scanName = scanId+'_'+scanType - niftiFolder = os.path.expanduser(os.path.join(datasetRoot, scanPath)) - if not os.path.isdir(niftiFolder): - failedCount += 1 - continue - experimentFolder = Folder().createFolder( - sessionsFolder, experimentId, parentType='folder', reuseExisting=True) - scanFolder = Folder().createFolder( - experimentFolder, scanName, parentType='folder', reuseExisting=True) - meta = { - 'experimentId': experimentId, - 'experimentNote': experimentNote, - 'site': site, - 'scanId': scanId, - 'scanType': scanType - } - # Merge note and rating if record exists - if existingSessionsFolder: - existingMeta = self.tryGetExistingSessionMeta( - existingSessionsFolder, experimentId, scanName) - if(existingMeta and (existingMeta.get('note', None) or existingMeta.get('rating', None))): - meta['note'] = existingMeta.get('note', None) - meta['rating'] = existingMeta.get('rating', None) - Folder().setMetadata(scanFolder, meta) - currentAssetstore = Assetstore().getCurrent() - if 'images' in scan: - scanImages = scan['images'] - # Import images one at a time because the user provided a list - for scanImage in scanImages: - absImagePath = os.path.join(niftiFolder, scanImage) - Assetstore().importData( - currentAssetstore, parent=scanFolder, parentType='folder', params={ - 'fileIncludeRegex': '^{0}$'.format(scanImage), - 'importPath': niftiFolder, - }, progress=noProgress, user=user, leafFoldersAsItems=False) - imageOrderDescription = { - 'orderDescription': { - 'images': scanImages - } - } - else: - scanImagePattern = scan['imagePattern'] - # Import all images in directory at once because user provide a file pattern - Assetstore().importData( - currentAssetstore, parent=scanFolder, parentType='folder', params={ - 'fileIncludeRegex': scanImagePattern, - 'importPath': niftiFolder, - }, progress=noProgress, user=user, leafFoldersAsItems=False) - imageOrderDescription = { - 'orderDescription': { - 'imagePattern': scanImagePattern - } - } - Item().createItem(name='imageOrderDescription', - creator=user, - folder=scanFolder, - reuseExisting=True, - description=json.dumps(imageOrderDescription)) - successCount += 1 - tryAddSites(sites, self.getCurrentUser()) - return { - "success": successCount, - "failed": failedCount - } @access.user @autoDescribeRoute( @@ -229,63 +108,6 @@ def dataExport(self, params): exportpath = os.path.expanduser(Setting().get(exportpathKey)) if not fileWritable(exportpath): raise RestException('export json file is not writable', code=500) - output = self.getExportJSON() + output = getExportJSON() with open(exportpath, 'w') as json_file: json_file.write(output) - - def getExportJSON(self): - def convertRatingToDecision(rating): - return { - None: 0, - 'questionable': 0, - 'good': 1, - 'usableExtra': 2, - 'bad': -1 - }[rating] - sessionsFolder = self.findSessionsFolder() - items = list(Folder().childItems(sessionsFolder, filters={'name': 'json'})) - if not len(items): - raise RestException('doesn\'t contain a json item', code=404) - jsonItem = items[0] - # Next TODO: read, format, and stream back the json version of the export - logger.info(jsonItem) - original_json_object = json.loads(jsonItem['description']) - - for scan in original_json_object['scans']: - experiment = Folder().findOne({ - 'name': scan['experiment_id'], - 'parentId': sessionsFolder['_id'] - }) - if not experiment: - continue - session = Folder().findOne({ - 'name': '{0}_{1}'.format(scan['id'], scan['type']), - 'parentId': experiment['_id'] - }) - if not session: - continue - scan['decision'] = convertRatingToDecision(session.get('meta', {}).get('rating', None)) - scan['note'] = session.get('meta', {}).get('note', None) - - return json.dumps(original_json_object) - - def findSessionsFolder(self, user=None, create=False): - collection = Collection().findOne({'name': 'miqa'}) - sessionsFolder = Folder().findOne({'name': 'sessions', 'baseParentId': collection['_id']}) - if not create: - return sessionsFolder - else: - if not sessionsFolder: - return Folder().createFolder(collection, 'sessions', - parentType='collection', creator=user) - - def tryGetExistingSessionMeta(self, sessionsFolder, experimentId, scan): - experimentFolder = Folder().findOne( - {'name': experimentId, 'parentId': sessionsFolder['_id']}) - if not experimentFolder: - return None - sessionFolder = Folder().findOne( - {'name': scan, 'parentId': experimentFolder['_id']}) - if not sessionFolder: - return None - return sessionFolder.get('meta', {}) diff --git a/server/miqa_server/util.py b/server/miqa_server/util.py new file mode 100644 index 0000000..b45d709 --- /dev/null +++ b/server/miqa_server/util.py @@ -0,0 +1,259 @@ +import datetime +import json +from jsonschema import validate +from jsonschema.exceptions import ValidationError as JSONValidationError +import os + +from girder.exceptions import RestException +from girder import logger +from girder.models.assetstore import Assetstore +from girder.models.collection import Collection +from girder.models.folder import Folder +from girder.models.item import Item +from girder.utility.progress import noProgress + +from .conversion.csv_to_json import csvContentToJsonObject +from .conversion.json_to_csv import jsonObjectToCsvContent +from .setting import tryAddSites +from .schema.data_import import schema + + +def parseIQM(iqm): + if not iqm: + return None + rows = iqm.split(';') + metrics = [] + for row in rows: + if not row: + continue + [key, value] = row.split(':') + value = float(value) + elements = key.split('_') + if len(elements) == 1: + metrics.append({key: value}) + else: + type_ = '_'.join(elements[:-1]) + subType = elements[-1] + if(list(metrics[-1].keys())[0]) != type_: + metrics.append({type_: []}) + subTypes = metrics[-1][type_] + subTypes.append({subType: value}) + return metrics + + +def findFolder(name, user=None, create=False): + collection = Collection().findOne({'name': 'miqa'}) + folder = Folder().findOne({'name': name, 'baseParentId': collection['_id']}) + if not create: + return folder + elif not folder: + return Folder().createFolder(collection, name, + parentType='collection', creator=user) + else: + return folder + + +def findSessionsFolder(user=None, create=False): + return findFolder('sessions', user, create) + + +def findTempFolder(user=None, create=False): + return findFolder('temp', user, create) + + +def tryGetExistingSessionMeta(sessionsFolder, experimentId, scan): + experimentFolder = Folder().findOne( + {'name': experimentId, 'parentId': sessionsFolder['_id']}) + if not experimentFolder: + return None + sessionFolder = Folder().findOne( + {'name': scan, 'parentId': experimentFolder['_id']}) + if not sessionFolder: + return None + return sessionFolder.get('meta', {}) + + +def importJson(json_content, user): + existingSessionsFolder = findSessionsFolder(user) + if existingSessionsFolder: + existingSessionsFolder['name'] = 'sessions_' + \ + datetime.datetime.now().strftime("%Y-%m-%d %I:%M:%S %p") + Folder().save(existingSessionsFolder) + sessionsFolder = findSessionsFolder(user, True) + Item().createItem('json', user, sessionsFolder, description=json.dumps(json_content)) + + datasetRoot = json_content['data_root'] + experiments = json_content['experiments'] + sites = json_content['sites'] + + successCount = 0 + failedCount = 0 + sites = set() + for scan in json_content['scans']: + experimentId = scan['experiment_id'] + experimentNote = '' + for experiment in experiments: + if experiment['id'] == experimentId: + experimentNote = experiment['note'] + scanPath = scan['path'] + site = scan['site_id'] + sites.add(site) + scanId = scan['id'] + scanType = scan['type'] + scanName = scanId+'_'+scanType + niftiFolder = os.path.expanduser(os.path.join(datasetRoot, scanPath)) + if not os.path.isdir(niftiFolder): + failedCount += 1 + continue + experimentFolder = Folder().createFolder( + sessionsFolder, experimentId, parentType='folder', reuseExisting=True) + scanFolder = Folder().createFolder( + experimentFolder, scanName, parentType='folder', reuseExisting=True) + meta = { + 'experimentId': experimentId, + 'experimentNote': experimentNote, + 'site': site, + 'scanId': scanId, + 'scanType': scanType + } + # Merge note and rating if record exists + if existingSessionsFolder: + existingMeta = tryGetExistingSessionMeta( + existingSessionsFolder, experimentId, scanName) + if(existingMeta and (existingMeta.get('note', None) or existingMeta.get('rating', None))): + meta['note'] = existingMeta.get('note', None) + meta['rating'] = existingMeta.get('rating', None) + Folder().setMetadata(scanFolder, meta) + currentAssetstore = Assetstore().getCurrent() + if 'images' in scan: + scanImages = scan['images'] + # Import images one at a time because the user provided a list + for scanImage in scanImages: + absImagePath = os.path.join(niftiFolder, scanImage) + Assetstore().importData( + currentAssetstore, parent=scanFolder, parentType='folder', params={ + 'fileIncludeRegex': '^{0}$'.format(scanImage), + 'importPath': niftiFolder, + }, progress=noProgress, user=user, leafFoldersAsItems=False) + imageOrderDescription = { + 'orderDescription': { + 'images': scanImages + } + } + else: + scanImagePattern = scan['imagePattern'] + # Import all images in directory at once because user provide a file pattern + Assetstore().importData( + currentAssetstore, parent=scanFolder, parentType='folder', params={ + 'fileIncludeRegex': scanImagePattern, + 'importPath': niftiFolder, + }, progress=noProgress, user=user, leafFoldersAsItems=False) + imageOrderDescription = { + 'orderDescription': { + 'imagePattern': scanImagePattern + } + } + Item().createItem(name='imageOrderDescription', + creator=user, + folder=scanFolder, + reuseExisting=True, + description=json.dumps(imageOrderDescription)) + itemMeta = {} + iqm = parseIQM(scan['iqms']) + if iqm: + itemMeta['iqm'] = iqm + good_prob = None + try: + good_prob = float(scan['good_prob']) + except: + pass + if good_prob: + itemMeta['goodProb'] = good_prob + if itemMeta: + item = list(Folder().childItems(scanFolder, limit=1))[0] + Item().setMetadata(item, itemMeta, allowNull=True) + successCount += 1 + + tryAddSites(sites, user) + + return { + "success": successCount, + "failed": failedCount + } + + +def importCSV(csv_content, user): + try: + json_content = csvContentToJsonObject(csv_content) + validate(json_content, schema) + except (JSONValidationError, Exception) as inst: + return { + "error": 'Invalid CSV file: {0}'.format(inst.message), + } + + return importJson(json_content, user) + + +def importData(importpath, user): + json_content = None + + if importpath.endswith('.csv'): + with open(importpath) as fd: + return importCSV(fd.read(), user) + else: + with open(importpath) as json_file: + json_content = json.load(json_file) + + try: + validate(json_content, schema) + except JSONValidationError as inst: + return { + "error": 'Invalid JSON file: {0}'.format(inst.message), + } + + return importJson(json_content, user) + + +def getExportJSONObject(): + def convertRatingToDecision(rating): + return { + None: 0, + 'questionable': 0, + 'good': 1, + 'usableExtra': 2, + 'bad': -1 + }[rating] + sessionsFolder = findSessionsFolder() + items = list(Folder().childItems(sessionsFolder, filters={'name': 'json'})) + if not len(items): + raise RestException('doesn\'t contain a json item', code=404) + jsonItem = items[0] + # Next TODO: read, format, and stream back the json version of the export + # logger.info(jsonItem) + original_json_object = json.loads(jsonItem['description']) + + for scan in original_json_object['scans']: + experiment = Folder().findOne({ + 'name': scan['experiment_id'], + 'parentId': sessionsFolder['_id'] + }) + if not experiment: + continue + session = Folder().findOne({ + 'name': '{0}_{1}'.format(scan['id'], scan['type']), + 'parentId': experiment['_id'] + }) + if not session: + continue + scan['decision'] = convertRatingToDecision(session.get('meta', {}).get('rating', None)) + scan['note'] = session.get('meta', {}).get('note', None) + + return original_json_object + + +def getExportJSON(): + return json.dumps(getExportJSONObject()) + + +def getExportCSV(): + return jsonObjectToCsvContent(getExportJSONObject())