Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add --num-threads argument to train cli #5086

Merged
merged 7 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/5086.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added a ``--num-threads`` CLI argument that can be passed to ``rasa train`` and will be used to train NLU components.
15 changes: 15 additions & 0 deletions rasa/cli/arguments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def set_train_arguments(parser: argparse.ArgumentParser):
add_augmentation_param(parser)
add_debug_plots_param(parser)

add_num_threads_param(parser)

add_model_name_param(parser)
add_persist_nlu_data_param(parser)
add_force_param(parser)
Expand Down Expand Up @@ -48,6 +50,8 @@ def set_train_nlu_arguments(parser: argparse.ArgumentParser):

add_nlu_data_param(parser, help_text="File or folder containing your NLU data.")

add_num_threads_param(parser)

add_model_name_param(parser)
add_persist_nlu_data_param(parser)

Expand Down Expand Up @@ -120,6 +124,17 @@ def add_debug_plots_param(
)


def add_num_threads_param(
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer]
):
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Maximum amount of threads to use when training.",
)


def add_model_name_param(parser: argparse.ArgumentParser):
parser.add_argument(
"--fixed-model-name",
Expand Down
17 changes: 14 additions & 3 deletions rasa/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def train(args: argparse.Namespace) -> Optional[Text]:
force_training=args.force,
fixed_model_name=args.fixed_model_name,
persist_nlu_training_data=args.persist_nlu_data,
additional_arguments=extract_additional_arguments(args),
core_additional_arguments=extract_core_additional_arguments(args),
nlu_additional_arguments=extract_nlu_additional_arguments(args),
)


