Skip to content

Commit

Permalink
add --num-threads argument to train cli
Browse files Browse the repository at this point in the history
  • Loading branch information
mludv committed Jan 26, 2020
1 parent 31f2357 commit 806822d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 15 deletions.
15 changes: 15 additions & 0 deletions rasa/cli/arguments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def set_train_arguments(parser: argparse.ArgumentParser):
add_debug_plots_param(parser)
add_dump_stories_param(parser)

add_num_threads_param(parser)

add_model_name_param(parser)
add_persist_nlu_data_param(parser)
add_force_param(parser)
Expand Down Expand Up @@ -50,6 +52,8 @@ def set_train_nlu_arguments(parser: argparse.ArgumentParser):

add_nlu_data_param(parser, help_text="File or folder containing your NLU data.")

add_num_threads_param(parser)

add_model_name_param(parser)
add_persist_nlu_data_param(parser)

Expand Down Expand Up @@ -133,6 +137,17 @@ def add_debug_plots_param(
)


def add_num_threads_param(
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer]
):
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Maximum amount of threads to use when training.",
)


def add_model_name_param(parser: argparse.ArgumentParser):
parser.add_argument(
"--fixed-model-name",
Expand Down
17 changes: 14 additions & 3 deletions rasa/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def train(args: argparse.Namespace) -> Optional[Text]:
force_training=args.force,
fixed_model_name=args.fixed_model_name,
persist_nlu_training_data=args.persist_nlu_data,
additional_arguments=extract_additional_arguments(args),
core_additional_arguments=extract_core_additional_arguments(args),
nlu_additional_arguments=extract_core_additional_arguments(args),
)


