Skip to content

Commit

Permalink
Merge pull request #64 from cokelaer/master
Browse files Browse the repository at this point in the history
Fix progress bar (threading with joblib/tqdm issue)
  • Loading branch information
cokelaer authored Sep 14, 2022
2 parents 1493b5f + 5b4bf54 commit 8910276
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ on:
jobs:
build-n-publish:
name: Build and publish to PyPI and TestPyPI
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@master
- uses: actions/checkout@main
- name: Set up Python 3.7
uses: actions/setup-python@v1
with:
Expand All @@ -26,14 +26,14 @@ jobs:
python setup.py sdist
- name: Publish distribution to Test PyPI
uses: pypa/gh-action-pypi-publish@master
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
repository_url: https://test.pypi.org/legacy/
- name: Publish distribution to PyPI
if: startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@master
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
55 changes: 44 additions & 11 deletions src/fitter/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@
.. sectionauthor:: Thomas Cokelaer, Aug 2014-2020
"""
import logging
import sys
import contextlib
import threading
from datetime import datetime

import logging

import numpy as np
import pandas as pd
import pylab
import scipy.stats

import joblib
from joblib.parallel import Parallel, delayed
from tqdm import tqdm
from scipy.stats import entropy as kl_div, kstest

Expand All @@ -35,6 +40,40 @@
__all__ = ["get_common_distributions", "get_distributions", "Fitter"]



# A solution to wrap joblib parallel call in tqdm from
# https://stackoverflow.com/questions/24983493/tracking-progress-of-joblib-parallel-execution/58936697#58936697
# and https://github.com/louisabraham/tqdm_joblib
@contextlib.contextmanager
def tqdm_joblib(*args, **kwargs):
"""Context manager to patch joblib to report into tqdm progress bar
given as argument"""

tqdm_object = tqdm(*args, **kwargs)

class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __call__(self, *args, **kwargs):
tqdm_object.update(n=self.batch_size)
return super().__call__(*args, **kwargs)

old_batch_callback = joblib.parallel.BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
try:
yield tqdm_object
finally:
joblib.parallel.BatchCompletionCallBack = old_batch_callback
tqdm_object.close()








def get_distributions():
distributions = []
for this in dir(scipy.stats):
Expand Down Expand Up @@ -309,11 +348,8 @@ def _fit_single_distribution(self, distribution):
self._aic[distribution] = np.inf
self._bic[distribution] = np.inf
self._kldiv[distribution] = np.inf
#if srogress:
# self._fit_i += 1
# #self.pb.animate(self._fit_i)

def fit(self, progress=False, n_jobs=-1):
def fit(self, progress=False, n_jobs=-1, max_workers=-1):
r"""Loop over distributions and find best parameter to fit the data for each
When a distribution is fitted onto the data, we populate a set of
Expand All @@ -331,13 +367,10 @@ def fit(self, progress=False, n_jobs=-1):

warnings.filterwarnings("ignore", category=RuntimeWarning)

from tqdm.contrib.concurrent import thread_map

result = thread_map(self._fit_single_distribution, self.distributions, max_workers=4, disable=not progress)
N = len(self.distributions)
with tqdm_joblib(desc=f"Fitting {N} distributions", total=N) as progress_bar:
Parallel(n_jobs=max_workers, backend='threading')(delayed(self._fit_single_distribution)(dist) for dist in self.distributions)

#jobs = (delayed(self._fit_single_distribution)(dist, progress) for dist in self.distributions)
#pool = Parallel(n_jobs=n_jobs, backend="threading")
#_ = pool(jobs)

self.df_errors = pd.DataFrame(
{
Expand Down

0 comments on commit 8910276

Please sign in to comment.