Expand All @@ -92,7 +93,7 @@ def train_core(
story_file = get_validated_path(
args.stories, "stories", DEFAULT_DATA_PATH, none_is_valid=True
)
additional_arguments = extract_additional_arguments(args)
additional_arguments = extract_core_additional_arguments(args)

# Policies might be a list for the compare training. Do normal training
# if only list item was passed.
Expand Down Expand Up @@ -138,10 +139,11 @@ def train_nlu(
train_path=train_path,
fixed_model_name=args.fixed_model_name,
persist_nlu_training_data=args.persist_nlu_data,
additional_arguments=extract_nlu_additional_arguments(args),
)


def extract_additional_arguments(args: argparse.Namespace) -> Dict:
def extract_core_additional_arguments(args: argparse.Namespace) -> Dict:
arguments = {}

if "augmentation" in args:
Expand All @@ -152,6 +154,15 @@ def extract_additional_arguments(args: argparse.Namespace) -> Dict:
return arguments


def extract_nlu_additional_arguments(args: argparse.Namespace) -> Dict:
arguments = {}

if "num_threads" in args:
arguments["num_threads"] = args.num_threads

return arguments


def _get_valid_config(
config: Optional[Text],
mandatory_keys: List[Text],
Expand Down
48 changes: 36 additions & 12 deletions rasa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def train(
force_training: bool = False,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Optional[Text]:
if loop is None:
Expand All @@ -47,7 +48,8 @@ def train(
force_training=force_training,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=additional_arguments,
core_additional_arguments=core_additional_arguments,
nlu_additional_arguments=nlu_additional_arguments,
)
)

Expand All @@ -60,7 +62,8 @@ async def train_async(
force_training: bool = False,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
) -> Optional[Text]:
"""Trains a Rasa model (Core and NLU).

Expand All @@ -73,7 +76,9 @@ async def train_async(
fixed_model_name: Name of model to be stored.
persist_nlu_training_data: `True` if the NLU training data should be persisted
with the model.
additional_arguments: Additional training parameters.
core_additional_arguments: Additional training parameters for core training.
nlu_additional_arguments: Additional training parameters forwarded to training
method of each NLU component.

Returns:
Path of the trained model archive.
Expand All @@ -98,7 +103,8 @@ async def train_async(
force_training,
fixed_model_name,
persist_nlu_training_data,
additional_arguments,
core_additional_arguments=core_additional_arguments,
nlu_additional_arguments=nlu_additional_arguments,
)


Expand All @@ -122,7 +128,8 @@ async def _train_async_internal(
force_training: bool,
fixed_model_name: Optional[Text],
persist_nlu_training_data: bool,
additional_arguments: Optional[Dict],
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
) -> Optional[Text]:
"""Trains a Rasa model (Core and NLU). Use only from `train_async`.

Expand All @@ -131,10 +138,12 @@ async def _train_async_internal(
train_path: Directory in which to train the model.
output_path: Output path.
force_training: If `True` retrain model even if data has not changed.
fixed_model_name: Name of model to be stored.
tmbo marked this conversation as resolved.
Show resolved Hide resolved
persist_nlu_training_data: `True` if the NLU training data should be persisted
with the model.
fixed_model_name: Name of model to be stored.
additional_arguments: Additional training parameters.
core_additional_arguments: Additional training parameters for core training.
nlu_additional_arguments: Additional training parameters forwarded to training
method of each NLU component.

Returns:
Path of the trained model archive.
Expand All @@ -158,6 +167,7 @@ async def _train_async_internal(
output=output_path,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=nlu_additional_arguments,
)

if nlu_data.is_empty():
Expand All @@ -166,7 +176,7 @@ async def _train_async_internal(
file_importer,
output=output_path,
fixed_model_name=fixed_model_name,
additional_arguments=additional_arguments,
additional_arguments=core_additional_arguments,
)

new_fingerprint = await model.model_fingerprint(file_importer)
Expand All @@ -185,7 +195,8 @@ async def _train_async_internal(
fingerprint_comparison_result=fingerprint_comparison,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=additional_arguments,
core_additional_arguments=core_additional_arguments,
nlu_additional_arguments=nlu_additional_arguments,
)

return model.package_model(
Expand All @@ -209,7 +220,8 @@ async def _do_training(
fingerprint_comparison_result: Optional[FingerprintComparisonResult] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
):
if not fingerprint_comparison_result:
fingerprint_comparison_result = FingerprintComparisonResult()
Expand All @@ -220,7 +232,7 @@ async def _do_training(
output=output_path,
train_path=train_path,
fixed_model_name=fixed_model_name,
additional_arguments=additional_arguments,
additional_arguments=core_additional_arguments,
)
elif fingerprint_comparison_result.should_retrain_nlg():
print_color(
Expand All @@ -243,6 +255,7 @@ async def _do_training(
train_path=train_path,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=nlu_additional_arguments,
)
else:
print_color(
Expand Down Expand Up @@ -383,6 +396,7 @@ def train_nlu(
train_path: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
) -> Optional[Text]:
"""Trains an NLU model.

Expand All @@ -395,6 +409,8 @@ def train_nlu(
fixed_model_name: Name of the model to be stored.
persist_nlu_training_data: `True` if the NLU training data should be persisted
with the model.
additional_arguments: Additional training parameters which will be passed to
the `train` method of each component.


Returns:
Expand All @@ -412,6 +428,7 @@ def train_nlu(
train_path,
fixed_model_name,
persist_nlu_training_data,
additional_arguments,
)
)

Expand All @@ -423,6 +440,7 @@ async def _train_nlu_async(
train_path: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
):
if not nlu_data:
print_error(
Expand Down Expand Up @@ -451,6 +469,7 @@ async def _train_nlu_async(
train_path=train_path,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=additional_arguments,
)


Expand All @@ -460,11 +479,15 @@ async def _train_nlu_with_validated_data(
train_path: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
) -> Optional[Text]:
"""Train NLU with validated training and config data."""

import rasa.nlu.train

if additional_arguments is None:
additional_arguments = {}

with ExitStack() as stack:
if train_path:
# If the train path was provided, do nothing on exit.
Expand All @@ -480,6 +503,7 @@ async def _train_nlu_with_validated_data(
_train_path,
fixed_model_name="nlu",
persist_nlu_training_data=persist_nlu_training_data,
**additional_arguments,
)
print_color("NLU model training completed.", color=bcolors.OKBLUE)

Expand Down
4 changes: 3 additions & 1 deletion tests/cli/test_rasa_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def test_train_help(run):
help_text = """usage: rasa train [-h] [-v] [-vv] [--quiet] [--data DATA [DATA ...]]
[-c CONFIG] [-d DOMAIN] [--out OUT]
[--augmentation AUGMENTATION] [--debug-plots]
[--num-threads NUM_THREADS]
[--fixed-model-name FIXED_MODEL_NAME] [--persist-nlu-data]
[--force]
{core,nlu} ..."""
Expand All @@ -340,7 +341,8 @@ def test_train_nlu_help(run: Callable[..., RunResult]):
output = run("train", "nlu", "--help")

help_text = """usage: rasa train nlu [-h] [-v] [-vv] [--quiet] [-c CONFIG] [--out OUT]
[-u NLU] [--fixed-model-name FIXED_MODEL_NAME]
[-u NLU] [--num-threads NUM_THREADS]
[--fixed-model-name FIXED_MODEL_NAME]
[--persist-nlu-data]"""

lines = help_text.split("\n")
Expand Down