Skip to content

Commit

Permalink
Merge pull request #1 from settylab/dev
Browse files Browse the repository at this point in the history
Version 1.1.0
 * numeric stabilization of nn-distribution
 * logging accessibility
 * improved warnings
 * cleaned base estimator class
  • Loading branch information
katosh authored Mar 16, 2023
2 parents 70967e4 + 28e423d commit 6460152
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 88 deletions.
14 changes: 13 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,27 @@
#
import os
import sys
from pathlib import Path

sys.path.insert(0, os.path.abspath('../..'))

this_directory = Path(__file__).parent

def get_version(rel_path):
for line in (this_directory / rel_path).read_text().splitlines():
if line.startswith('__version__'):
delim = '"' if '"' in line else "'"
return line.split(delim)[1]
else:
raise RuntimeError("Unable to find version string.")
# -- Project information -----------------------------------------------------

project = 'Mellon'
copyright = '2022, Setty Lab'
author = 'Setty Lab'

# The full version, including alpha/beta/rc tags
release = '1.0.2'
release = get_version('../../mellon/__init__.py')


# -- General configuration ---------------------------------------------------
Expand Down
9 changes: 3 additions & 6 deletions mellon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import logging

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

from jax.config import config as jaxconfig

jaxconfig.update("jax_enable_x64", True)
jaxconfig.update("jax_platform_name", "cpu")

__version__ = "1.1.0"

