Skip to content

Commit

Permalink
add a few tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhihang Dong committed Oct 14, 2023
1 parent 0d9a342 commit e3bf2f3
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
41 changes: 41 additions & 0 deletions src/mdp/utils/els.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import logging

class ELSValidation:
"""
A class to check various gradient conditions for EL-S optimization theories.
"""
def __init__(self, F_grad, L, beta=None, mu=None, verbose=False):
"""
Initialize with the gradient of function F, Lipschitz constant L, and optionally beta and mu.
"""
self.F_grad = F_grad
self.L = L
self.beta = beta
self.mu = mu
self.verbose = verbose
self.logger = logging.getLogger(__name__)
if self.verbose:
logging.basicConfig(level=logging.INFO)

def _validate_input(self, theta, delta_theta):
assert theta.shape == delta_theta.shape, "Mismatched shapes between theta and delta_theta!"
assert isinstance(theta, np.ndarray) and isinstance(delta_theta, np.ndarray), "Inputs must be numpy arrays!"

def _log(self, message):
if self.verbose:
self.logger.info(message)

def check_lipschitz(self, theta, delta_theta):
"""
Check the Lipschitz continuity condition.
"""
self._validate_input(theta, delta_theta)
grad_diff = self.F_grad(theta + delta_theta) - self.F_grad(theta)
condition = np.linalg.norm(grad_diff) <= self.L * np.linalg.norm(delta_theta)
self._log(f"Lipschitz check: {condition}")
return condition

# ... (other methods remain the same)

# Usage remains similar
67 changes: 67 additions & 0 deletions src/mdp/utils/softmax_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
from typing import Tuple

class SoftmaxTabularPolicy:
def __init__(self, action_value: np.ndarray, temperature: float = 1.0) -> None:
"""
Initialize the SoftmaxTabularPolicy.
Parameters:
- action_value (np.ndarray): A 2D array representing the action value Q(s, a),
with shape [num_states, num_actions].
- temperature (float, optional): Temperature parameter to control exploration.
Higher values encourage more exploration. Default is 1.0.
"""
self.Q: np.ndarray = np.asarray(action_value)
self.temperature: float = float(temperature)
self.num_states, self.num_actions = self.Q.shape

@staticmethod
def _softmax(x: np.ndarray, temperature: float) -> np.ndarray:
"""
Compute softmax values for a given state-action value array considering numerical stability.
Parameters:
- x (np.ndarray): The input array.
- temperature (float): The temperature parameter.
Returns:
np.ndarray: The computed softmax values.
"""
e_x: np.ndarray = np.exp((x - np.max(x)) / temperature)
return e_x / e_x.sum(axis=-1, keepdims=True)

def action_probabilities(self, state: int) -> np.ndarray:
"""
Compute the action probabilities under the softmax policy for a given state.
Parameters:
- state (int): The index of the state.
Returns:
np.ndarray: A 1D array of action probabilities.
"""
if not (0 <= state < self.num_states):
raise ValueError("Invalid state index!")
return self._softmax(self.Q[state, :], self.temperature)

def sample_action(self, state: int) -> int:
"""
Sample an action from the softmax policy for a given state.
Parameters:
- state (int): The index of the state.
Returns:
int: The index of the sampled action.
"""
return np.random.choice(self.num_actions, p=self.action_probabilities(state))

def set_temperature(self, temperature: float) -> None:
"""
Update the temperature of the softmax policy.
Parameters:
- temperature (float): The new temperature value.
"""
self.temperature = float(temperature)

0 comments on commit e3bf2f3

Please sign in to comment.