-
Hello, I am trying to train an LSTM Network for the text classification tasks. I am trying to use the AG NEWS dataset available from Torchtext. The network is simple consisting of only a single LSTM layer (code below). The memory usage by the network is increasing through the training process after the completion of each batch and epoch. I am seeing this issue whenever I use LSTM layers with Flax. I have also designed models using Haiku but the issue is less over there. Please let me know if I am doing anything wrong below or memory increase is expected. I am trying a few different LSTM Models (Text classification, Text Generation, Time-Series, etc) but facing this issue everywhere. I am trying these models in kaggle kernels that have 16 GB RAM and still, it is using all of it and then failing. I am running all models on the CPU. System information
Problem you have encountered:Flax LSTM Layers memory usage increases over time. It eventually runs out of memory. What you expected to happen:Reduce Memory usage. Logs, error messages, etc:Out of RAM. Kagle Kernel Stops as memory usage is crossing 16 GB. Steps to reproduce:Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link. Please find below my simple model that I am trying to train using Flax. I am trying to use the AG NEWS dataset available from Torchtext. The network is simple consisting of only a single LSTM layer. ######### Model ############################ from flax import linen
embed_len = 50
hidden_dim = 75
class LSTMClassifier(linen.Module):
def setup(self):
self.embedding = linen.Embed(len(tokenizer.word_index)+1, embed_len, name="Word Embeddings")
LSTMLayer = linen.scan(linen.OptimizedLSTMCell,
variable_broadcast="params",
split_rngs={"params": False},
in_axes=1, out_axes=1,
length=max_tokens,
reverse=False)
self.lstm = LSTMLayer(name="LSTM")
self.linear1 = linen.Dense(len(target_classes), name="Dense1")
@linen.remat
def __call__(self, X_batch):
x = self.embedding(X_batch)
carry, hidden = linen.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(X_batch),), size=hidden_dim)
(carry, hidden), x = self.lstm((carry, hidden), x)
return self.linear1(x[:, -1]) ####### LOSS Function ############# def CrossEntropyLoss(params, input_data, actual):
logits_preds = model.apply(params, input_data)
one_hot_actual = jax.nn.one_hot(actual, num_classes=len(target_classes))
return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum() ########## Training Function ############### from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score
def TrainModelInBatches(X, Y, X_val, Y_val, epochs, params, optimizer_state, batch_size=32):
for i in range(1, epochs+1):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in tqdm(batches):
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss, gradients = value_and_grad(CrossEntropyLoss)(params, X_batch,Y_batch)
## Update Network Parameters
updates, optimizer_state = optimizer.update(gradients, optimizer_state)
params = optax.apply_updates(params, updates)
losses.append(loss) ## Record Loss
print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
gc.collect()
Y_val_preds = model.apply(params, X_val)
val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
print("Validation Accuracy : {:.3f}".format(val_acc))
gc.collect()
return params |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Hey @sunny2309! Its not totally clear that its the LSTM layer that is at fault here. I would suggest the following:
I'll be moving this into a discussion as its still unclear where the bug is. If the error still persists and you gather more information feel free to open another issue. |
Beta Was this translation helpful? Give feedback.
-
Hey @sunny2309! I was running into the same issues, my LSTM running incredibly slow and eventually getting killed by the oom-kill handler of the cluster I run my experiments in because of memory consumption. My Model is as follows:
I can't make this code snippet thing work properly to show the module in the correct way. Basically I have a similar architecture as you do and I don't set the length in the nn.scan(OptimizedLSTMCell) explicitly. Since I have a chaging number of max length in my data I decided to leave the Module to infer it automatically. This is the main issue, I had read something about Jax compiling the model every time so: I believe Jax is indeed compiling the LSTM each time you get a different max_length if you dont set it explicitly. So it will create 1 LSTM for all the different max_length we have (i.e 23, 45, 156, 208 etc etc), those values being the different time steps in the batch. So if a new training point has the same max length Jax will look into the cached lstm for that specific length and not compile it, there problem for us is that we have so many different lengths in the dataset that we are basically creating a whole LSTM for each one. This was confirmed once I looked into the perfomance metrics for CPU. The lines going up are before I discovered this. They would rise until the process was killed off. The two lines that are below are the same models with the issue fixed. You can either set the length attribute explicitly to your desired max_length, in that way you will only compile the model once and that's it. You would only have to pad the rest of your sequences so that each one has the same length size. In my case what I did is I left the Model to infer the length automatically but I set the max length from the dataloader directly. After fixing this, my models were training way way faster, (from 3hrs/epoch initially to ~20min/epoch) and I had no memory leak anymore. I apologize if the way I submitted this answer is not ideal, this is the first time I'm posting here. Hopefully this will help anyone working with LSTM in Flax as there are not that many examples out there. Kudos! |
Beta Was this translation helpful? Give feedback.
Hey @sunny2309! Its not totally clear that its the LSTM layer that is at fault here. I would suggest the following:
train_step
function as shown in the Annotated MNIST example. This can help lower memory consumption.I'll be moving this into a discussion as its still unclear where the bug is. If the error still persists and you gather more information feel free to open another issue.