diff --git a/rasa/cli/arguments/train.py b/rasa/cli/arguments/train.py index c3b4741b555e..6ac8fecd2ea2 100644 --- a/rasa/cli/arguments/train.py +++ b/rasa/cli/arguments/train.py @@ -21,6 +21,8 @@ def set_train_arguments(parser: argparse.ArgumentParser): add_debug_plots_param(parser) add_dump_stories_param(parser) + add_num_threads_param(parser) + add_model_name_param(parser) add_persist_nlu_data_param(parser) add_force_param(parser) @@ -50,6 +52,8 @@ def set_train_nlu_arguments(parser: argparse.ArgumentParser): add_nlu_data_param(parser, help_text="File or folder containing your NLU data.") + add_num_threads_param(parser) + add_model_name_param(parser) add_persist_nlu_data_param(parser) @@ -133,6 +137,17 @@ def add_debug_plots_param( ) +def add_num_threads_param( + parser: Union[argparse.ArgumentParser, argparse._ActionsContainer] +): + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Maximum amount of threads to use when training.", + ) + + def add_model_name_param(parser: argparse.ArgumentParser): parser.add_argument( "--fixed-model-name", diff --git a/rasa/cli/train.py b/rasa/cli/train.py index 0da1b5ebfd9a..e618d95ee923 100644 --- a/rasa/cli/train.py +++ b/rasa/cli/train.py @@ -73,7 +73,8 @@ def train(args: argparse.Namespace) -> Optional[Text]: force_training=args.force, fixed_model_name=args.fixed_model_name, persist_nlu_training_data=args.persist_nlu_data, - additional_arguments=extract_additional_arguments(args), + core_additional_arguments=extract_core_additional_arguments(args), + nlu_additional_arguments=extract_core_additional_arguments(args), ) @@ -92,7 +93,7 @@ def train_core( story_file = get_validated_path( args.stories, "stories", DEFAULT_DATA_PATH, none_is_valid=True ) - additional_arguments = extract_additional_arguments(args) + additional_arguments = extract_core_additional_arguments(args) # Policies might be a list for the compare training. Do normal training # if only list item was passed. @@ -138,10 +139,11 @@ def train_nlu( train_path=train_path, fixed_model_name=args.fixed_model_name, persist_nlu_training_data=args.persist_nlu_data, + additional_arguments=extract_nlu_additional_arguments(args), ) -def extract_additional_arguments(args: argparse.Namespace) -> Dict: +def extract_core_additional_arguments(args: argparse.Namespace) -> Dict: arguments = {} if "augmentation" in args: @@ -154,6 +156,15 @@ def extract_additional_arguments(args: argparse.Namespace) -> Dict: return arguments +def extract_nlu_additional_arguments(args: argparse.Namespace) -> Dict: + arguments = {} + + if "num_threads" in args: + arguments["num_threads"] = args.num_threads + + return arguments + + def _get_valid_config( config: Optional[Text], mandatory_keys: List[Text], diff --git a/rasa/train.py b/rasa/train.py index b18a3018a10d..7ccc69aa0c05 100644 --- a/rasa/train.py +++ b/rasa/train.py @@ -28,7 +28,8 @@ def train( force_training: bool = False, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, - additional_arguments: Optional[Dict] = None, + core_additional_arguments: Optional[Dict] = None, + nlu_additional_arguments: Optional[Dict] = None, loop: Optional[asyncio.AbstractEventLoop] = None, ) -> Optional[Text]: if loop is None: @@ -47,7 +48,8 @@ def train( force_training=force_training, fixed_model_name=fixed_model_name, persist_nlu_training_data=persist_nlu_training_data, - additional_arguments=additional_arguments, + core_additional_arguments=core_additional_arguments, + nlu_additional_arguments=nlu_additional_arguments, ) ) @@ -60,7 +62,8 @@ async def train_async( force_training: bool = False, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, - additional_arguments: Optional[Dict] = None, + core_additional_arguments: Optional[Dict] = None, + nlu_additional_arguments: Optional[Dict] = None, ) -> Optional[Text]: """Trains a Rasa model (Core and NLU). @@ -73,7 +76,9 @@ async def train_async( fixed_model_name: Name of model to be stored. persist_nlu_training_data: `True` if the NLU training data should be persisted with the model. - additional_arguments: Additional training parameters. + core_additional_arguments: Additional training parameters for core training. + nlu_additional_arguments: Additional training parameters forwarded to training + method of each NLU component. Returns: Path of the trained model archive. @@ -98,7 +103,8 @@ async def train_async( force_training, fixed_model_name, persist_nlu_training_data, - additional_arguments, + core_additional_arguments=core_additional_arguments, + nlu_additional_arguments=nlu_additional_arguments, ) @@ -122,7 +128,8 @@ async def _train_async_internal( force_training: bool, fixed_model_name: Optional[Text], persist_nlu_training_data: bool, - additional_arguments: Optional[Dict], + core_additional_arguments: Optional[Dict] = None, + nlu_additional_arguments: Optional[Dict] = None, ) -> Optional[Text]: """Trains a Rasa model (Core and NLU). Use only from `train_async`. @@ -131,10 +138,12 @@ async def _train_async_internal( train_path: Directory in which to train the model. output_path: Output path. force_training: If `True` retrain model even if data has not changed. + fixed_model_name: Name of model to be stored. persist_nlu_training_data: `True` if the NLU training data should be persisted with the model. - fixed_model_name: Name of model to be stored. - additional_arguments: Additional training parameters. + core_additional_arguments: Additional training parameters for core training. + nlu_additional_arguments: Additional training parameters forwarded to training + method of each NLU component. Returns: Path of the trained model archive. @@ -158,6 +167,7 @@ async def _train_async_internal( output=output_path, fixed_model_name=fixed_model_name, persist_nlu_training_data=persist_nlu_training_data, + additional_arguments=nlu_additional_arguments, ) if nlu_data.is_empty(): @@ -166,7 +176,7 @@ async def _train_async_internal( file_importer, output=output_path, fixed_model_name=fixed_model_name, - additional_arguments=additional_arguments, + additional_arguments=core_additional_arguments, ) new_fingerprint = await model.model_fingerprint(file_importer) @@ -185,7 +195,8 @@ async def _train_async_internal( fingerprint_comparison_result=fingerprint_comparison, fixed_model_name=fixed_model_name, persist_nlu_training_data=persist_nlu_training_data, - additional_arguments=additional_arguments, + core_additional_arguments=core_additional_arguments, + nlu_additional_arguments=nlu_additional_arguments, ) return model.package_model( @@ -209,7 +220,8 @@ async def _do_training( fingerprint_comparison_result: Optional[FingerprintComparisonResult] = None, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, - additional_arguments: Optional[Dict] = None, + core_additional_arguments: Optional[Dict] = None, + nlu_additional_arguments: Optional[Dict] = None, ): if not fingerprint_comparison_result: fingerprint_comparison_result = FingerprintComparisonResult() @@ -220,7 +232,7 @@ async def _do_training( output=output_path, train_path=train_path, fixed_model_name=fixed_model_name, - additional_arguments=additional_arguments, + additional_arguments=core_additional_arguments, ) elif fingerprint_comparison_result.should_retrain_nlg(): print_color( @@ -243,6 +255,7 @@ async def _do_training( train_path=train_path, fixed_model_name=fixed_model_name, persist_nlu_training_data=persist_nlu_training_data, + additional_arguments=nlu_additional_arguments, ) else: print_color( @@ -383,6 +396,7 @@ def train_nlu( train_path: Optional[Text] = None, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, + additional_arguments: Optional[Dict] = None, ) -> Optional[Text]: """Trains an NLU model. @@ -395,6 +409,8 @@ def train_nlu( fixed_model_name: Name of the model to be stored. persist_nlu_training_data: `True` if the NLU training data should be persisted with the model. + additional_arguments: Additional training parameters which will be passed to + the `train` method of each component. Returns: @@ -412,6 +428,7 @@ def train_nlu( train_path, fixed_model_name, persist_nlu_training_data, + additional_arguments, ) ) @@ -423,6 +440,7 @@ async def _train_nlu_async( train_path: Optional[Text] = None, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, + additional_arguments: Optional[Dict] = None, ): # training NLU only hence the training files still have to be selected file_importer = TrainingDataImporter.load_nlu_importer_from_config( @@ -443,6 +461,7 @@ async def _train_nlu_async( train_path=train_path, fixed_model_name=fixed_model_name, persist_nlu_training_data=persist_nlu_training_data, + additional_arguments=additional_arguments, ) @@ -452,11 +471,15 @@ async def _train_nlu_with_validated_data( train_path: Optional[Text] = None, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, + additional_arguments: Optional[Dict] = None, ) -> Optional[Text]: """Train NLU with validated training and config data.""" import rasa.nlu.train + if additional_arguments is None: + additional_arguments = {} + with ExitStack() as stack: if train_path: # If the train path was provided, do nothing on exit. @@ -472,6 +495,7 @@ async def _train_nlu_with_validated_data( _train_path, fixed_model_name="nlu", persist_nlu_training_data=persist_nlu_training_data, + **additional_arguments, ) print_color("NLU model training completed.", color=bcolors.OKBLUE)