Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateless dist map offshoot #10

Merged
merged 9 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ authors = [
{ name = "Predictive Analytics Lab (PAL)", email = "info@predictive-analytics-lab.com" },
]
license = "Apache-2.0"
packages = [{ include = "src" }]
dependencies = [
"tqdm>=4.66.4",
"scikit-learn>=1.5.1",
Expand All @@ -16,6 +15,7 @@ dependencies = [
"loguru>=0.7.2",
"ranzen>=2.5.1",
"hydra-zen>=0.13.0",
"beartype>=0.18.5",
]
classifiers = [
"Programming Language :: Python :: 3.10",
Expand All @@ -24,7 +24,7 @@ classifiers = [
"Operating System :: OS Independent",
"Typing :: Typed",
]
repository = "https://github.com/wearepal/fescher"
urls = {github="https://github.com/wearepal/fescher"}
readme = "README.md"
requires-python = ">= 3.10, <3.12"

Expand All @@ -40,7 +40,7 @@ dev-dependencies = [
]

[tool.rye.scripts]
rrm = {cmd="python run/rrm_credit_lr.py"}
rrm = {cmd="python -m src.rrm_credit_lr"}

[tool.hatch.metadata]
allow-direct-references = true
Expand Down Expand Up @@ -150,4 +150,4 @@ addopts = ["--import-mode=importlib"]
filterwarnings = [
"ignore::DeprecationWarning:_pytest.*:",
"ignore::DeprecationWarning:pkg_resources.*:"
]
]
2 changes: 2 additions & 0 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
antlr4-python3-runtime==4.9.3
# via hydra-core
# via omegaconf
beartype==0.18.5
# via fescher
cloudpickle==3.0.0
# via gymnasium
contourpy==1.2.1
Expand Down
2 changes: 2 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
antlr4-python3-runtime==4.9.3
# via hydra-core
# via omegaconf
beartype==0.18.5
# via fescher
cloudpickle==3.0.0
# via gymnasium
contourpy==1.2.1
Expand Down
259 changes: 0 additions & 259 deletions run/rrm_credit_lr.py

This file was deleted.

74 changes: 70 additions & 4 deletions src/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,73 @@
from __future__ import annotations
import sys
from typing import Final
from typing_extensions import Self

__all__ = ["TESTING"]
import numpy as np
import pytest

TESTING: Final[bool] = "pytest" in sys.modules
from src.dynamics.registration import make_env
from src.dynamics.state import State
from src.models.lr import Model
from src.types import FloatArray, IntArray


@pytest.fixture
def mock_state():
features = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64)
labels = np.array([1, 0], dtype=np.uint8)
return State(features=features, labels=labels)


@pytest.fixture
def mock_env(mock_state: State):
return make_env(initial_state=mock_state, epsilon=0.1, memory=False, changeable_features=[0])


@pytest.fixture
def mock_model():
class MockModel(Model):
def __init__(self):
self._weights = np.array([0.1, 0.2])

@property
def weights(self) -> FloatArray:
return self._weights

@weights.setter
def weights(self, value: FloatArray) -> None:
self._weights = value

def fit(
self,
*,
x: FloatArray,
y: IntArray,
tol: float = 1e-7,
) -> Self:
self.weights = np.array([0.3, 0.4])
return self

def acc(self, *, features: FloatArray, labels: IntArray) -> float:
return 0.5

def loss(
self,
*,
x: FloatArray,
y: IntArray,
) -> float:
return 0.6

@property
def l2_penalty(self) -> float:
raise NotImplementedError

def logits(self, features: FloatArray) -> FloatArray:
raise NotImplementedError

def preds(self, features: FloatArray) -> IntArray:
raise NotImplementedError

def probs(self, features: FloatArray) -> FloatArray:
raise NotImplementedError

return MockModel()
Loading