Skip to content

Commit

Permalink
Merge pull request #7 from yu9824/dev
Browse files Browse the repository at this point in the history
Fix #6.
  • Loading branch information
yu9824 authored Apr 25, 2023
2 parents 0d9d5e9 + d72c80c commit d85a077
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 50 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/python-package-conda.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Regrences
# - https://enu23456.hatenablog.com/entry/2022/11/24/195744
name: Python Package using Conda

on: [push]

jobs:
build-linux:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
max-parallel: 5

steps:
- uses: actions/checkout@v3
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Add conda to system path
run: |
# $CONDA is an environment variable pointing to the root of the miniconda directory
echo $CONDA/bin >> $GITHUB_PATH
- name: Install dependencies
run: |
conda env update --file requirements.txt --name base
conda install pip pandas
pip install -e . --user --no-deps
- name: Test with pytest
run: |
conda install pytest
pytest
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
[![PyPI version](https://badge.fury.io/py/kennard-stone.svg)](https://pypi.org/project/kennard-stone/)
[![Downloads](https://pepy.tech/badge/kennard-stone)](https://pepy.tech/project/kennard-stone)

[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Anaconda-Server Badge](https://anaconda.org/conda-forge/kennard-stone/badges/version.svg)](https://anaconda.org/conda-forge/kennard-stone)
[![Anaconda-platform badge](https://anaconda.org/conda-forge/kennard-stone/badges/platforms.svg)](https://anaconda.org/conda-forge/kennard-stone)
[![Anaconda-license badge](https://anaconda.org/conda-forge/kennard-stone/badges/license.svg)](https://anaconda.org/conda-forge/kennard-stone)

## What is this?

Expand Down
2 changes: 1 addition & 1 deletion kennard_stone/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .kennard_stone import KFold, train_test_split

__version__ = "2.0.0"
__version__ = "2.0.1"
__license__ = "MIT"
__author__ = "yu9824"
__copyright__ = "Copyright © 2021 yu9824"
Expand Down
8 changes: 4 additions & 4 deletions kennard_stone/kennard_stone.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
"""

from typing import List, Union, Optional
from itertools import chain
import warnings

import numpy as np

from itertools import chain
from sklearn.model_selection._split import BaseShuffleSplit
from sklearn.model_selection._split import _BaseKFold
from sklearn.model_selection._split import _validate_shuffle_split
Expand Down Expand Up @@ -194,8 +194,8 @@ def get_indexes(self, X) -> List[List[int]]:
X=X, lst_idx_selected=lst_idx_selected, idx_remaining=idx_remaining
)
assert (
len(sum(indexes, start=[]))
== len(set(sum(indexes, start=[])))
len(list(chain.from_iterable(indexes)))
== len(set(chain.from_iterable(indexes)))
== len(self._original_X)
)

Expand All @@ -208,7 +208,7 @@ def _sort(
idx_remaining: Union[List[int], np.ndarray],
) -> List[List[int]]:
samples_selected: np.ndarray = self._original_X[
sum(lst_idx_selected, start=[])
list(chain.from_iterable(lst_idx_selected))
]

# まだ選択されていない各サンプルにおいて、これまで選択されたすべてのサンプルとの間で
Expand Down
17 changes: 0 additions & 17 deletions tests/check_new_old_match.py

This file was deleted.

27 changes: 0 additions & 27 deletions tests/test.py

This file was deleted.

38 changes: 38 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from sklearn.datasets import load_diabetes
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_validate
import pytest

from kennard_stone import train_test_split, KFold


@pytest.fixture
def prepare_data():
diabetes = load_diabetes(as_frame=True)
X = diabetes.data
y = diabetes.target
return (X, y)


def test_train_test_split(prepare_data):
X, y = prepare_data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
assert X_train.index[0] == y_train.index[0] == 224
assert X_train.index[-1] == y_train.index[-1] == 123
assert X_test.index[0] == y_test.index[0] == 68
assert X_test.index[-1] == y_test.index[-1] == 203


def test_KFold(prepare_data):
X, y = prepare_data
estimator = RandomForestRegressor(random_state=334, n_jobs=-1)
kf = KFold(n_splits=5, shuffle=True)
cross_validate(
estimator,
X,
y,
scoring="neg_mean_squared_error",
n_jobs=-1,
cv=kf,
return_train_score=True,
)
14 changes: 14 additions & 0 deletions tests/test_new_old_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from sklearn.datasets import load_diabetes

from kennard_stone import kennard_stone
from kennard_stone import _deprecated


def test_new_old_match():
diabetes = load_diabetes(as_frame=True)
X = diabetes.data

ks_old = _deprecated._KennardStone()
ks_new = kennard_stone._KennardStone(n_groups=1)

assert ks_old._get_indexes(X) == ks_new.get_indexes(X)[0]

0 comments on commit d85a077

Please sign in to comment.