-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Zhihang Dong
committed
Oct 14, 2023
1 parent
0d9a342
commit e3bf2f3
Showing
2 changed files
with
108 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import numpy as np | ||
import logging | ||
|
||
class ELSValidation: | ||
""" | ||
A class to check various gradient conditions for EL-S optimization theories. | ||
""" | ||
def __init__(self, F_grad, L, beta=None, mu=None, verbose=False): | ||
""" | ||
Initialize with the gradient of function F, Lipschitz constant L, and optionally beta and mu. | ||
""" | ||
self.F_grad = F_grad | ||
self.L = L | ||
self.beta = beta | ||
self.mu = mu | ||
self.verbose = verbose | ||
self.logger = logging.getLogger(__name__) | ||
if self.verbose: | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
def _validate_input(self, theta, delta_theta): | ||
assert theta.shape == delta_theta.shape, "Mismatched shapes between theta and delta_theta!" | ||
assert isinstance(theta, np.ndarray) and isinstance(delta_theta, np.ndarray), "Inputs must be numpy arrays!" | ||
|
||
def _log(self, message): | ||
if self.verbose: | ||
self.logger.info(message) | ||
|
||
def check_lipschitz(self, theta, delta_theta): | ||
""" | ||
Check the Lipschitz continuity condition. | ||
""" | ||
self._validate_input(theta, delta_theta) | ||
grad_diff = self.F_grad(theta + delta_theta) - self.F_grad(theta) | ||
condition = np.linalg.norm(grad_diff) <= self.L * np.linalg.norm(delta_theta) | ||
self._log(f"Lipschitz check: {condition}") | ||
return condition | ||
|
||
# ... (other methods remain the same) | ||
|
||
# Usage remains similar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import numpy as np | ||
from typing import Tuple | ||
|
||
class SoftmaxTabularPolicy: | ||
def __init__(self, action_value: np.ndarray, temperature: float = 1.0) -> None: | ||
""" | ||
Initialize the SoftmaxTabularPolicy. | ||
Parameters: | ||
- action_value (np.ndarray): A 2D array representing the action value Q(s, a), | ||
with shape [num_states, num_actions]. | ||
- temperature (float, optional): Temperature parameter to control exploration. | ||
Higher values encourage more exploration. Default is 1.0. | ||
""" | ||
self.Q: np.ndarray = np.asarray(action_value) | ||
self.temperature: float = float(temperature) | ||
self.num_states, self.num_actions = self.Q.shape | ||
|
||
@staticmethod | ||
def _softmax(x: np.ndarray, temperature: float) -> np.ndarray: | ||
""" | ||
Compute softmax values for a given state-action value array considering numerical stability. | ||
Parameters: | ||
- x (np.ndarray): The input array. | ||
- temperature (float): The temperature parameter. | ||
Returns: | ||
np.ndarray: The computed softmax values. | ||
""" | ||
e_x: np.ndarray = np.exp((x - np.max(x)) / temperature) | ||
return e_x / e_x.sum(axis=-1, keepdims=True) | ||
|
||
def action_probabilities(self, state: int) -> np.ndarray: | ||
""" | ||
Compute the action probabilities under the softmax policy for a given state. | ||
Parameters: | ||
- state (int): The index of the state. | ||
Returns: | ||
np.ndarray: A 1D array of action probabilities. | ||
""" | ||
if not (0 <= state < self.num_states): | ||
raise ValueError("Invalid state index!") | ||
return self._softmax(self.Q[state, :], self.temperature) | ||
|
||
def sample_action(self, state: int) -> int: | ||
""" | ||
Sample an action from the softmax policy for a given state. | ||
Parameters: | ||
- state (int): The index of the state. | ||
Returns: | ||
int: The index of the sampled action. | ||
""" | ||
return np.random.choice(self.num_actions, p=self.action_probabilities(state)) | ||
|
||
def set_temperature(self, temperature: float) -> None: | ||
""" | ||
Update the temperature of the softmax policy. | ||
Parameters: | ||
- temperature (float): The new temperature value. | ||
""" | ||
self.temperature = float(temperature) |