diff --git a/openwakeword/utils.py b/openwakeword/utils.py index c032c5a..8da8048 100644 --- a/openwakeword/utils.py +++ b/openwakeword/utils.py @@ -269,8 +269,9 @@ def _get_melspectrogram_batch(self, x, batch_size=128, ncpu=1): result = self._get_melspectrogram(batch) elif pool: + chunksize = batch.shape[0]//ncpu if batch.shape[0] >= ncpu else 1 result = np.array(pool.map(self._get_melspectrogram, - batch, chunksize=batch.shape[0]//ncpu)) + batch, chunksize=chunksize)) melspecs[i:i+batch_size, :, :] = result.squeeze() @@ -330,8 +331,9 @@ def _get_embeddings_batch(self, x, batch_size=128, ncpu=1): result = self.embedding_model_predict(batch) elif pool: + chunksize = batch.shape[0]//ncpu if batch.shape[0] >= ncpu else 1 result = np.array(pool.map(self._get_embeddings_from_melspec, - batch, chunksize=batch.shape[0]//ncpu)) + batch, chunksize=chunksize)) for j, ndx2 in zip(range(0, result.shape[0], n_frames), ndcs): embeddings[ndx2, :, :] = result[j:j+n_frames]