Skip to content

Commit

Permalink
WIP: CLI tool
Browse files Browse the repository at this point in the history
  • Loading branch information
rossant committed Mar 19, 2016
1 parent b33638b commit 730c390
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
63 changes: 60 additions & 3 deletions klusta/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from pprint import pformat
import shutil

import click
import numpy as np

from .traces import SpikeDetekt
from .kwik.creator import create_kwik, KwikCreator
from .kwik.model import KwikModel
from .klustakwik import klustakwik
from .utils import _ensure_dir_exists
from .__init__ import __version_git__

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,15 +125,28 @@ def on_iter(sc):

def klusta(prm_file,
output_dir=None,
do_detect=True,
do_cluster=True,
interval=None,
channel_group=None):
channel_group=None,
detect_only=False,
cluster_only=False,
):
# Detection and/or clustering.
do_detect = not cluster_only
do_cluster = not detect_only
assert do_detect or do_cluster

# Parse the interval (pair of floats).
# if isinstance(interval, six.string_types):
if interval:
# interval = tuple(map(float, interval.split(',')))
assert 0 <= interval[0] < interval[1]

# Ensure the kwik file exists. Doesn't overwrite it if it exists.
kwik_path = create_kwik(prm_file=prm_file, output_dir=output_dir)

# Detection.
if do_detect:
logger.info("Starting spike detection.")
# NOTE: always detect on all shanks.
model = KwikModel(kwik_path)
out = detect(model, interval=interval)
Expand All @@ -140,6 +155,7 @@ def klusta(prm_file,
# Add the spikes to the kwik file.
creator = KwikCreator(kwik_path)
creator.add_spikes_after_detection(out)
logger.info("Spike detection done!")

# List of channel groups.
if channel_group is None:
Expand All @@ -151,6 +167,8 @@ def klusta(prm_file,
if do_cluster:
# Cluster every channel group.
for channel_group in channel_groups:
logger.info("Starting clustering on shank %d/%d.",
channel_group, len(channel_groups))
model = KwikModel(kwik_path, channel_group=channel_group)
logger.info("Clustering group %d (%d spikes).",
channel_group, model.n_spikes)
Expand All @@ -166,5 +184,44 @@ def klusta(prm_file,
model.copy_clustering('main', 'original')
model.clustering_metadata.update(metadata)
model.close()
logger.info("Clustering done!")

return kwik_path


@click.command()
@click.argument('prm_file',
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
@click.option('--output-dir',
help='Output directory.')
@click.option('--interval',
type=click.Tuple([float, float]),
help='Interval in seconds, e.g. `--interval 0 2`.')
@click.option('--channel-group',
type=click.INT,
help='Channel group to cluster (all by default).')
@click.option('--detect-only',
help='Only do spike detection.',
is_flag=True)
@click.option('--cluster-only',
help='Only do automatic clustering.',
is_flag=True)
@click.version_option(version=__version_git__)
@click.help_option()
def main(*args, **kwargs):
"""Spikesort a dataset.
By default, perform spike detection (with SpikeDetekt) and automatic
clustering (with KlustaKwik2). You can also choose to run only one step.
You need to specify three pieces of information to spikesort your data:
* The raw data file: typically a `.dat` file.
* The PRM file: a Python file with the `.prm` extension, containing the parameters for your sorting session.
* The PRB file: a Python file with the `.prb` extension, containing the layout of your probe.
""" # noqa
return klusta(*args, **kwargs)
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def _package_tree(pkgroot):
package_data={
'klusta': ['*.txt', '*.prb'],
},
# entry_points={
# 'console_scripts': [
# 'klusta = klusta.utils.cli:klusta'
# ],
# },
entry_points={
'console_scripts': [
'klusta = klusta.launch:main'
],
},
include_package_data=True,
keywords='klusta,neuroscience,spike sorting,klustakwik',
classifiers=[
Expand Down

0 comments on commit 730c390

Please sign in to comment.