from .base_cov import Covariance
from .util import stabilize, mle, distance
from .util import stabilize, mle, distance, Log
from .cov import Matern32, Matern52, ExpQuad, Exponential, RatQuad
from .inference import (
compute_transform,
Expand Down
27 changes: 6 additions & 21 deletions mellon/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from jax.numpy import sum as arraysum
from jax.numpy.linalg import eigh, cholesky, qr
from jax.scipy.linalg import solve_triangular
from .util import stabilize, DEFAULT_JITTER
from . import logger
from .util import stabilize, DEFAULT_JITTER, Log


DEFAULT_RANK = 0.999
DEFAULT_RANK = 0.99
DEFAULT_METHOD = "auto"

logger = Log()

def _check_method(rank, full, method):
R"""
Expand Down Expand Up @@ -43,22 +43,6 @@ def _check_method(rank, full, method):
message = f"""The argument method={method} does not match the rank={rank}.
The detected method from the rank is 'fixed'."""
raise ValueError(message)
elif rank == 1 and method == "auto":
if percent:
message = """rank is 1.0, which is ambiguous. Because
rank is a float, it is interpreted as the percentage of
eigenvalues to include in the low rank approximation.
To bypass this warning, explictly set method='percent'.
If this is not the intended behavior, explicitly set
method='fixed'."""
else:
message = """rank is 1, which is ambiguous. Because
rank is an int, it is interpreted as the number of
eigenvectors to include in the low rank approximation.
To bypass this warning, explictly set method='fixed'.
If this is not the intended behavior, explicitly set
method='percent'."""
logger.warning(message)
if percent:
return "percent"
else:
Expand Down Expand Up @@ -100,8 +84,9 @@ def _eigendecomposition(A, rank=DEFAULT_RANK, method=DEFAULT_METHOD):
p = 1
else:
p = min(rank, p)
frac = summed[p]/summed[-1]
logger.info(f'Recovering {frac:%} variance in rank reduction.')
if (method == "percent" and rank < 1) or rank < len(summed):
frac = summed[p]/summed[-1]
logger.info(f'Recovering {frac:%} variance in eigendecomposition.')
s_ = s[-p:]
v_ = v[:, -p:]
return s_, v_
Expand Down
9 changes: 4 additions & 5 deletions mellon/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,12 @@ def _nearest_neighbors(r, d):
:return: The likelihood function.
:rtype: function
"""
constant1 = pi ** (d / 2) / exp(gammaln(d / 2 + 1))
V = constant1 * (r**d)
constant2 = log(d + 1e-16) + (d * log(pi) / 2) - gammaln(d / 2 + 1)
Vdr = constant2 + ((d - 1) * log(r))
const = (d * log(pi) / 2) - gammaln(d / 2 + 1)
V = log(r) * d + const
Vdr = log(d) + ((d - 1) * log(r)) + const

def logpdf(log_density):
A = exp(log_density) * V
A = exp(log_density + V)
B = log_density + Vdr
return arraysum(B - A)

Expand Down
30 changes: 11 additions & 19 deletions mellon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
from .util import (
DEFAULT_JITTER,
vector_map,
configure_logger,
Log,
)
from . import logger


DEFAULT_COV_FUNC = Matern52

logger = Log()

class BaseEstimator:
R"""
Expand All @@ -50,7 +50,6 @@ def __init__(
cov_func_curry=DEFAULT_COV_FUNC,
n_landmarks=DEFAULT_N_LANDMARKS,
rank=DEFAULT_RANK,
method=DEFAULT_METHOD,
jitter=DEFAULT_JITTER,
optimizer=DEFAULT_OPTIMIZER,
n_iter=DEFAULT_N_ITER,
Expand All @@ -65,11 +64,9 @@ def __init__(
L=None,
initial_value=None,
):
configure_logger(logger)
self.cov_func_curry = cov_func_curry
self.n_landmarks = n_landmarks
self.rank = rank
self.method = method
self.jitter = jitter
self.landmarks = landmarks
self.nn_distances = nn_distances
Expand All @@ -79,7 +76,6 @@ def __init__(
self.cov_func = cov_func
self.L = L
self.x = None
self.logger = logger

def __str__(self):
return self.__repr__()
Expand All @@ -91,9 +87,7 @@ def __repr__(self):
f"cov_func_curry={self.cov_func_curry}, "
f"n_landmarks={self.n_landmarks}, "
f"rank={self.rank}, "
f"method='{self.method}', "
f"jitter={self.jitter}, "
f"n_iter={self.n_iter}, "
f"landmarks={self.landmarks}, "
)
if self.nn_distances is None:
Expand Down Expand Up @@ -124,6 +118,7 @@ def _compute_landmarks(self):

def _compute_nn_distances(self):
x = self.x
logger.info('Computing nearest neighbor distances.')
nn_distances = compute_nn_distances(x)
return nn_distances

Expand All @@ -148,7 +143,7 @@ def _compute_L(self):
rank = self.rank
method = self.method
jitter = self.jitter
if isinstance(rank, float):
if isinstance(rank, float) and method != 'fixed':
logger.info(
f'Computing rank reduction using "{method}" method '
f"retaining > {rank:.2%} of variance."
Expand All @@ -161,7 +156,12 @@ def _compute_L(self):
x, cov_func, landmarks=landmarks, rank=rank, method=method, jitter=jitter
)
new_rank = L.shape[1]
if new_rank > (0.8 * n_landmarks):
if not (
type(rank) is int
and rank == n_landmarks
or type(rank) is float
and rank == 1.0
) and method != 'fixed' and new_rank > (rank * 0.8 * n_landmarks):
logger.warning(
f"Shallow rank reduction from {n_landmarks:,} to {new_rank:,} "
"indicates underrepresentation by landmarks. Consider "
Expand Down Expand Up @@ -448,7 +448,6 @@ def __init__(
cov_func_curry=cov_func_curry,
n_landmarks=n_landmarks,
rank=rank,
method=method,
jitter=jitter,
landmarks=landmarks,
nn_distances=nn_distances,
Expand All @@ -458,6 +457,7 @@ def __init__(
cov_func=cov_func,
L=L,
)
self.method = method
self.optimizer = optimizer
self.n_iter = n_iter
self.init_learn_rate = init_learn_rate
Expand Down Expand Up @@ -550,7 +550,6 @@ def _compute_loss_func(self):
def _set_log_density_x(self):
pre_transformation = self.pre_transformation
transform = self.transform
logger.info("Decoding latent density representation.")
log_density_x = compute_log_density_x(pre_transformation, transform)
self.log_density_x = log_density_x

Expand Down Expand Up @@ -776,9 +775,6 @@ class FunctionEstimator(BaseEstimator):
k(x, y) :math:`\rightarrow` float. If None, automatically generates the covariance
function cov_func = cov_func_curry(ls). Defaults to None.
:type cov_func: function or None
:param L: A matrix such that :math:`L L^\top \approx K`, where :math:`K` is the covariance matrix.
If None, automatically computes L. Defaults to None.
:type L: array-like or None
:param sigma: The white moise standard deviation. Defaults to 0.
:type sigma: float
:ivar n_landmarks: The number of landmark points.
Expand Down Expand Up @@ -819,22 +815,19 @@ def __init__(
ls=None,
ls_factor=1,
cov_func=None,
L=None,
sigma=0,
):
super().__init__(
cov_func_curry=cov_func_curry,
n_landmarks=n_landmarks,
rank=rank,
method=method,
jitter=jitter,
landmarks=landmarks,
nn_distances=nn_distances,
mu=mu,
ls=ls,
ls_factor=ls_factor,
cov_func=cov_func,
L=L,
)
self.sigma = sigma

Expand Down Expand Up @@ -871,7 +864,6 @@ def prepare_inference(self, x):
self._prepare_attribute("ls")
self._prepare_attribute("cov_func")
self._prepare_attribute("landmarks")
self._prepare_attribute("L")
return

def compute_conditional(self, x=None, y=None):
Expand Down
63 changes: 28 additions & 35 deletions mellon/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,38 +80,31 @@ def vector_map(fun, X, in_axis=0):
return vfun(X)


def logger_is_configured(logger):
"""
Checks if the logger has any other handlers than the NullHandler.
:param logger: A logger from the logging module.
:type logger: logging.Logger
:return: If the logger is configured.
:rtype: bool
"""
for handler in logger.handlers:
if not isinstance(handler, logging.NullHandler):
return True
return False


def configure_logger(logger, force=False):
"""
Applies default configuration to the logger if it is not configured yet.
:param logger: A logger from the logging module.
:type logger: logging.Logger
:param force: If True, apply configuratuion even if configured.
Defaults to False.
:type force: bool
:return: The passed logger.
:rtype: logging.Logger
"""
if force or not logger_is_configured(logger):
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter("[%(asctime)s] [%(levelname)-8s] %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
return logger
class Log(object):
"""Access the Mellon logging/verbosity. Log() returns the singelon logger and
Log.off() and Log.on() disable or enable logging respectively."""

def __new__(cls):
"""Return the singelton Logger."""
if not hasattr(cls, "logger"):
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
cls.handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter("[%(asctime)s] [%(levelname)-8s] %(message)s")
cls.handler.setFormatter(formatter)
logger.addHandler(cls.handler)
logger.propagate = False
cls.logger = logger
return cls.logger

@classmethod
def off(cls):
"""Turn off all logging."""
cls.__new__(cls)
cls.logger.setLevel(logging.CRITICAL + 1)

@classmethod
def on(cls):
"""Turn on logging."""
cls.__new__(cls)
cls.logger.setLevel(logging.INFO)
10 changes: 9 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@

this_directory = Path(__file__).parent

def get_version(rel_path):
for line in (this_directory / rel_path).read_text().splitlines():
if line.startswith('__version__'):
delim = '"' if '"' in line else "'"
return line.split(delim)[1]
else:
raise RuntimeError("Unable to find version string.")

setup(
name="mellon",
version="1.0.2",
version=get_version('mellon/__init__.py'),
description="Non-parametric density estimator.",
url="https://github.com/settylab/mellon",
author="Setty Lab",
Expand Down

0 comments on commit 6460152

Please sign in to comment.