Skip to content

Commit

Permalink
add --num-threads argument to train cli (#5086)
Browse files Browse the repository at this point in the history
* add `--num-threads` argument to train cli

* Update rasa/cli/train.py

* Create 5086.feature.rst

* fixed tests

Co-authored-by: Tom Bocklisch <tom@rasa.com>
  • Loading branch information
mludv and tmbo authored May 18, 2020
1 parent 6c18ff1 commit 0d34b00
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 16 deletions.
1 change: 1 addition & 0 deletions changelog/5086.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added a ``--num-threads`` CLI argument that can be passed to ``rasa train`` and will be used to train NLU components.
15 changes: 15 additions & 0 deletions rasa/cli/arguments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def set_train_arguments(parser: argparse.ArgumentParser):
add_augmentation_param(parser)
add_debug_plots_param(parser)

add_num_threads_param(parser)

add_model_name_param(parser)
add_persist_nlu_data_param(parser)
add_force_param(parser)
Expand Down Expand Up @@ -48,6 +50,8 @@ def set_train_nlu_arguments(parser: argparse.ArgumentParser):

add_nlu_data_param(parser, help_text="File or folder containing your NLU data.")

add_num_threads_param(parser)

add_model_name_param(parser)
add_persist_nlu_data_param(parser)

Expand Down Expand Up @@ -120,6 +124,17 @@ def add_debug_plots_param(
)


def add_num_threads_param(
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer]
):
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Maximum amount of threads to use when training.",
)


def add_model_name_param(parser: argparse.ArgumentParser):
parser.add_argument(
"--fixed-model-name",
Expand Down
17 changes: 14 additions & 3 deletions rasa/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def train(args: argparse.Namespace) -> Optional[Text]:
force_training=args.force,
fixed_model_name=args.fixed_model_name,
persist_nlu_training_data=args.persist_nlu_data,
additional_arguments=extract_additional_arguments(args),
core_additional_arguments=extract_core_additional_arguments(args),
nlu_additional_arguments=extract_nlu_additional_arguments(args),
)


