BERT Fake News Classifier

Continue on from yesterday, today we will move on to training our classifier and saving the fine-tuned BERT for future inference!

The main output of this pipeline is a trained BERT fake news classifier!

BERT Class

In [57]:
import torch
import torch.nn as nn
from transformers import BertModel

class BertClassifier(nn.Module):
  def __init__(self, freeze_bert = False):
    super(BertClassifier, self).__init__()

    D_in, H, D_out = 768, 50, 2

    # Bert layer
    self.bert = BertModel.from_pretrained('bert-base-uncased')

    # Linear layer with ReLU
    self.classifier = nn.Sequential(
        nn.Linear(D_in, H),
        nn.ReLU(),
        nn.Linear(H, D_out)
    )

    if freeze_bert:
      for param in self.bert.parameters():
        param.requires_grad = False
  
  def forward(self, input_ids, attention_mask):
    outputs = self.bert(input_ids = input_ids, attention_mask = attention_mask)
    first_hidden_state_cls = outputs[0][:, 0, :]

    logits = self.classifier(first_hidden_state_cls)

    return logits

Initialise BERT, Optimiser, Scheduler, and Loss Function

In [58]:
from transformers import AdamW, get_linear_schedule_with_warmup

# bert classifier
bert_classifier = BertClassifier()
bert_classifier.to(device)

# optimiser
optimizer = AdamW(bert_classifier.parameters(), lr = 5e-5, eps=1e-8)

epochs = 4
# scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0, num_training_steps = len(train_dataloader) * epochs)

# loss function
loss_fn = nn.CrossEntropyLoss()

Define the training and evaluation functions

In [60]:
def train(model, train_dataloader, val_dataloader=None, epochs = 4, evaluation = False):
  for epoch_i in range(epochs):
    print("EPOCH: %s" % epoch_i)
    t_epoch, t_batch = time.time(), time.time()

    total_loss = 0

    model.train()

    for step, batch in enumerate(train_dataloader):      
      b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)

      model.zero_grad()

      logits = model(b_input_ids, b_attn_mask)

      loss = loss_fn(logits, b_labels)
      total_loss += loss.item()

      loss.backward()

      torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

      optimizer.step()
      scheduler.step()
    
    avg_train_loss = total_loss / len(train_dataloader)

    if evaluation:
      val_loss, val_accuracy = evaluate(model, val_dataloader)
  
  print('Training Completed!')
In [59]:
import random
import time

def evaluate(model, val_dataloader):
  model.eval()

  val_accuracy = []
  val_loss = []

  for batch in val_dataloader:
    b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)

    with torch.no_grad():
      logits = model(b_input_ids, b_attn_mask)

      loss = loss_fn(logits, b_labels)
      val_loss.append(loss.item())

      # Get the predictions
      preds = torch.argmax(logits, dim=1).flatten()

      # Calculate the accuracy rate
      accuracy = (preds == b_labels).cpu().numpy().mean() * 100
      val_accuracy.append(accuracy)

  val_loss = np.mean(val_loss)
  val_accuracy = np.mean(val_accuracy)

  return val_loss, val_accuracy

Training

  • Set the seed value
In [61]:
seed_value = 21
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
In [62]:
train(bert_classifier, train_dataloader, val_dataloader, epochs = 4, evaluation = True)
EPOCH: 0
EPOCH: 1
EPOCH: 2
EPOCH: 3
Training Completed!

Inference on test dataset

In [66]:
import torch.nn.functional as F

def bert_predict(model, dataloader):
  model.eval()

  all_logits = []

  for batch in dataloader:
    b_input_ids, b_attn_mask = tuple(t.to(device) for t in batch)[:2]

    # Compute logits
    with torch.no_grad():
        logits = model(b_input_ids, b_attn_mask)
    all_logits.append(logits)
  
  all_logits = torch.cat(all_logits, dim = 0)

  probs = F.softmax(all_logits, dim = 1).cpu().numpy()

  return probs
In [64]:
test_df['combined'] = test_df['title'] + ' ' + test_df['text']
In [65]:
# Preparing the test data
test_inputs, test_masks = bert_preprocessing(test_df['combined'])

# Create the DataLoader for our test set
test_dataset = TensorDataset(test_inputs, test_masks)
test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=32)
In [67]:
# Compute predicted probabilities on the test set
test_probs = bert_predict(bert_classifier, test_dataloader)

# Get predictions from the probabilities
threshold = 0.9
preds = np.where(test_probs[:, 1] > threshold, 1, 0)

# Number of tweets predicted non-negative
print("Number of tweets predicted non-negative: ", preds.sum())
Number of tweets predicted non-negative:  2589
In [69]:
test_df['label'] = preds
test_df['label'].value_counts()
Out[69]:
0    2611
1    2589
Name: label, dtype: int64
In [70]:
test_df[['label', 'id']].to_csv('bert_fake_news_output.csv', index = False)

Save the tokenizer and trained BERT

In [80]:
tokenizer.save_pretrained('./bert-fake-news')
Out[80]:
('./bert-fake-news/vocab.txt',
 './bert-fake-news/special_tokens_map.json',
 './bert-fake-news/added_tokens.json')
In [73]:
torch.save(bert_classifier.state_dict(), './bert-fake-news/bert-fake-news-classifier.pt')
Ryan

Ryan

Data Scientist

Leave a Reply