Objective and Contribution

Proposed zero-shot learning for text classification, where we train models on a large corpus of sentences to learn the relationship between a sentence and sentence’s tags. The goal of zero-shot learning is to train models to have the ability to predict unseen classes. Therefore, the model learns to predict whether a given sentence is related to a tag or not instead of classifying the sentence to one of the training classes. This is a step towards general intelligence in NLP.

The contributions of the paper are as follows:

  1. Proposed zero-shot learning framework for text classification as a binary classification task to determine the relatedness between text and categories. We show that this framework can adapt to any number of categories and across datasets without the need to retrain or fine-tune

  2. Proposed three neural network that can be used for zero-shot classification

  3. Compared the accuracy of our zero-shot classifier on different datasets with the SOTA results obtained from models trained specifically on those datasets. Our models performed competitively

Dataset

We crawled over 4 million news headlines with their SEO tags. Our dataset has more than 300K unique SEO tags and each news article can have more than one SEO tags. The figure below showcase samples of our training data.

Note that zero-shot learning is different than multi-class or multi-label classification. Our goal is to train our model to predict whether a sentence is related to a given tag or not. The difference in classification type is illustrated in the figure below.

Test datasets

We split out source dataset above into train and test set and tested our trained model on the test set. Additionally, to assess our model’s ability to generalise (zero-shot learning), we tested our model on two other test sets: UCI News Aggregator and tweet classification. The UCI News Aggregator contains over 400K sentences covering four different categories: technology, business, medicine, and entertainment. The tweet classification dataset contains six different categories: business, health, politics, sports, technology, and entertainment.

Methodology

We experimented with three different architectures for zero-shot classification. Architecture 1 is a simple fully connected layer that takes in the concatenation of mean sentence embedding and tag embedding. Architecture 2 uses the LSTM model to encode the input sentence and output the context vector. The context vector is concatenated with the tag embedding and feed into a fully connected layer. Architecture 3 uses the LSTM model as well except at each time step, it takes in the concatenation of tag embedding and the word embedding. This is for the model to learn whether the word is related to the tag or not. The final hidden state is feed into a fully connected layer for final prediction. The figures below showcase the three architectures in order.

Results

As mentioned, we evaluated our 3 architectures on source test set, UCI News Aggregator, and tweet classification dataset. Firstly, our results on source test set, our architecture 1, 2, and 3 achieved 72%, 72.6%, and 74% accuracy respectively. We also evaluated the performance of our architectures on tags that are only in the source test set and not in the training set and the results are 78%, 76%, and 81% respectively.

Secondly, our results on UCI News Aggregator, our models achieve 61.73%, 63%, and 64.21% accuracy respectively. This is below the SOTA results of 94.75% on this dataset but given that our models haven’t seen a single sample from the UCI News Aggregator, the results are considered good.

Lastly, our results on tweet classification, our models achieve 64%, 53%, and 64.5% accuracy respectively. The best results were achieved using SVC and Naïve Bayes trained on the dataset, which scored 74% and 78%.

Conclusion and Future Work

The accuracy level of zero-shot classification is still far from supervised models but our models were able to perform better than random classification on datasets without seeing any related examples, showcasing our model’s ability to generalise.

Ryan

Ryan

Data Scientist

Leave a Reply