Text Generation with Recurrent Neural Networks (RNNs) : A Step-by-Step Guide
This article explains how Recurrent Neural Networks (RNNs) work for text generation. We’ll explore the problems with earlier approaches (Markov Chains and n-gram models), how RNNs solve those problems, how RNNs generate text, a simple Python implementation, and the limitations of RNNs.
1. Before RNNs: Problems with Markov Chains and n-grams
- Markov Chains: These models predict the next word based on the current word, but they don’t remember long-term context. For example, generating a coherent paragraph with Markov Chains is difficult because they only consider one or two preceding words.
- n-grams: These models improve on Markov Chains by looking at a fixed window of words (e.g., a 3-gram model looks at the last 3 words). However, like Markov Chains, they struggle with long sequences since they can’t remember what happened beyond a fixed window.
- Main Limitation: Both methods lack the ability to remember long-term dependencies in a sentence. This is especially problematic for tasks like generating grammatically correct or meaningful text over several sentences.
2. What is RNN?
- Recurrent Neural Networks (RNNs): RNNs are neural networks that are designed to work with sequential data, such as text, time series, or audio. The key feature of RNNs is that they have loops in their architecture, allowing them to retain a “memory” of previous inputs, which makes them capable of learning long-term dependencies.
- How RNN Works: In an RNN, each word in a sequence is processed one by one. The network passes information from one step to the next, allowing it to remember the entire sequence. At each step, the hidden state is updated based on the current input and the previous hidden state, which makes it possible for the RNN to generate text by considering not only the current word but also the words that came before it.
- ANN vs RNN:
Let us discuss simple RNN language model.
Components of an RNN Cell
An RNN cell is the core building block of a Recurrent Neural Network. It processes one word or element at a time while retaining information from previous steps (memory). The key components of an RNN cell are:
- Input (xₜ): The data point or word that the RNN processes at the current time step. For example, in text generation, this could be a single word.
- Hidden State (hₜ): This represents the “memory” or context from previous time steps. It is updated at each step and helps the RNN remember information about previous inputs.
- Bh: bias added to the calculation of the hidden state (hₜ)
- Weights (Wxh, Whh, V):
- Wxh: The weight applied to the current input (xₜ).
- Whh: The weight applied to the previous hidden state (hₜ₋₁).
- Why: The weight used to compute the output.
- Activation Function: Often, the tanh or ReLU function is used to introduce non-linearity, allowing the RNN to learn complex patterns.
- By : bias is added to the calculation of the output (ŷₜ)
- Output (ŷₜ): The predicted value or output for the current time step based on the input and hidden state.
What Happens Inside the RNN Cell?
At each step:
- The cell takes the current input (xₜ) and the previous hidden state (hₜ₋₁).
- It computes the new hidden state (hₜ) by applying weights and an activation function.
- The new hidden state is passed to the next step.
- Optionally, an output is computed at each step.
3. Types of RNNs
RNNs come in various types, each designed to handle different types of sequence data:
- Vanilla RNN: The basic version of an RNN where the same function is applied at each step of the sequence. This is the one we’ll focus on for text generation.
- LSTM (Long Short-Term Memory): An advanced version of RNN that solves the vanishing gradient problem and can remember information for a long time.
- GRU (Gated Recurrent Unit): A simpler version of LSTMs that also addresses the vanishing gradient problem but with fewer parameters.
Even RNN can be classified into following 4 types as well.
In this article, we will focus on Vanilla RNNs, which are the simplest form of RNN.
4. How RNN Solves Earlier Problems
- Long-Term Dependencies: RNNs address the limitations of Markov Chains and n-grams by keeping track of long-term dependencies. Unlike fixed-window models, RNNs can remember information from earlier in the sequence, which allows them to generate more coherent text.
- Sequence Flexibility: RNNs can handle sequences of varying lengths. They don’t rely on a fixed window size like n-grams, so they can adjust to the length of the input text dynamically.
5. How RNN Generates Text
RNNs generate text by predicting one word at a time based on previous words. Here’s a breakdown of how it works:
- Training Phase: First, you train the RNN on a large text dataset. The RNN learns to predict the next word in a sequence based on the context provided by previous words.
- Text Generation: After training, the RNN can generate text by starting with a seed word. It predicts the next word, adds it to the sequence, and continues predicting the next word based on the previous sequence, gradually building the text.
6. Implementation: Text Generation Using Vanilla RNN in Python
Let’s implement a basic RNN to generate text. We’ll use TensorFlow/Keras and train the model to predict the next word in a sentence.
Step-by-Step Implementation
1. Import Libraries
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import SimpleRNN, Dense, Embedding
2. Prepare Data: Load and preprocess the dataset.
# Example text
data = "This is a simple example for text generation using RNN in Python. "
# Tokenize the text
tokenizer = Tokenizer()
tokenizer.fit_on_texts([data])
tokenizer.word_index
total_words = len(tokenizer.word_index) + 1
total_words
# output
13
# Create input sequences
input_sequences = []
for i in range(1, len(tokenizer.texts_to_sequences([data])[0])):
input_sequences.append(tokenizer.texts_to_sequences([data])[0][:i+1])
input_sequences
# Pad sequences
max_sequence_len = max([len(x) for x in input_sequences])
input_sequences = np.array(tf.keras.preprocessing.sequence.pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))
input_sequences
# Create predictors and labels
X, y = input_sequences[:,:-1], input_sequences[:,-1]
# check input
X
# check target
y
# convert target into categorical
y = tf.keras.utils.to_categorical(y, num_classes = total_words)
y
3. Build the RNN Model: Use a simple RNN layer.
model = Sequential()
model.add(Embedding(total_words, 100, input_length=max_sequence_len-1))
model.add(SimpleRNN(150, return_sequences=False))
model.add(Dense(total_words, activation='softmax'))
# Compile the model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# Train the model
model.fit(X, y, epochs=10, verbose=5)
# check model summary
model.summary()
4. Generate Text: Use the trained model to generate new text.
def generate_text(seed_text, next_words, max_sequence_len):
for _ in range(next_words):
token_list = tokenizer.texts_to_sequences([seed_text])[0]
token_list = tf.keras.preprocessing.sequence.pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')
predicted = model.predict(token_list, verbose=0)
predicted_word = tokenizer.index_word[np.argmax(predicted)]
seed_text += " " + predicted_word
return seed_text
seed_text = "This is"
print(generate_text(seed_text, 10, max_sequence_len))
7. Limitations of RNNs
- Vanishing Gradient Problem: When dealing with very long sequences, RNNs can suffer from the vanishing gradient problem, where the gradients become too small during backpropagation, making it difficult for the model to learn long-term dependencies effectively.
- Slow Training: RNNs process data sequentially, which makes them slower to train compared to models like Transformers, which can process data in parallel.
- Limited Memory: Although RNNs are better than n-grams, they still have trouble with extremely long-term dependencies (e.g., remembering something that happened 100 words ago in a novel).
8. Conclusion
- RNNs vs. Earlier Models: RNNs are a major improvement over Markov Chains and n-grams because they can remember the long-term context and generate more meaningful text. They do this by maintaining a hidden state that carries information about the entire sequence, not just the last few words.
- Limitations: While RNNs are powerful, they struggle with very long sequences due to the vanishing gradient problem and slow training speeds. More advanced architectures like LSTMs and GRUs address these issues, but basic RNNs still offer a simple and effective way to generate text.
- Future of Text Generation: More advanced models like Transformers (e.g., GPT-3) have now surpassed RNNs for tasks like text generation, but understanding RNNs is still crucial for grasping the evolution of sequential models.
This hands-on guide helps you understand the basics of text generation with RNNs and provides an easy-to-follow implementation in Python.