Skip to content

Commit

Permalink
Made use of drop-small-last-batch logic only possible in DIET and con…
Browse files Browse the repository at this point in the history
…figurable
  • Loading branch information
twerkmeister committed Nov 13, 2023
1 parent fa00c49 commit cc26633
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 18 deletions.
5 changes: 5 additions & 0 deletions rasa/nlu/classifiers/diet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.shared.nlu.training_data.message import Message
from rasa.utils.tensorflow.constants import (
DROP_SMALL_LAST_BATCH,
LABEL,
IDS,
HIDDEN_LAYERS_SIZES,
Expand Down Expand Up @@ -288,6 +289,9 @@ def get_default_config() -> Dict[Text, Any]:
# a few steps, as the compilation of the graph tends to take more time than
# running it. It is recommended to not adjust the optimization parameter.
RUN_EAGERLY: False,
# Determines whether the last batch should be dropped if it contains fewer
# than half a batch size of examples
DROP_SMALL_LAST_BATCH: False,
}

def __init__(
Expand Down Expand Up @@ -931,6 +935,7 @@ def train(self, training_data: TrainingData) -> Resource:
self.component_config[BATCH_STRATEGY],
self.component_config[EVAL_NUM_EXAMPLES],
self.component_config[RANDOM_SEED],
drop_small_last_batch=self.component_config[DROP_SMALL_LAST_BATCH],
)
callbacks = train_utils.create_common_callbacks(
self.component_config[EPOCHS],
Expand Down
1 change: 1 addition & 0 deletions rasa/utils/tensorflow/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,4 @@

USE_GPU = "use_gpu"
RUN_EAGERLY = "run_eagerly"
DROP_SMALL_LAST_BATCH = "drop_small_last_batch"
19 changes: 14 additions & 5 deletions rasa/utils/tensorflow/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def __init__(
epochs: int = 1,
batch_strategy: Text = SEQUENCE,
shuffle: bool = True,
drop_small_last_batch: bool = False,
):
"""Initializes the increasing batch size data generator.
Expand All @@ -353,6 +354,8 @@ def __init__(
epochs: The total number of epochs.
batch_strategy: The batch strategy.
shuffle: If 'True', data will be shuffled.
drop_small_last_batch: if 'True', the last batch in an epoch will be dropped
if it has less examples than half the batch size
"""
super().__init__(model_data, batch_size, batch_strategy, shuffle)

Expand All @@ -370,6 +373,7 @@ def __init__(
self._current_batch_size = 0
# create separate data variable that will store modified data for each batch
self._data: Data = {}
self.drop_small_last_batch = drop_small_last_batch
self.on_epoch_end()

def __len__(self) -> int:
Expand All @@ -381,11 +385,16 @@ def __len__(self) -> int:
# data was rebalanced, so need to recalculate number of examples
num_examples = self.model_data.number_of_examples(self._data)
batch_size = self._current_batch_size
# keep last batch only if it has at least half a batch size of examples
last_batch_half_full = num_examples % batch_size >= math.ceil(batch_size / 2)
num_batches = num_examples // batch_size + int(last_batch_half_full)
# Return at least 1 if there is an example
return max(num_batches, int(num_examples > 0))
if self.drop_small_last_batch:
# keep last batch only if it has at least half a batch size of examples
last_batch_half_full = num_examples % batch_size >= math.ceil(
batch_size / 2
)
num_batches = num_examples // batch_size + int(last_batch_half_full)
# Return at least 1 if there is an example
return max(num_batches, int(num_examples > 0))
else:
return num_examples // batch_size + int(num_examples % batch_size > 0)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Gets batch at position `index`.
Expand Down
5 changes: 5 additions & 0 deletions rasa/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def create_data_generators(
eval_num_examples: int = 0,
random_seed: Optional[int] = None,
shuffle: bool = True,
drop_small_last_batch: bool = False,
) -> Tuple[RasaBatchDataGenerator, Optional[RasaBatchDataGenerator]]:
"""Create data generators for train and optional validation data.
Expand All @@ -313,6 +314,8 @@ def create_data_generators(
eval_num_examples: Number of examples to use for validation data.
random_seed: The random seed.
shuffle: Whether to shuffle data inside the data generator.
drop_small_last_batch: whether to drop the last batch if it has fewer than half
a batch size of examples
Returns:
The training data generator and optional validation data generator.
Expand All @@ -328,6 +331,7 @@ def create_data_generators(
epochs=epochs,
batch_strategy=batch_strategy,
shuffle=shuffle,
drop_small_last_batch=drop_small_last_batch,
)

data_generator = RasaBatchDataGenerator(
Expand All @@ -336,6 +340,7 @@ def create_data_generators(
epochs=epochs,
batch_strategy=batch_strategy,
shuffle=shuffle,
drop_small_last_batch=drop_small_last_batch,
)

return data_generator, validation_data_generator
Expand Down
41 changes: 28 additions & 13 deletions tests/nlu/classifiers/test_diet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,24 +971,35 @@ async def test_no_bilou_when_entity_recognition_off(

@pytest.mark.timeout(120, func_only=True)
@pytest.mark.parametrize(
"batch_size, expected_num_batches",
"batch_size, expected_num_batches, drop_small_last_batch",
# the training dataset has 48 NLU examples
[
(1, 48),
(8, 6),
(15, 3),
(16, 3),
(18, 3),
(20, 2),
(32, 2),
(64, 1),
(128, 1),
(256, 1),
(1, 48, True),
(8, 6, True),
(15, 3, True),
(16, 3, True),
(18, 3, True),
(20, 2, True),
(32, 2, True),
(64, 1, True),
(128, 1, True),
(256, 1, True),
(1, 48, False),
(8, 6, False),
(15, 4, False),
(16, 3, False),
(18, 3, False),
(20, 3, False),
(32, 2, False),
(64, 1, False),
(128, 1, False),
(256, 1, False),
],
)
async def test_dropping_of_last_partial_batch(
batch_size: int,
expected_num_batches: int,
drop_small_last_batch: bool,
create_diet: Callable[..., DIETClassifier],
train_and_preprocess: Callable[..., Tuple[TrainingData, List[GraphComponent]]],
):
Expand All @@ -1012,7 +1023,9 @@ async def test_dropping_of_last_partial_batch(
)

model_data = diet.preprocess_train_data(training_data)
data_generator, _ = train_utils.create_data_generators(model_data, batch_size, 1)
data_generator, _ = train_utils.create_data_generators(
model_data, batch_size, 1, drop_small_last_batch=drop_small_last_batch
)

assert len(data_generator) == expected_num_batches

Expand Down Expand Up @@ -1041,6 +1054,8 @@ async def test_dropping_of_last_partial_batch_empty_data(
)

model_data = diet.preprocess_train_data(training_data)
data_generator, _ = train_utils.create_data_generators(model_data, 64, 1)
data_generator, _ = train_utils.create_data_generators(
model_data, 64, 1, drop_small_last_batch=True
)

assert len(data_generator) == 0

0 comments on commit cc26633

Please sign in to comment.