-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
Copy pathmodels.py
934 lines (788 loc) · 35.2 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
import time
import random
import tensorflow as tf
import numpy as np
import logging
import os
from collections import defaultdict
from typing import List, Text, Dict, Tuple, Union, Optional, Any, TYPE_CHECKING
from keras.utils import tf_utils
from keras import Model
from rasa.shared.constants import DIAGNOSTIC_DATA
from rasa.utils.tensorflow.constants import (
LABEL,
IDS,
INTENT_CLASSIFICATION,
SENTENCE,
SEQUENCE_LENGTH,
RANDOM_SEED,
EMBEDDING_DIMENSION,
REGULARIZATION_CONSTANT,
SIMILARITY_TYPE,
CONNECTION_DENSITY,
NUM_NEG,
LOSS_TYPE,
MAX_POS_SIM,
MAX_NEG_SIM,
USE_MAX_NEG_SIM,
NEGATIVE_MARGIN_SCALE,
SCALE_LOSS,
LEARNING_RATE,
CONSTRAIN_SIMILARITIES,
MODEL_CONFIDENCE,
RUN_EAGERLY,
)
from rasa.utils.tensorflow.model_data import (
RasaModelData,
FeatureSignature,
FeatureArray,
)
import rasa.utils.train_utils
from rasa.utils.tensorflow import layers
from rasa.utils.tensorflow import rasa_layers
from rasa.utils.tensorflow.data_generator import (
RasaDataGenerator,
RasaBatchDataGenerator,
)
from rasa.shared.nlu.constants import TEXT
from rasa.shared.exceptions import RasaException
from rasa.utils.tensorflow.types import BatchData, MaybeNestedBatchData
if TYPE_CHECKING:
from tensorflow.python.types.core import GenericFunction
logger = logging.getLogger(__name__)
LABEL_KEY = LABEL
LABEL_SUB_KEY = IDS
# noinspection PyMethodOverriding
class RasaModel(Model):
"""Abstract custom Keras model.
This model overwrites the following methods:
- train_step
- test_step
- predict_step
- save
- load
Cannot be used as tf.keras.Model.
"""
_training: Optional[bool]
def __init__(self, random_seed: Optional[int] = None, **kwargs: Any) -> None:
"""Initialize the RasaModel.
Args:
random_seed: set the random seed to get reproducible results
"""
# make sure that keras releases resources from previously trained model
tf.keras.backend.clear_session()
super().__init__(**kwargs)
self.total_loss = tf.keras.metrics.Mean(name="t_loss")
self.metrics_to_log = ["t_loss"]
self._training = None # training phase should be defined when building a graph
if random_seed is None:
random_seed = int(time.time())
self.random_seed = random_seed
self._set_random_seed()
self._tf_predict_step: Optional["GenericFunction"] = None
self.prepared_for_prediction = False
self._checkpoint = tf.train.Checkpoint(model=self)
def _set_random_seed(self) -> None:
random.seed(self.random_seed)
np.random.seed(self.random_seed)
tf.random.set_seed(self.random_seed)
tf.experimental.numpy.random.seed(self.random_seed)
tf.keras.utils.set_random_seed(self.random_seed)
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(self.random_seed)
def batch_loss(
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
) -> tf.Tensor:
"""Calculates the loss for the given batch.
Args:
batch_in: The batch.
Returns:
The loss of the given batch.
"""
raise NotImplementedError
def prepare_for_predict(self) -> None:
"""Prepares tf graph fpr prediction.
This method should contain necessary tf calculations
and set self variables that are used in `batch_predict`.
For example, pre calculation of `self.all_labels_embed`.
"""
pass
def batch_predict(
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
) -> Dict[Text, Union[tf.Tensor, Dict[Text, tf.Tensor]]]:
"""Predicts the output of the given batch.
Args:
batch_in: The batch.
Returns:
The output to predict.
"""
raise NotImplementedError
def train_step(
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
) -> Dict[Text, float]:
"""Performs a train step using the given batch.
Args:
batch_in: The batch input.
Returns:
Training metrics.
"""
self._training = True
# calculate supervision and regularization losses separately
with tf.GradientTape(persistent=True) as tape:
prediction_loss = self.batch_loss(batch_in)
regularization_loss = tf.math.add_n(self.losses)
total_loss = prediction_loss + regularization_loss
self.total_loss.update_state(total_loss)
# calculate the gradients that come from supervision signal
prediction_gradients = tape.gradient(prediction_loss, self.trainable_variables)
# calculate the gradients that come from regularization
regularization_gradients = tape.gradient(
regularization_loss, self.trainable_variables
)
# delete gradient tape manually
# since it was created with `persistent=True` option
del tape
gradients = []
for pred_grad, reg_grad in zip(prediction_gradients, regularization_gradients):
if pred_grad is not None and reg_grad is not None:
# remove regularization gradient for variables
# that don't have prediction gradient
gradients.append(
pred_grad
+ tf.where(pred_grad > 0, reg_grad, tf.zeros_like(reg_grad))
)
else:
gradients.append(pred_grad)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
self._training = None
return self._get_metric_results()
def test_step(
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
) -> Dict[Text, float]:
"""Tests the model using the given batch.
This method is used during validation.
Args:
batch_in: The batch input.
Returns:
Testing metrics.
"""
self._training = False
prediction_loss = self.batch_loss(batch_in)
regularization_loss = tf.math.add_n(self.losses)
total_loss = prediction_loss + regularization_loss
self.total_loss.update_state(total_loss)
self._training = None
return self._get_metric_results()
def predict_step(
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
) -> Dict[Text, tf.Tensor]:
"""Predicts the output for the given batch.
Args:
batch_in: The batch to predict.
Returns:
Prediction output.
"""
self._training = False
if not self.prepared_for_prediction:
# in case the model is used for prediction without loading, e.g. directly
# after training, we need to prepare the model for prediction once
self.prepare_for_predict()
self.prepared_for_prediction = True
return self.batch_predict(batch_in)
@staticmethod
def _dynamic_signature(
batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
) -> List[List[tf.TensorSpec]]:
element_spec = []
for tensor in batch_in:
if len(tensor.shape) > 1:
shape: List[Union[None, int]] = [None] * (len(tensor.shape) - 1)
shape += [tensor.shape[-1]]
else:
shape = [None]
element_spec.append(tf.TensorSpec(shape, tensor.dtype))
# batch_in is a list of tensors, therefore we need to wrap element_spec into
# the list
return [element_spec]
def _rasa_predict(
self, batch_in: Tuple[np.ndarray, ...]
) -> Dict[Text, Union[np.ndarray, Dict[Text, Any]]]:
"""Custom prediction method that builds tf graph on the first call.
Args:
batch_in: Prepared batch ready for input to `predict_step` method of model.
Return:
Prediction output, including diagnostic data.
"""
self._training = False
if not self.prepared_for_prediction:
# in case the model is used for prediction without loading, e.g. directly
# after training, we need to prepare the model for prediction once
self.prepare_for_predict()
self.prepared_for_prediction = True
if self._run_eagerly:
# Once we take advantage of TF's distributed training, this is where
# scheduled functions will be forced to execute and return actual values.
outputs = tf_utils.sync_to_numpy_or_python_type(self.predict_step(batch_in))
if DIAGNOSTIC_DATA in outputs:
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
outputs[DIAGNOSTIC_DATA]
)
return outputs
if self._tf_predict_step is None:
self._tf_predict_step = tf.function(
self.predict_step, input_signature=self._dynamic_signature(batch_in)
)
# Once we take advantage of TF's distributed training, this is where
# scheduled functions will be forced to execute and return actual values.
outputs = tf_utils.sync_to_numpy_or_python_type(self._tf_predict_step(batch_in))
if DIAGNOSTIC_DATA in outputs:
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
outputs[DIAGNOSTIC_DATA]
)
return outputs
def run_inference(
self,
model_data: RasaModelData,
batch_size: Union[int, List[int]] = 1,
output_keys_expected: Optional[List[Text]] = None,
) -> Dict[Text, Union[np.ndarray, Dict[Text, Any]]]:
"""Implements bulk inferencing through the model.
Args:
model_data: Input data to be fed to the model.
batch_size: Size of batches that the generator should create.
output_keys_expected: Keys which are expected in the output.
The output should be filtered to have only these keys before
merging it with the output across all batches.
Returns:
Model outputs corresponding to the inputs fed.
"""
outputs: Dict[Text, Union[np.ndarray, Dict[Text, Any]]] = {}
(data_generator, _) = rasa.utils.train_utils.create_data_generators(
model_data=model_data, batch_sizes=batch_size, epochs=1, shuffle=False
)
data_iterator = iter(data_generator)
while True:
try:
# data_generator is a tuple of 2 elements - input and output.
# We only need input, since output is always None and not
# consumed by our TF graphs.
batch_in = next(data_iterator)[0]
batch_out: Dict[
Text, Union[np.ndarray, Dict[Text, Any]]
] = self._rasa_predict(batch_in)
if output_keys_expected:
batch_out = {
key: output
for key, output in batch_out.items()
if key in output_keys_expected
}
outputs = self._merge_batch_outputs(outputs, batch_out)
except StopIteration:
# Generator ran out of batches, time to finish inferencing
break
return outputs
@staticmethod
def _merge_batch_outputs(
all_outputs: Dict[Text, Union[np.ndarray, Dict[Text, Any]]],
batch_output: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
) -> Dict[Text, Union[np.ndarray, Dict[Text, Any]]]:
"""Merges a batch's output into the output for all batches.
Function assumes that the schema of batch output remains the same,
i.e. keys and their value types do not change from one batch's
output to another.
Args:
all_outputs: Existing output for all previous batches.
batch_output: Output for a batch.
Returns:
Merged output with the output for current batch stacked
below the output for all previous batches.
"""
if not all_outputs:
return batch_output
for key, val in batch_output.items():
if isinstance(val, np.ndarray):
all_outputs[key] = np.concatenate(
[all_outputs[key], batch_output[key]], axis=0
)
elif isinstance(val, dict):
# recurse and merge the inner dict first
all_outputs[key] = RasaModel._merge_batch_outputs(all_outputs[key], val)
return all_outputs
@staticmethod
def _empty_lists_to_none_in_dict(input_dict: Dict[Text, Any]) -> Dict[Text, Any]:
"""Recursively replaces empty list or np array with None in a dictionary."""
def _recurse(
x: Union[Dict[Text, Any], List[Any], np.ndarray]
) -> Optional[Union[Dict[Text, Any], List[Any], np.ndarray]]:
if isinstance(x, dict):
return {k: _recurse(v) for k, v in x.items()}
elif (isinstance(x, list) or isinstance(x, np.ndarray)) and np.size(x) == 0:
return None
return x
return {k: _recurse(v) for k, v in input_dict.items()}
def _get_metric_results(self, prefix: Optional[Text] = "") -> Dict[Text, float]:
return {
f"{prefix}{metric.name}": metric.result()
for metric in self.metrics
if metric.name in self.metrics_to_log
}
def save(self, model_file_name: Text, overwrite: bool = True) -> None:
"""Save the model to the given file.
Args:
model_file_name: The file name to save the model to.
overwrite: If 'True' an already existing model with the same file name will
be overwritten.
"""
self.save_weights(model_file_name, overwrite=overwrite, save_format="tf")
@classmethod
def load(
cls,
model_file_name: Text,
model_data_example: RasaModelData,
predict_data_example: Optional[RasaModelData] = None,
finetune_mode: bool = False,
*args: Any,
**kwargs: Any,
) -> "RasaModel":
"""Loads a model from the given weights.
Args:
model_file_name: Path to file containing model weights.
model_data_example: Example data point to construct the model architecture.
predict_data_example: Example data point to speed up prediction during
inference.
finetune_mode: Indicates whether to load the model for further finetuning.
*args: Any other non key-worded arguments.
**kwargs: Any other key-worded arguments.
Returns:
Loaded model with weights appropriately set.
"""
logger.debug(
f"Loading the model from {model_file_name} "
f"with finetune_mode={finetune_mode}..."
)
# create empty model
model = cls(*args, **kwargs)
learning_rate = kwargs.get("config", {}).get(LEARNING_RATE, 0.001)
run_eagerly = kwargs.get("config", {}).get(RUN_EAGERLY)
# need to train on 1 example to build weights of the correct size
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate), run_eagerly=run_eagerly
)
data_generator = RasaBatchDataGenerator(model_data_example, batch_size=1)
model.fit(data_generator, verbose=False)
# load trained weights
model.load_weights(model_file_name)
# predict on one data example to speed up prediction during inference
# the first prediction always takes a bit longer to trace tf function
if not finetune_mode and predict_data_example:
model.run_inference(predict_data_example)
logger.debug("Finished loading the model.")
return model
@staticmethod
def batch_to_model_data_format(
batch: MaybeNestedBatchData,
data_signature: Dict[Text, Dict[Text, List[FeatureSignature]]],
) -> Dict[Text, Dict[Text, List[tf.Tensor]]]:
"""Convert input batch tensors into batch data format.
Batch contains any number of batch data. The order is equal to the
key-value pairs in session data. As sparse data were converted into (indices,
data, shape) before, this method converts them into sparse tensors. Dense
data is kept.
"""
# during training batch is a tuple of input and target data
# as our target data is inside the input data, we are just interested in the
# input data
unpacked_batch = batch[0] if isinstance(batch[0], Tuple) else batch
batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]] = defaultdict(
lambda: defaultdict(list)
)
idx = 0
for key, values in data_signature.items():
for sub_key, signature in values.items():
for is_sparse, feature_dimension, number_of_dimensions in signature:
# we converted all 4D features to 3D features before
number_of_dimensions = (
number_of_dimensions if number_of_dimensions != 4 else 3
)
if is_sparse:
tensor, idx = RasaModel._convert_sparse_features(
unpacked_batch, feature_dimension, idx, number_of_dimensions
)
else:
tensor, idx = RasaModel._convert_dense_features(
unpacked_batch, feature_dimension, idx, number_of_dimensions
)
batch_data[key][sub_key].append(tensor)
return batch_data
@staticmethod
def _convert_dense_features(
batch: BatchData,
feature_dimension: int,
idx: int,
number_of_dimensions: int,
) -> Tuple[tf.Tensor, int]:
batch_at_idx = batch[idx]
if isinstance(batch_at_idx, tf.Tensor):
# explicitly substitute last dimension in shape with known
# static value
if number_of_dimensions > 1 and (
batch_at_idx.shape is None or batch_at_idx.shape[-1] is None
):
shape: List[Optional[int]] = [None] * (number_of_dimensions - 1)
shape.append(feature_dimension)
batch_at_idx.set_shape(shape)
return batch_at_idx, idx + 1
# convert to Tensor
return (
tf.constant(batch[idx], dtype=tf.float32, shape=batch[idx].shape),
idx + 1,
)
@staticmethod
def _convert_sparse_features(
batch: BatchData,
feature_dimension: int,
idx: int,
number_of_dimensions: int,
) -> Tuple[tf.SparseTensor, int]:
# explicitly substitute last dimension in shape with known
# static value
shape = [batch[idx + 2][i] for i in range(number_of_dimensions - 1)] + [
feature_dimension
]
return tf.SparseTensor(batch[idx], batch[idx + 1], shape), idx + 3
def call(
self,
inputs: Union[tf.Tensor, List[tf.Tensor]],
training: Optional[tf.Tensor] = None,
mask: Optional[tf.Tensor] = None,
) -> Union[tf.Tensor, List[tf.Tensor]]:
"""Calls the model on new inputs.
Arguments:
inputs: A tensor or list of tensors.
training: Boolean or boolean scalar tensor, indicating whether to run
the `Network` in training mode or inference mode.
mask: A mask or list of masks. A mask can be
either a tensor or None (no mask).
Returns:
A tensor if there is a single output, or
a list of tensors if there are more than one outputs.
"""
# This method needs to be implemented, otherwise the super class is raising a
# NotImplementedError('When subclassing the `Model` class, you should
# implement a `call` method.')
pass
# noinspection PyMethodOverriding
class TransformerRasaModel(RasaModel):
def __init__(
self,
name: Text,
config: Dict[Text, Any],
data_signature: Dict[Text, Dict[Text, List[FeatureSignature]]],
label_data: RasaModelData,
) -> None:
super().__init__(name=name, random_seed=config[RANDOM_SEED])
self.config = config
self.data_signature = data_signature
self.label_signature = label_data.get_signature()
self._check_data()
label_batch = RasaDataGenerator.prepare_batch(label_data.data)
self.tf_label_data = self.batch_to_model_data_format(
label_batch, self.label_signature
)
# set up tf layers
self._tf_layers: Dict[Text, tf.keras.layers.Layer] = {}
def adjust_for_incremental_training(
self,
data_example: Dict[Text, Dict[Text, List[FeatureArray]]],
new_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
old_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
) -> None:
"""Adjusts the model for incremental training.
First we should check if any of the sparse feature sizes has decreased
and raise an exception if this happens.
If none of them have decreased and any of them has increased, then the
function updates `DenseForSparse` layers, compiles the model, fits a sample
data on it to activate adjusted layer(s) and updates the data signatures.
New and old sparse feature sizes could look like this:
{TEXT: {FEATURE_TYPE_SEQUENCE: [4, 24, 128], FEATURE_TYPE_SENTENCE: [4, 128]}}
Args:
data_example: a data example that is stored with the ML component.
new_sparse_feature_sizes: sizes of current sparse features.
old_sparse_feature_sizes: sizes of sparse features the model was
previously trained on.
"""
self._check_if_sparse_feature_sizes_decreased(
new_sparse_feature_sizes=new_sparse_feature_sizes,
old_sparse_feature_sizes=old_sparse_feature_sizes,
)
if self._sparse_feature_sizes_have_increased(
new_sparse_feature_sizes=new_sparse_feature_sizes,
old_sparse_feature_sizes=old_sparse_feature_sizes,
):
self._update_dense_for_sparse_layers(
new_sparse_feature_sizes, old_sparse_feature_sizes
)
self._compile_and_fit(data_example)
@staticmethod
def _check_if_sparse_feature_sizes_decreased(
new_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
old_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
) -> None:
"""Checks if the sizes of sparse features have decreased during fine-tuning.
Sparse feature sizes might decrease after changing the training data.
This can happen for example with `LexicalSyntacticFeaturizer`.
We don't support this behaviour and we raise an exception if this happens.
Args:
new_sparse_feature_sizes: sizes of current sparse features.
old_sparse_feature_sizes: sizes of sparse features the model was
previously trained on.
Raises:
RasaException: When any of the sparse feature sizes decrease
from the last time training was run.
"""
for attribute, new_feature_sizes in new_sparse_feature_sizes.items():
old_feature_sizes = old_sparse_feature_sizes[attribute]
for feature_type, new_sizes in new_feature_sizes.items():
old_sizes = old_feature_sizes[feature_type]
for new_size, old_size in zip(new_sizes, old_sizes):
if new_size < old_size:
raise RasaException(
"Sparse feature sizes have decreased from the last time "
"training was run. The training data was changed in a way "
"that resulted in some features not being present in the "
"data anymore. This can happen if you had "
"`LexicalSyntacticFeaturizer` in your pipeline. "
"The pipeline cannot support incremental training "
"in this setting. We recommend you to retrain "
"the model from scratch."
)
@staticmethod
def _sparse_feature_sizes_have_increased(
new_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
old_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
) -> bool:
"""Checks if the sizes of sparse features have increased during fine-tuning.
If there's any sparse feature size that has increased after changing the
training data, we need to look for the corresponding `DenseForSparse` layer
and adjust it. On the other hand, if none of them have increased, we don't
need to change anything. This function helps us with making the decision.
Note that the function assumes that none of the sparse feature sizes
have decreased. In other words, it should get valid arguments in order
to function well.
Args:
new_sparse_feature_sizes: sizes of current sparse features.
old_sparse_feature_sizes: sizes of sparse features the model was
previously trained on.
Returns:
`True` if any of the sparse feature sizes has increased, `False` otherwise.
"""
for attribute, new_feature_sizes in new_sparse_feature_sizes.items():
old_feature_sizes = old_sparse_feature_sizes[attribute]
for feature_type, new_sizes in new_feature_sizes.items():
old_sizes = old_feature_sizes[feature_type]
if sum(new_sizes) > sum(old_sizes):
return True
return False
def _update_dense_for_sparse_layers(
self,
new_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
old_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
) -> None:
"""Updates `DenseForSparse` layers.
Updates sizes of `DenseForSparse` layers by comparing current sparse feature
sizes to old ones. This must be done before fine-tuning starts to account
for any change in the size of sparse features that might have happened
because of addition of new data.
Args:
new_sparse_feature_sizes: sizes of current sparse features.
old_sparse_feature_sizes: sizes of sparse features the model was
previously trained on.
"""
for name, layer in self._tf_layers.items():
# `if` condition is necessary because only `RasaCustomLayer`
# can adjust sparse layers for incremental training by default.
if isinstance(layer, rasa_layers.RasaCustomLayer):
layer.adjust_sparse_layers_for_incremental_training(
new_sparse_feature_sizes,
old_sparse_feature_sizes,
self.config[REGULARIZATION_CONSTANT],
)
def _compile_and_fit(
self, data_example: Dict[Text, Dict[Text, List[FeatureArray]]]
) -> None:
"""Compiles modified model and fits a sample data on it.
Args:
data_example: a data example that is stored with the ML component.
"""
self.compile(
optimizer=tf.keras.optimizers.Adam(self.config[LEARNING_RATE]),
run_eagerly=self.config[RUN_EAGERLY],
)
label_key = LABEL_KEY if self.config[INTENT_CLASSIFICATION] else None
label_sub_key = LABEL_SUB_KEY if self.config[INTENT_CLASSIFICATION] else None
model_data = RasaModelData(
label_key=label_key, label_sub_key=label_sub_key, data=data_example
)
self._update_data_signatures(model_data)
data_generator = RasaBatchDataGenerator(model_data, batch_size=1)
self.fit(data_generator, verbose=False)
def _update_data_signatures(self, model_data: RasaModelData) -> None:
self.data_signature = model_data.get_signature()
self.predict_data_signature = {
feature_name: features
for feature_name, features in self.data_signature.items()
if TEXT in feature_name
}
def _check_data(self) -> None:
raise NotImplementedError
def _prepare_layers(self) -> None:
raise NotImplementedError
def _prepare_label_classification_layers(self, predictor_attribute: Text) -> None:
"""Prepares layers & loss for the final label prediction step."""
self._prepare_embed_layers(predictor_attribute)
self._prepare_embed_layers(LABEL)
self._prepare_dot_product_loss(LABEL, self.config[SCALE_LOSS])
def _prepare_embed_layers(self, name: Text, prefix: Text = "embed") -> None:
self._tf_layers[f"{prefix}.{name}"] = layers.Embed(
self.config[EMBEDDING_DIMENSION], self.config[REGULARIZATION_CONSTANT], name
)
def _prepare_ffnn_layer(
self,
name: Text,
layer_sizes: List[int],
drop_rate: float,
prefix: Text = "ffnn",
) -> None:
self._tf_layers[f"{prefix}.{name}"] = layers.Ffnn(
layer_sizes,
drop_rate,
self.config[REGULARIZATION_CONSTANT],
self.config[CONNECTION_DENSITY],
layer_name_suffix=name,
)
def _prepare_dot_product_loss(
self, name: Text, scale_loss: bool, prefix: Text = "loss"
) -> None:
self._tf_layers[f"{prefix}.{name}"] = self.dot_product_loss_layer(
self.config[NUM_NEG],
loss_type=self.config[LOSS_TYPE],
mu_pos=self.config[MAX_POS_SIM],
mu_neg=self.config[MAX_NEG_SIM],
use_max_sim_neg=self.config[USE_MAX_NEG_SIM],
neg_lambda=self.config[NEGATIVE_MARGIN_SCALE],
scale_loss=scale_loss,
similarity_type=self.config[SIMILARITY_TYPE],
constrain_similarities=self.config[CONSTRAIN_SIMILARITIES],
model_confidence=self.config[MODEL_CONFIDENCE],
)
@property
def dot_product_loss_layer(self) -> tf.keras.layers.Layer:
"""Returns the dot-product loss layer to use.
Returns:
The loss layer that is used by `_prepare_dot_product_loss`.
"""
return layers.SingleLabelDotProductLoss
def _prepare_entity_recognition_layers(self) -> None:
for tag_spec in self._entity_tag_specs:
name = tag_spec.tag_name
num_tags = tag_spec.num_tags
self._tf_layers[f"embed.{name}.logits"] = layers.Embed(
num_tags, self.config[REGULARIZATION_CONSTANT], f"logits.{name}"
)
self._tf_layers[f"crf.{name}"] = layers.CRF(
num_tags, self.config[REGULARIZATION_CONSTANT], self.config[SCALE_LOSS]
)
self._tf_layers[f"embed.{name}.tags"] = layers.Embed(
self.config[EMBEDDING_DIMENSION],
self.config[REGULARIZATION_CONSTANT],
f"tags.{name}",
)
@staticmethod
def _last_token(x: tf.Tensor, sequence_lengths: tf.Tensor) -> tf.Tensor:
last_sequence_index = tf.maximum(0, sequence_lengths - 1)
batch_index = tf.range(tf.shape(last_sequence_index)[0])
indices = tf.stack([batch_index, last_sequence_index], axis=1)
return tf.gather_nd(x, indices)
def _get_mask_for(
self,
tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
key: Text,
sub_key: Text,
) -> Optional[tf.Tensor]:
if key not in tf_batch_data or sub_key not in tf_batch_data[key]:
return None
sequence_lengths = tf.cast(tf_batch_data[key][sub_key][0], dtype=tf.int32)
return rasa_layers.compute_mask(sequence_lengths)
def _get_sequence_feature_lengths(
self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], key: Text
) -> tf.Tensor:
"""Fetches the sequence lengths of real tokens per input example.
The number of real tokens for an example is the same as the length of the
sequence of the sequence-level (token-level) features for that input example.
"""
if key in tf_batch_data and SEQUENCE_LENGTH in tf_batch_data[key]:
return tf.cast(tf_batch_data[key][SEQUENCE_LENGTH][0], dtype=tf.int32)
batch_dim = self._get_batch_dim(tf_batch_data[key])
return tf.zeros([batch_dim], dtype=tf.int32)
def _get_sentence_feature_lengths(
self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], key: Text
) -> tf.Tensor:
"""Fetches the sequence lengths of sentence-level features per input example.
This is needed because we treat sentence-level features as token-level features
with 1 token per input example. Hence, the sequence lengths returned by this
function are all 1s if sentence-level features are present, and 0s otherwise.
"""
batch_dim = self._get_batch_dim(tf_batch_data[key])
if key in tf_batch_data and SENTENCE in tf_batch_data[key]:
return tf.ones([batch_dim], dtype=tf.int32)
return tf.zeros([batch_dim], dtype=tf.int32)
@staticmethod
def _get_batch_dim(attribute_data: Dict[Text, List[tf.Tensor]]) -> int:
# All the values in the attribute_data dict should be lists of tensors, each
# tensor of the shape (batch_dim, ...). So we take the first non-empty list we
# encounter and infer the batch size from its first tensor.
for key, data in attribute_data.items():
if data:
return tf.shape(data[0])[0]
return 0
def _calculate_entity_loss(
self,
inputs: tf.Tensor,
tag_ids: tf.Tensor,
mask: tf.Tensor,
sequence_lengths: tf.Tensor,
tag_name: Text,
entity_tags: Optional[tf.Tensor] = None,
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
tag_ids = tf.cast(tag_ids[:, :, 0], tf.int32)
if entity_tags is not None:
_tags = self._tf_layers[f"embed.{tag_name}.tags"](entity_tags)
inputs = tf.concat([inputs, _tags], axis=-1)
logits = self._tf_layers[f"embed.{tag_name}.logits"](inputs)
# should call first to build weights
pred_ids, _ = self._tf_layers[f"crf.{tag_name}"](logits, sequence_lengths)
loss = self._tf_layers[f"crf.{tag_name}"].loss(
logits, tag_ids, sequence_lengths
)
f1 = self._tf_layers[f"crf.{tag_name}"].f1_score(tag_ids, pred_ids, mask)
return loss, f1, logits
def batch_loss(
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
) -> tf.Tensor:
"""Calculates the loss for the given batch.
Args:
batch_in: The batch.
Returns:
The loss of the given batch.
"""
raise NotImplementedError
def batch_predict(
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
) -> Dict[Text, Union[tf.Tensor, Dict[Text, tf.Tensor]]]:
"""Predicts the output of the given batch.
Args:
batch_in: The batch.
Returns:
The output to predict.
"""
raise NotImplementedError