-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
I supplemented the documentation with a paragraph about the work of t…
…he framework with the optimal selection of two real and one discrete parameters. Corrected the problem code for finding real and discrete parameters. (#167)
- Loading branch information
Showing
18 changed files
with
36,941 additions
and
9 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
40 changes: 40 additions & 0 deletions
40
examples/Machine_learning/SVC/_2D/Example_SVC_2D_Transformators_State.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from iOpt.output_system.listeners.static_painters import StaticPainterNDListener | ||
from iOpt.output_system.listeners.animate_painters import AnimatePainterNDListener | ||
from iOpt.output_system.listeners.console_outputers import ConsoleOutputListener | ||
|
||
from iOpt.solver import Solver | ||
from iOpt.solver_parametrs import SolverParameters | ||
from examples.Machine_learning.SVC._2D.Problems import SVC_2D_Transformators_State | ||
from sklearn.utils import shuffle | ||
import numpy as np | ||
import pandas as pd | ||
import csv | ||
|
||
def factory_dataset(): | ||
x = [] | ||
y = [] | ||
with open(r"../Datasets/transformator_state.csv") as rrrr_file: | ||
file_reader = csv.reader(rrrr_file, delimiter=",") | ||
for row in file_reader: | ||
x_row = [] | ||
for i in range(len(row)-1): | ||
x_row.append(row[i]) | ||
x.append(x_row) | ||
y.append(row[len(row)-1]) | ||
return shuffle(np.array(x), np.array(y), random_state=42) | ||
|
||
|
||
if __name__ == "__main__": | ||
X, Y = factory_dataset() | ||
regularization_value_bound = {'low': 5, 'up': 9} | ||
kernel_coefficient_bound = {'low': -3, 'up': 1} | ||
problem = SVC_2D_Transformators_State.SVC_2D_Transformators_State(X, Y, regularization_value_bound, kernel_coefficient_bound) | ||
method_params = SolverParameters(r=np.double(2.0), iters_limit=100) | ||
solver = Solver(problem, parameters=method_params) | ||
#apl = AnimatePainterNDListener("svc2d_transformator_state_anim.png", "output", vars_indxs=[0, 1]) | ||
#solver.add_listener(apl) | ||
#spl = StaticPainterNDListener("svc2d_transformator_state_stat.png", "output", vars_indxs=[0, 1], mode="surface", calc="interpolation") | ||
#solver.add_listener(spl) | ||
cfol = ConsoleOutputListener(mode='full') | ||
solver.add_listener(cfol) | ||
solver_info = solver.solve() |
58 changes: 58 additions & 0 deletions
58
examples/Machine_learning/SVC/_2D/Problems/SVC_2D_Transformators_State.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import numpy as np | ||
from iOpt.trial import Point | ||
from iOpt.trial import FunctionValue | ||
from iOpt.problem import Problem | ||
from sklearn.svm import SVC | ||
from sklearn.model_selection import cross_val_score | ||
from typing import Dict | ||
from sklearn.pipeline import Pipeline | ||
from sklearn.preprocessing import StandardScaler | ||
|
||
class SVC_2D_Transformators_State(Problem): | ||
""" | ||
Класс SVC_2D представляет возможность поиска оптимального набора гиперпараметров алгоритма | ||
C-Support Vector Classification. | ||
Найденные параметры являются оптимальными при варьировании параматра регуляризации | ||
(Regularization parameter С) значения коэфицента ядра (gamma) | ||
""" | ||
|
||
def __init__(self, x_dataset: np.ndarray, y_dataset: np.ndarray, | ||
regularization_bound: Dict[str, float], | ||
kernel_coefficient_bound: Dict[str, float]): | ||
""" | ||
Конструктор класса SVC_2D | ||
:param x_dataset: входные данные обучающе выборки метода SVC | ||
:param y_dataset: выходные данные обучающе выборки метода SVC | ||
:param kernel_coefficient_bound: Значение параметра регуляризации | ||
:param regularization_bound: Границы изменения значений коэфицента ядра (low - нижняя граница, up - верхняя) | ||
""" | ||
super(SVC_2D_Transformators_State, self).__init__() | ||
self.dimension = 2 | ||
self.number_of_float_variables = 2 | ||
self.number_of_discrete_variables = 0 | ||
self.number_of_objectives = 1 | ||
self.number_of_constraints = 0 | ||
if x_dataset.shape[0] != y_dataset.shape[0]: | ||
raise ValueError('The input and output sample sizes do not match.') | ||
self.x = x_dataset | ||
self.y = y_dataset | ||
self.float_variable_names = np.array(["Regularization parameter", "Kernel coefficient"], dtype=str) | ||
self.lower_bound_of_float_variables = np.array([regularization_bound['low'], kernel_coefficient_bound['low']], | ||
dtype=np.double) | ||
self.upper_bound_of_float_variables = np.array([regularization_bound['up'], kernel_coefficient_bound['up']], | ||
dtype=np.double) | ||
|
||
|
||
|
||
def calculate(self, point: Point, function_value: FunctionValue) -> FunctionValue: | ||
""" | ||
Метод расчёта значения целевой функции в точке | ||
:param point: Точка испытания | ||
:param function_value: объект хранения значения целевой функции в точке | ||
""" | ||
cs, gammas = point.float_variables[0], point.float_variables[1] | ||
clf = SVC(C=10 ** cs, gamma=10 ** gammas) | ||
function_value.value = -cross_val_score(clf, self.x, self.y, scoring='accuracy').mean() | ||
return function_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.