-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_online.py
25 lines (20 loc) · 1.05 KB
/
train_online.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import logging
from rasa_core.agent import Agent
from rasa_core.channels.console import ConsoleInputChannel
from rasa_core.interpreter import RegexInterpreter
from rasa_core.policies.keras_policy import KerasPolicy
from rasa_core.policies.memoization import MemoizationPolicy
from rasa_core.interpreter import RasaNLUInterpreter
logger=logging.getLogger(__name__)
def run_weather_online(input_channel,interpreter,domain_file="weather_domain.yml",training_data_file="data/stories.md"):
agent=Agent(domain_file,policies=[MemoizationPolicy(),KerasPolicy()], interpreter=interpreter)
agent.train_online(training_data_file,input_channel=input_channel,max_history=2,batch_size=50,epochs=200,max_training_samples=300)
return agent
if __name__=="__main__":
logging.basicConfig(level="INFO")
nlu_interpreter=RasaNLUInterpreter('./models/nlu/default/weathernlu')
run_weather_online(ConsoleInputChannel(),nlu_interpreter)