-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
85 lines (70 loc) · 3.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import logging.config
import pandas as pd
from traceback import format_exc
from raif_hack.model import BenchmarkModel
from raif_hack.settings import MODEL_PARAMS,MODEL_RF_PARAMS, LOGGING_CONFIG, NUM_FEATURES, CATEGORICAL_OHE_FEATURES, \
CATEGORICAL_STE_FEATURES, TARGET
from raif_hack.utils import PriceTypeEnum
from raif_hack.metrics import metrics_stat
from raif_hack.features import prepare_categorical, prepare_floor, add_economic, prepare_square, prepare_building, \
prepare_amenity, prepare_historic, add_metro, change_region
logging.config.dictConfig(LOGGING_CONFIG)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(
description="""
Бенчмарк для хакатона по предсказанию стоимости коммерческой недвижимости от "Райффайзенбанк"
Скрипт для обучения модели
Примеры:
1) с poetry - poetry run python3 train.py --train_data /path/to/train/data --model_path /path/to/model
2) без poetry - python3 train.py --train_data /path/to/train/data --model_path /path/to/model
""",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--train_data", "-d", type=str, dest="d", required=True, help="Путь до обучающего датасета")
parser.add_argument("--model_path", "-mp", type=str, dest="mp", required=True,
help="Куда сохранить обученную ML модель")
return parser.parse_args()
if __name__ == "__main__":
try:
logger.info('START train.py')
args = vars(parse_args())
logger.info('Load train df')
train_df = pd.read_csv(args['d'])
logger.info(f'Input shape: {train_df.shape}')
train_df = prepare_square(train_df)
train_df = prepare_building(train_df)
train_df = prepare_amenity(train_df)
train_df = prepare_historic(train_df)
train_df = prepare_categorical(train_df)
train_df = prepare_floor(train_df)
train_df = add_economic(train_df)
train_df = add_metro(train_df)
logger.info(f'Preprocessed shape: {train_df.shape}')
X_offer = train_df[train_df.price_type == PriceTypeEnum.OFFER_PRICE][
NUM_FEATURES + CATEGORICAL_OHE_FEATURES + CATEGORICAL_STE_FEATURES]
y_offer = train_df[train_df.price_type == PriceTypeEnum.OFFER_PRICE][TARGET]
X_manual = train_df[train_df.price_type == PriceTypeEnum.MANUAL_PRICE][
NUM_FEATURES + CATEGORICAL_OHE_FEATURES + CATEGORICAL_STE_FEATURES]
y_manual = train_df[train_df.price_type == PriceTypeEnum.MANUAL_PRICE][TARGET]
logger.info(
f'X_offer {X_offer.shape} y_offer {y_offer.shape}\tX_manual {X_manual.shape} y_manual {y_manual.shape}')
model = BenchmarkModel(numerical_features=NUM_FEATURES, ohe_categorical_features=CATEGORICAL_OHE_FEATURES,
ste_categorical_features=CATEGORICAL_STE_FEATURES, model_params=MODEL_PARAMS)
logger.info('Fit model')
model.fit(X_offer, y_offer, X_manual, y_manual)
logger.info('Save model')
model.save(args['mp'])
predictions_offer = model.predict(X_offer)
metrics = metrics_stat(y_offer.values, predictions_offer / (
1 + model.corr_coef)) # для обучающей выборки с ценами из объявлений смотрим качество без коэффициента
logger.info(f'Metrics stat for training data with offers prices: {metrics}')
predictions_manual = model.predict(X_manual)
metrics = metrics_stat(y_manual.values, predictions_manual)
logger.info(f'Metrics stat for training data with manual prices: {metrics}')
except Exception as e:
err = format_exc()
logger.error(err)
raise (e)
logger.info('END train.py')