Expand All @@ -92,7 +93,7 @@ def train_core(
story_file = get_validated_path(
args.stories, "stories", DEFAULT_DATA_PATH, none_is_valid=True
)
additional_arguments = extract_additional_arguments(args)
additional_arguments = extract_core_additional_arguments(args)

# Policies might be a list for the compare training. Do normal training
# if only list item was passed.
Expand Down Expand Up @@ -138,10 +139,11 @@ def train_nlu(
train_path=train_path,
fixed_model_name=args.fixed_model_name,
persist_nlu_training_data=args.persist_nlu_data,
additional_arguments=extract_nlu_additional_arguments(args),
)


def extract_additional_arguments(args: argparse.Namespace) -> Dict:
def extract_core_additional_arguments(args: argparse.Namespace) -> Dict:
arguments = {}

if "augmentation" in args:
Expand All @@ -152,6 +154,15 @@ def extract_additional_arguments(args: argparse.Namespace) -> Dict:
return arguments


def extract_nlu_additional_arguments(args: argparse.Namespace) -> Dict:
arguments = {}

if "num_threads" in args:
arguments["num_threads"] = args.num_threads

return arguments


def _get_valid_config(
config: Optional[Text],
mandatory_keys: List[Text],
Expand Down
48 changes: 36 additions & 12 deletions rasa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def train(
force_training: bool = False,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Optional[Text]:
if loop is None:
Expand All @@ -47,7 +48,8 @@ def train(
force_training=force_training,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=additional_arguments,
core_additional_arguments=core_additional_arguments,
nlu_additional_arguments=nlu_additional_arguments,
)
)

Expand All @@ -60,7 +62,8 @@ async def train_async(
force_training: bool = False,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
) -> Optional[Text]:
"""Trains a Rasa model (Core and NLU).
Expand All @@ -73,7 +76,9 @@ async def train_async(
fixed_model_name: Name of model to be stored.
persist_nlu_training_data: `True` if the NLU training data should be persisted
with the model.
additional_arguments: Additional training parameters.
core_additional_arguments: Additional training parameters for core training.
nlu_additional_arguments: Additional training parameters forwarded to training
method of each NLU component.
Returns:
Path of the trained model archive.
Expand All @@ -98,7 +103,8 @@ async def train_async(
force_training,
fixed_model_name,
persist_nlu_training_data,
additional_arguments,
core_additional_arguments=core_additional_arguments,
nlu_additional_arguments=nlu_additional_arguments,
)


Expand All @@ -122,7 +128,8 @@ async def _train_async_internal(
force_training: bool,
fixed_model_name: Optional[Text],
persist_nlu_training_data: bool,
additional_arguments: Optional[Dict],
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
) -> Optional[Text]:
"""Trains a Rasa model (Core and NLU). Use only from `train_async`.
Expand All @@ -131,10 +138,12 @@ async def _train_async_internal(
train_path: Directory in which to train the model.
output_path: Output path.
force_training: If `True` retrain model even if data has not changed.
fixed_model_name: Name of model to be stored.
persist_nlu_training_data: `True` if the NLU training data should be persisted
with the model.
fixed_model_name: Name of model to be stored.
additional_arguments: Additional training parameters.
core_additional_arguments: Additional training parameters for core training.
nlu_additional_arguments: Additional training parameters forwarded to training
method of each NLU component.
Returns:
Path of the trained model archive.
Expand All @@ -158,6 +167,7 @@ async def _train_async_internal(
output=output_path,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=nlu_additional_arguments,
)

if nlu_data.is_empty():
Expand All @@ -166,7 +176,7 @@ async def _train_async_internal(
file_importer,
output=output_path,
fixed_model_name=fixed_model_name,
additional_arguments=additional_arguments,
additional_arguments=core_additional_arguments,
)

new_fingerprint = await model.model_fingerprint(file_importer)
Expand All @@ -185,7 +195,8 @@ async def _train_async_internal(
fingerprint_comparison_result=fingerprint_comparison,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=additional_arguments,
core_additional_arguments=core_additional_arguments,
nlu_additional_arguments=nlu_additional_arguments,
)

return model.package_model(
Expand All @@ -209,7 +220,8 @@ async def _do_training(
fingerprint_comparison_result: Optional[FingerprintComparisonResult] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
):
if not fingerprint_comparison_result:
fingerprint_comparison_result = FingerprintComparisonResult()
Expand All @@ -220,7 +232,7 @@ async def _do_training(
output=output_path,
train_path=train_path,
fixed_model_name=fixed_model_name,
additional_arguments=additional_arguments,
additional_arguments=core_additional_arguments,
)
elif fingerprint_comparison_result.should_retrain_nlg():
print_color(
Expand All @@ -243,6 +255,7 @@ async def _do_training(
train_path=train_path,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=nlu_additional_arguments,
)
else:
print_color(
Expand Down Expand Up @@ -383,6 +396,7 @@ def train_nlu(
train_path: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
) -> Optional[Text]:
"""Trains an NLU model.
Expand All @@ -395,6 +409,8 @@ def train_nlu(
fixed_model_name: Name of the model to be stored.
persist_nlu_training_data: `True` if the NLU training data should be persisted
with the model.
additional_arguments: Additional training parameters which will be passed to
the `train` method of each component.
Returns:
Expand All @@ -412,6 +428,7 @@ def train_nlu(
train_path,
fixed_model_name,
persist_nlu_training_data,
additional_arguments,
)
)

Expand All @@ -423,6 +440,7 @@ async def _train_nlu_async(
train_path: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
):
if not nlu_data:
print_error(
Expand Down Expand Up @@ -451,6 +469,7 @@ async def _train_nlu_async(
train_path=train_path,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=additional_arguments,
)


Expand All @@ -460,11 +479,15 @@ async def _train_nlu_with_validated_data(
train_path: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
) -> Optional[Text]:
"""Train NLU with validated training and config data."""

import rasa.nlu.train

if additional_arguments is None:
additional_arguments = {}

with ExitStack() as stack:
if train_path:
# If the train path was provided, do nothing on exit.
Expand All @@ -480,6 +503,7 @@ async def _train_nlu_with_validated_data(
_train_path,
fixed_model_name="nlu",
persist_nlu_training_data=persist_nlu_training_data,
**additional_arguments,
)
print_color("NLU model training completed.", color=bcolors.OKBLUE)

Expand Down
4 changes: 3 additions & 1 deletion tests/cli/test_rasa_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def test_train_help(run):
help_text = """usage: rasa train [-h] [-v] [-vv] [--quiet] [--data DATA [DATA ...]]
[-c CONFIG] [-d DOMAIN] [--out OUT]
[--augmentation AUGMENTATION] [--debug-plots]
[--num-threads NUM_THREADS]
[--fixed-model-name FIXED_MODEL_NAME] [--persist-nlu-data]
[--force]
{core,nlu} ..."""
Expand All @@ -340,7 +341,8 @@ def test_train_nlu_help(run: Callable[..., RunResult]):
output = run("train", "nlu", "--help")

help_text = """usage: rasa train nlu [-h] [-v] [-vv] [--quiet] [-c CONFIG] [--out OUT]
[-u NLU] [--fixed-model-name FIXED_MODEL_NAME]
[-u NLU] [--num-threads NUM_THREADS]
[--fixed-model-name FIXED_MODEL_NAME]
[--persist-nlu-data]"""

lines = help_text.split("\n")
Expand Down

0 comments on commit 0d34b00

Please sign in to comment.