Expand All @@ -92,7 +93,7 @@ def train_core(
story_file = get_validated_path(
args.stories, "stories", DEFAULT_DATA_PATH, none_is_valid=True
)
additional_arguments = extract_additional_arguments(args)
additional_arguments = extract_core_additional_arguments(args)

# Policies might be a list for the compare training. Do normal training
# if only list item was passed.
Expand Down Expand Up @@ -138,10 +139,11 @@ def train_nlu(
train_path=train_path,
fixed_model_name=args.fixed_model_name,
persist_nlu_training_data=args.persist_nlu_data,
additional_arguments=extract_nlu_additional_arguments(args),
)


def extract_additional_arguments(args: argparse.Namespace) -> Dict:
def extract_core_additional_arguments(args: argparse.Namespace) -> Dict:
arguments = {}

if "augmentation" in args:
Expand All @@ -154,6 +156,15 @@ def extract_additional_arguments(args: argparse.Namespace) -> Dict:
return arguments


def extract_nlu_additional_arguments(args: argparse.Namespace) -> Dict:
arguments = {}

if "num_threads" in args:
arguments["num_threads"] = args.num_threads

return arguments


def _get_valid_config(
config: Optional[Text],
mandatory_keys: List[Text],
Expand Down
48 changes: 36 additions & 12 deletions rasa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def train(
force_training: bool = False,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Optional[Text]:
if loop is None:
Expand All @@ -43,7 +44,8 @@ def train(
force_training=force_training,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=additional_arguments,
core_additional_arguments=core_additional_arguments,
nlu_additional_arguments=nlu_additional_arguments,
)
)

Expand All @@ -56,7 +58,8 @@ async def train_async(
force_training: bool = False,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
) -> Optional[Text]:
"""Trains a Rasa model (Core and NLU).
Expand All @@ -69,7 +72,9 @@ async def train_async(
fixed_model_name: Name of model to be stored.
persist_nlu_training_data: `True` if the NLU training data should be persisted
with the model.
additional_arguments: Additional training parameters.
core_additional_arguments: Additional training parameters for core training.
nlu_additional_arguments: Additional training parameters forwarded to training
method of each NLU component.
Returns:
Path of the trained model archive.
Expand All @@ -94,7 +99,8 @@ async def train_async(
force_training,
fixed_model_name,
persist_nlu_training_data,
additional_arguments,
core_additional_arguments=core_additional_arguments,
nlu_additional_arguments=nlu_additional_arguments,
)


Expand All @@ -118,7 +124,8 @@ async def _train_async_internal(
force_training: bool,
fixed_model_name: Optional[Text],
persist_nlu_training_data: bool,
additional_arguments: Optional[Dict],
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
) -> Optional[Text]:
"""Trains a Rasa model (Core and NLU). Use only from `train_async`.
Expand All @@ -127,10 +134,12 @@ async def _train_async_internal(
train_path: Directory in which to train the model.
output_path: Output path.
force_training: If `True` retrain model even if data has not changed.
fixed_model_name: Name of model to be stored.
persist_nlu_training_data: `True` if the NLU training data should be persisted
with the model.
fixed_model_name: Name of model to be stored.
additional_arguments: Additional training parameters.
core_additional_arguments: Additional training parameters for core training.
nlu_additional_arguments: Additional training parameters forwarded to training
method of each NLU component.
Returns:
Path of the trained model archive.
Expand All @@ -154,6 +163,7 @@ async def _train_async_internal(
output=output_path,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=nlu_additional_arguments,
)

if nlu_data.is_empty():
Expand All @@ -162,7 +172,7 @@ async def _train_async_internal(
file_importer,
output=output_path,
fixed_model_name=fixed_model_name,
additional_arguments=additional_arguments,
additional_arguments=core_additional_arguments,
)

new_fingerprint = await model.model_fingerprint(file_importer)
Expand All @@ -181,7 +191,8 @@ async def _train_async_internal(
fingerprint_comparison_result=fingerprint_comparison,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=additional_arguments,
core_additional_arguments=core_additional_arguments,
nlu_additional_arguments=nlu_additional_arguments,
)

return model.package_model(
Expand All @@ -205,7 +216,8 @@ async def _do_training(
fingerprint_comparison_result: Optional[FingerprintComparisonResult] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
):
if not fingerprint_comparison_result:
fingerprint_comparison_result = FingerprintComparisonResult()
Expand All @@ -216,7 +228,7 @@ async def _do_training(
output=output_path,
train_path=train_path,
fixed_model_name=fixed_model_name,
additional_arguments=additional_arguments,
additional_arguments=core_additional_arguments,
)
elif fingerprint_comparison_result.should_retrain_nlg():
print_color(
Expand All @@ -239,6 +251,7 @@ async def _do_training(
train_path=train_path,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=nlu_additional_arguments,
)
else:
print_color(
Expand Down Expand Up @@ -379,6 +392,7 @@ def train_nlu(
train_path: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
) -> Optional[Text]:
"""Trains an NLU model.
Expand All @@ -391,6 +405,8 @@ def train_nlu(
fixed_model_name: Name of the model to be stored.
persist_nlu_training_data: `True` if the NLU training data should be persisted
with the model.
additional_arguments: Additional training parameters which will be passed to
the `train` method of each component.
Returns:
Expand All @@ -408,6 +424,7 @@ def train_nlu(
train_path,
fixed_model_name,
persist_nlu_training_data,
additional_arguments,
)
)

Expand All @@ -419,6 +436,7 @@ async def _train_nlu_async(
train_path: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
):
# training NLU only hence the training files still have to be selected
file_importer = TrainingDataImporter.load_nlu_importer_from_config(
Expand All @@ -439,6 +457,7 @@ async def _train_nlu_async(
train_path=train_path,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=additional_arguments,
)


Expand All @@ -448,11 +467,15 @@ async def _train_nlu_with_validated_data(
train_path: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
) -> Optional[Text]:
"""Train NLU with validated training and config data."""

import rasa.nlu.train

if additional_arguments is None:
additional_arguments = {}

with ExitStack() as stack:
if train_path:
# If the train path was provided, do nothing on exit.
Expand All @@ -468,6 +491,7 @@ async def _train_nlu_with_validated_data(
_train_path,
fixed_model_name="nlu",
persist_nlu_training_data=persist_nlu_training_data,
**additional_arguments,
)
print_color("NLU model training completed.", color=bcolors.OKBLUE)

Expand Down

0 comments on commit 806822d

Please sign in to comment.