-
Notifications
You must be signed in to change notification settings - Fork 1
/
ca-taudata.py
executable file
·151 lines (117 loc) · 4.11 KB
/
ca-taudata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python
# coding=utf-8
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
import os
import argparse
import nixio as nix
import numpy as np
import pandas as pd
import sys
from ca.nix import item_of_type
from collections import defaultdict
def get_related_metadata(da, metadata, nf):
md = da.metadata
if md is None:
md = nf.sections[da.name[:9]]
return md[metadata]
def get_condition(da, nf):
md = get_related_metadata(da, 'condition', nf)
return md.lower()
def get_age(da, nf):
md = get_related_metadata(da, 'age', nf)
return md.lower()
def filter_pulses(block, pulse):
return [da for da in block.data_arrays if pulse in da.name]
def group_by_name(result, x):
a_name = result[-1][-1].name[:9]
b_name = x.name[:9]
if a_name == b_name:
result[-1].append(x)
else:
result.append([x])
return result
def mk_sorter_by_age(nf):
def _get_age(list_of_dataarrays):
ld = list_of_dataarrays
da = ld[0]
return get_age(da, nf)
return _get_age
def da_name_get_number(da):
name = da.name
return name.split('.')[2]
def read_exclude(path):
lines = []
with open(os.path.expanduser(path)) as pf:
lines = pf.readlines()
lines = list(map(lambda x: x.strip(), lines))
excludes = defaultdict(list)
for l in lines:
x = l.split(',')
ex = []
if len(x) > 1:
ex = x[1:]
excludes[x[0]] = ex
return excludes
def mk_filter_excludes(names):
def _filter(g):
name = g[0].name[:9]
return name not in names or len(names[name])
return _filter
def mk_condition_name(name):
return {'control': 'CT',
'noisebox': 'NB'}[name.lower()]
def main():
parser = argparse.ArgumentParser(description="")
parser.add_argument('--style', nargs='*', type=str, default=['ck'])
parser.add_argument('--pulse', default='ap25')
parser.add_argument('--age', default=None, type=int)
parser.add_argument('--exclude', default='~/Data/exludes_tau.csv')
parser.add_argument("file")
args = parser.parse_args()
excludes = read_exclude(args.exclude)
print('%d excludes read!' % len(excludes), file=sys.stderr)
nf = nix.File.open(args.file, nix.FileMode.ReadOnly)
data = item_of_type(nf.blocks, "dff.mean")
images = filter_pulses(data, args.pulse)
images = sorted(images, key=lambda x: x.name[:9])
grouped = reduce(group_by_name, images[1:], [[images[0]]])
# print(len(set(map(lambda g: g[0].name[:9], grouped))), file=sys.stderr)
grouped_filtered = list(filter(mk_filter_excludes(excludes), grouped))
n_removed = len(grouped) - len(grouped_filtered)
print('removed %d neuros' % (n_removed),
file=sys.stderr)
grouped_sorted = sorted(grouped_filtered, key=mk_sorter_by_age(nf))
lens = [imgs[0].shape[0] for imgs in grouped_sorted]
alldata = np.empty((max(lens), sum([len(g) for g in grouped_sorted])))
alldata[:] = np.NAN
print("alldata matrix: %s" % str(alldata.shape), file=sys.stderr)
count = 0
names = []
for g in grouped_sorted:
first_da = g[0]
neuron_name = first_da.name[:9]
age = get_age(first_da, nf)
condition = mk_condition_name(get_condition(first_da, nf))
# print("- %s (%s, %s)" % (neuron_name, age, condition), file=sys.stderr)
imgs = sorted(g, key=da_name_get_number)
exlist = excludes.get(neuron_name, [])
rep = 1
for i, img in enumerate(imgs):
if str(i+1) in exlist:
print("\t - %s excluded" % img.name, file=sys.stderr)
continue
name = "%s_p%s_%s_%d" % (neuron_name, age, condition, rep)
print("\t - [%s -> %s]" % (name, img.name), file=sys.stderr)
alldata[:, count] = np.array(img[:])
count += 1
rep += 1
names.append(name)
outfile = sys.stdout
# output data
print(",".join(names), file=outfile)
for row in alldata:
print(",".join(map(str, row)), file=outfile)
if __name__ == '__main__':
main()