From 730c390ff1ef64758953dbc7f9e25b7facd595f1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Mar 2016 16:14:25 +0100 Subject: [PATCH] WIP: CLI tool --- klusta/launch.py | 63 +++++++++++++++++++++++++++++++++++++++++++++--- setup.py | 10 ++++---- 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/klusta/launch.py b/klusta/launch.py index 59c56b7..37e61a1 100644 --- a/klusta/launch.py +++ b/klusta/launch.py @@ -11,6 +11,7 @@ from pprint import pformat import shutil +import click import numpy as np from .traces import SpikeDetekt @@ -18,6 +19,7 @@ from .kwik.model import KwikModel from .klustakwik import klustakwik from .utils import _ensure_dir_exists +from .__init__ import __version_git__ logger = logging.getLogger(__name__) @@ -123,15 +125,28 @@ def on_iter(sc): def klusta(prm_file, output_dir=None, - do_detect=True, - do_cluster=True, interval=None, - channel_group=None): + channel_group=None, + detect_only=False, + cluster_only=False, + ): + # Detection and/or clustering. + do_detect = not cluster_only + do_cluster = not detect_only + assert do_detect or do_cluster + + # Parse the interval (pair of floats). + # if isinstance(interval, six.string_types): + if interval: + # interval = tuple(map(float, interval.split(','))) + assert 0 <= interval[0] < interval[1] + # Ensure the kwik file exists. Doesn't overwrite it if it exists. kwik_path = create_kwik(prm_file=prm_file, output_dir=output_dir) # Detection. if do_detect: + logger.info("Starting spike detection.") # NOTE: always detect on all shanks. model = KwikModel(kwik_path) out = detect(model, interval=interval) @@ -140,6 +155,7 @@ def klusta(prm_file, # Add the spikes to the kwik file. creator = KwikCreator(kwik_path) creator.add_spikes_after_detection(out) + logger.info("Spike detection done!") # List of channel groups. if channel_group is None: @@ -151,6 +167,8 @@ def klusta(prm_file, if do_cluster: # Cluster every channel group. for channel_group in channel_groups: + logger.info("Starting clustering on shank %d/%d.", + channel_group, len(channel_groups)) model = KwikModel(kwik_path, channel_group=channel_group) logger.info("Clustering group %d (%d spikes).", channel_group, model.n_spikes) @@ -166,5 +184,44 @@ def klusta(prm_file, model.copy_clustering('main', 'original') model.clustering_metadata.update(metadata) model.close() + logger.info("Clustering done!") return kwik_path + + +@click.command() +@click.argument('prm_file', + type=click.Path(exists=True, file_okay=True, dir_okay=False), + ) +@click.option('--output-dir', + help='Output directory.') +@click.option('--interval', + type=click.Tuple([float, float]), + help='Interval in seconds, e.g. `--interval 0 2`.') +@click.option('--channel-group', + type=click.INT, + help='Channel group to cluster (all by default).') +@click.option('--detect-only', + help='Only do spike detection.', + is_flag=True) +@click.option('--cluster-only', + help='Only do automatic clustering.', + is_flag=True) +@click.version_option(version=__version_git__) +@click.help_option() +def main(*args, **kwargs): + """Spikesort a dataset. + + By default, perform spike detection (with SpikeDetekt) and automatic + clustering (with KlustaKwik2). You can also choose to run only one step. + + You need to specify three pieces of information to spikesort your data: + + * The raw data file: typically a `.dat` file. + + * The PRM file: a Python file with the `.prm` extension, containing the parameters for your sorting session. + + * The PRB file: a Python file with the `.prb` extension, containing the layout of your probe. + + """ # noqa + return klusta(*args, **kwargs) diff --git a/setup.py b/setup.py index 55e11c4..8f0a1c5 100644 --- a/setup.py +++ b/setup.py @@ -51,11 +51,11 @@ def _package_tree(pkgroot): package_data={ 'klusta': ['*.txt', '*.prb'], }, - # entry_points={ - # 'console_scripts': [ - # 'klusta = klusta.utils.cli:klusta' - # ], - # }, + entry_points={ + 'console_scripts': [ + 'klusta = klusta.launch:main' + ], + }, include_package_data=True, keywords='klusta,neuroscience,spike sorting,klustakwik', classifiers=[