Skip to content

Commit

Permalink
Allow coremltools to be import even if kmeans1d is broken. Also add b…
Browse files Browse the repository at this point in the history
…ack sklearn parameters
  • Loading branch information
TobyRoseman committed Jun 5, 2023
1 parent da74098 commit b7375e2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 12 deletions.
8 changes: 7 additions & 1 deletion coremltools/_deps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@

from packaging import version

from . import kmeans1d as _kmeans1d
from coremltools import _logger as logger

_HAS_KMEANS1D = True
try:
from . import kmeans1d as _kmeans1d
except:
_kmeans1d = None
_HAS_KMEANS1D = False


def _get_version(version):
# matching 1.6.1, and 1.6.1rc, 1.6.1.dev
Expand Down
37 changes: 26 additions & 11 deletions coremltools/models/neural_network/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@

import numpy as _np

from coremltools import ComputeUnit as _ComputeUnit
from coremltools._deps import kmeans1d as _kmeans1d
from coremltools import (
ComputeUnit as _ComputeUnit,
_logger
)
from coremltools._deps import (
_HAS_KMEANS1D,
_kmeans1d
)
from coremltools.models import (
_QUANTIZATION_MODE_CUSTOM_LOOKUP_TABLE,
_QUANTIZATION_MODE_DEQUANTIZE,
Expand Down Expand Up @@ -383,7 +389,17 @@ def _get_kmeans_lookup_table_and_weight(nbits, w, force_kmeans1d=False):
wf = w.reshape(-1, 1)
lut = _np.zeros(lut_len)

if not force_kmeans1d and (num_weights < 10_000 or w.dtype != _np.float16):
is_better_to_use_kmeans1d = (num_weights >= 10_000 and w.dtype == _np.float16)

if (is_better_to_use_kmeans1d and _HAS_KMEANS1D) or force_kmeans1d:
# Cluster with kmeans1d
assert(_HAS_KMEANS1D)
values, indices, counts = _np.unique(wf, return_inverse=True, return_counts=True)
n_clusters = min(len(values), lut_len)
kmeans_results = _kmeans1d.cluster(values, n_clusters, weights=counts)
lut[:n_clusters] = kmeans_results.centroids
wq = _np.array(kmeans_results.clusters)[indices]
else:
# Cluster with scikit-learn
try:
from sklearn.cluster import KMeans
Expand All @@ -393,17 +409,16 @@ def _get_kmeans_lookup_table_and_weight(nbits, w, force_kmeans1d=False):
" To install, run: \"pip install scikit-learn\"."
)

if is_better_to_use_kmeans1d:
_logger.warning("It would be better to use kmeans1d but that is not available."
" Using scikit-learn for K-means.")

n_clusters = min(num_weights, lut_len)
kmeans = KMeans(n_clusters).fit(wf)
kmeans = KMeans(
n_clusters, init="k-means++", tol=1e-2, n_init=1, random_state=0
).fit(wf)
wq = kmeans.labels_[:num_weights]
lut[:n_clusters] = kmeans.cluster_centers_.flatten()
else:
# Cluster with kmeans1d
values, indices, counts = _np.unique(wf, return_inverse=True, return_counts=True)
n_clusters = min(len(values), lut_len)
kmeans_results = _kmeans1d.cluster(values, n_clusters, weights=counts)
lut[:n_clusters] = kmeans_results.centroids
wq = _np.array(kmeans_results.clusters)[indices]

return lut, wq

Expand Down

0 comments on commit b7375e2

Please sign in to comment.