Fine tuning a T5 text-classification model on colab

One of the most interesting recent developments in natural language processing is the T5 family of language models. These are transformer based sequence-to-sequence models trained on multiple different tasks. The main insight is that different tasks provide a broader context for tokens and help the model statistically approximate the natural language context of words.

Tasks can be specified by providing an input context, often with a “command” prefix and then predicting an output sequence. This google AI blog article contains an excellent description of how multiple tasks are specified.

Naturally, these models are attractive to anyone working on NLP. The drawback, however, is that they are large and require very significant computational resources, beyond the resources of most developers desktop machines. Fortunately, google makes their Colaboratory platform free to use (with some resource limitations). It is the ideal platform to experiment with this family of models in order to see how their perform on a specific application.

As I decided to do so, for a text-classification problem, I had to pierce together information from a few different sources and try a few different approaches. I decided to put together a mini tutorial of how to fine-tune a T5 model for text-classification. Running the tutorial requires a Google cloud account and a Cloud Storage bucket. Cloud storage also has a free tier which should be sufficient to run the examples.

For the purposes of the tutorial, I used a set of labeled stack-overflow questions. This is a small dataset (which poses over-fitting challenges) of stack overflow questions assigned to 3 classes according to the perceived quality of the question.

The classification problem was modelled as supplying the question Title and Body as input sequence, along with the prefix “quality:” and then predicting a token for each of the expected classes: “none”, “low”, “high”. This is a somewhat similar approach to a typical classification head applied on a BERT model [CLS] token; except that one uses the full decoder block as a classifier rather than a single Dense or CRF layer.

In my experiment, the model achieves a .936 class accuracy score, both on the validation set and on a held-out set not used for tuning. The confusion matrix on the full validation set is:

none low high
None 4336 0 664
Low 6 4994 0
High 284 0 4716

Where rows show the true values and columns predictions. E.g. there are 664 examples misclassified as High quality which should have been classified as None.

I initially expected the predominant source of error to be the fact that only 512 tokens of the question are considered, the first 256 which include the title and the last 256. Manually examining some of the misclassified examples, I get the impression that there are examples where one needs quite a bit of domain knowledge to understand what is a reasonable question and what isn’t; i.e. questions that are rooted on a basic misunderstanding but are otherwise well formulated.

In terms of executing the fine-tuning process on TPUs my most important conclusion is that if one wants to use TPUs for training the best approach is to use a encoded as .tfrecord. This allows the central CPU, which is restricted in terms of memory footprint to coordinate the iteration over the data but the actual tensor values can be read from Google Cloud Storage directly by the TPUs.

My second important learning is that the Tensorflow distributed training API included in the, evaluate and predict APIs does the job of coordinated the distributed training well. Attempting to manually replicate this support with pytorch XLA would be a non-trivial effort.

It is interesting to see the fragmentation in the ML ecosystem, around different ways to expect the DNN nodes themselves, which to my untrained eye look very similar (i.e. keras, torch, Flax/Jax); for these large models a lot of the non trivial effort comes in terms of the dataset and trainer API support for distributed training which is often non-trivial.


Microservices workflow orchestration

A recurring pattern in software architecture is the need to trigger a process or workflow that is implemented across multiple microservices and then report to the user the results when the process completes.

In a previous project, I faced this issue when building a SaaS application in the Intelligent Document Processing (IDP) space. The application was supposed to take a collection of scanned pages, split it in documents, and for each document perform several document understanding tasks. There is a mix of per-page-bundle, per-page and per-document processing steps.

Given the desire to develop each step independently and be able to scale the processing independently (e.g. page OCR consumes more resources than other tasks) I designed a system around a message bus (RabbitMQ) and individual workers that pull requests from message queues.

Unfortunately there aren’t a whole lot of easy to use solutions available for this type of design. Googling for “rabbitmq workflow orchestration” the most helpful link I get is for an article that recommends the use of BPMN for this type of design. That is rather centered in the Java ecosystem. For my use case I needed something that worked well in python and would be preferably language agnostic. I ended up building a custom solution for this company.

However as I have design conversations around system architecture topics I do often end up seeing scenarios where a similar tool would be desirable. That motivated me to start working on a new workflow orchestration project. It is still a work in progress but it is able to execute the kind of workflows that I’ve encountered in the past.

Workflows can be declared in simple yaml syntax as a set of steps with dependencies; sub-tasks can be triggered (as in my original requirement of processing a collection of pages, per page steps and documents) and workflows can mix services programmed in different programming languages.

There are existing open source orchestration tools to manage batch workflows such as Apache Airflow and Argo Workflows. In these systems, for each batch job, a new process instance is created and passed command line arguments that specify the workflow parameters.

This project provides similar functionality for online micro-services. For instance, in machine learning use cases it is common that loading the inference process takes in the order of 30s – 1m while processing a single user request is an operation in the order of 10-100ms. A system that is designed to process 10s or 100s of requests per minute can’t afford to use a batch approach for this type of system.

Instead all micro-services are pre-loaded and managed as standard online services, except that instead of receiving REST operations from a load-balancer, they receive requests and post responses from/to an AMQP message queue. This is done so that the logic of determining the next step in the workflow does not have to be distributed in the individual services. Debugging is also simplified as the workflow-manager is tracking the state of the user request.

I’m looking forward to getting some feedback. Drop me a line if you think it can be useful to any problem you are working on or have feature requests.

Playing games with Tensorflow

As a fun project, I recently built a web app to play checkers online against the computer. This post tries to outline the methodology I used. If you want to checkout the results, I would encourage you to try the web link above, change the difficulty level to ‘hard’ and play a round against the computer. You will be playing against a very simple neural network model that is, as far as I can tell, reasonably effective.

The standard approach to developing a game AI for something like board games is the “MiniMax” algorithm. Implementing “MiniMax” for a game like checkers is a relatively simple task; one needs to components:

  • A method of generating valid moves for a play given a board position;
  • A scorer function that evaluates the “goodness” of a given board position for a play;

There are multiple sets of possible rules for the game of checkers. I used the “Spanish draughts” rule set popular in Portugal: men move forward only; flying kings and mandatory moves on a 8×8 board. The minimax algorithm is independent of the particular rule-set used.

The scorer function must be able to look at a given player position and determine a score. The first approach I used was to simply count the different in pieces from both players, giving a 10x score to kings vs men. For a rule set with non-flying kings one would probably want to adjust these score factors.

The minimax algorithm implementation is rather straightforward. The algorithm builds a tree of depth N, of all possible player moves and counter-moves. It then assumes that each player selects the position that maximises its score. I used a depth of 4 moves by default which can be calculated in the order of 10-30k positions, assuming an average of 7-8 moves for each position. An important feature to note is that memoization reduces the number of required computations in approximately half.

Memoization is very effective since in the game, one very often has a set of moves that can happen in multiple orders. E.g. Assume that player 1 has initially the set of valid moves {m1, m2, …}; player 2 will often have a common subset of available moves after m1 and m2, let’s call them {m2a, …}. The sequences {m1, m2a, m2} and {m2, m2a, m1} will result in the same board contents; the minimax sub tree need only be evaluated once.

The minimax algorithm provides a reasonable initial solution for playing a game like checkers. It has issues however. For instance, at the start of the game, exploring the “minmax” tree with a naive scorer like the one described above, results in most of the moves having the same score. One can randomly select among these; but while the algorithm avoids playing really bad moves it doesn’t feel to the human opponent very challenging.

My first approach to address this issue was to get the computer to play 100,000 games against itself and then analyse the results.

This is the reason the board management, valid move selection and scorer where implemented in C++. The C++ implementation is 10x faster than my initial python prototype and can easily be integrated with the rest of the python code base using Cython wrappers. I find that after algorithmic optimisation (e.g. memoization in this case), converting the most performance sensitive code to C++ is very effective. Cython translates directly python lists to c++ vectors or lists and makes it easy to interface between the 2 languages.

Given a large number of games played, one can analyse the logs and determine that some initial move selections that appear equally good to the minimax algorithm actually tend to yield rather statistically different outcomes.

The trick here is to determine what is statically significant. In a game of checkers one can have 3 outcomes (tie, player 1 wins, player 2 wins); Given a specific board position and next player to move, I wrote code that aggregates all the moves from that position and the final outcomes. For some of these positions one has sufficient data that is statistically relevant to say that choosing one specific move increases the probability of success.

Once can start the process of determining what is useful statical information by establishing the probability of each of the outcomes of a given board configuration given all the games played (i.e. across all possible moves from that position); which I assumed to be statistically significant. This probabilities can then be used to model a multinomial distribution. For each move, one can then look at the probability mass function. An high probability mass function means that the tie/win statistics from the move could easily be explained by the position itself; a low probability mass function means that we should consider the probabilities given by the move statistics itself. An example implementation is here.

This statistical knowledge is useful in the opening moves of the game; the number of possible board positions is too high for it to be really useful later on. The move selection logic implementation that uses these game statistics to select between moves that the minimax algorithm considers preferable statistically beats the baseline algorithm. It basically learns what moves make for a good opening strategy.

The problem with this statistically based approach is that it only works for board positions for which there is sufficient information; this can be addressed by using a neural network. Neural networks seem to be very good at capturing statistics information and at generalising it to data points that have not been seen exactly.

The simplest way to apply a neural network is to train a scorer algorithm that given a specific board position can tell how close the player is from a given outcome (tie, win, loss). In order to do that I used the same 100k game logs and assigned for each board position a value for each of the outcomes: 0 if the game did not end in that outcome and [0 – 1] on a sigmoid that starts at 0 for the initial move and ends at 1 for the move that results in victory.

This assumes that we don’t quite know what the correct strategy is; but clearly a move that results in immediate victory is very very good; and board positions that get one closer to this victory goal are increasingly good.

Note that we are getting the network to extrapolate values. The training examples do not have an exact; we do not know for sure that a given board position will always result in victory (except if it is the immediate position before a winning move).

I ended up using a very simple 2-layer transformer model for this scorer. This model converges in about 10-epochs to a set of weights that provide an approximate evaluation of the “goodness” of a board position, even for positions that where never seem before. This is something that the statistical model is not capable of.

From a software engineering perspective, some of the interesting challenges came after having a model that performs reasonably well. One still wants to evaluate as many board positions as possible, in a minimax tree; except that the goal is to use the Neural Network as a scorer rather than our naive piece counting initial scoring algorithm. But of course, calling a tensorflow evaluation function once per board position is prohibitively expensive.

To get decent performance I ended up writing a different implementation of the minimax algorithm. This implementation allows the code to batch a set of positions that wants evaluated and then call the neural network once per each set of 32 examples. The trick here is that instead of using the call stack as in our initial implementation one has to build a tree structure that is populated with values asynchronously from the perspective of the minimax tree construction.

A couple of details that really surprised me:

  • Calling keras.Model.predict() has a very significant overhead (tensorflow 2.4.1); I had to define a function wrapped with @tf.function that calls the model directly in order to escape a bunch of library overhead.
  • Using accessor / slicing operations eager tensor in python is extremely slow; converting the tensor to a numpy array immediately after invoking the model and then using slicing based accessors makes a very significant difference.

The neural network scorer based minimax algorithm wins about 2/3 of the games against the baseline minimax plus move statistics. The main issue with the minimax algorithm as a player is that there is an “horizon effect”; the algorithm just can’t see consequences of a position beyond the chosen move window and does silly things like sacrificing pieces in order to delay an outcome past its visible window, such as the human player getting a king. The neural network based scorer doesn’t have that issue.

As a summary, I continue to be surprised by the ability of neural networks to gather statistical information and extrapolate to similar but unseen examples.

This project was developed in my laptop, without the use of any computation service to either play games or train the neural network. The game is deployed in a free tier cloud instance with no GPU support. This sort of project, would not so many years ago, require a much more significant investment in terms of both computational and human resources. This is now day to day engineering work.

Authenticated access to Kubernetes pods

When running a micro-services style application in a public cloud, one of the problems to solve is how to provide access to debug information. At Laserlike, we run our application stack on GKE. Most of the stack consists of golang Pods that run an HTTP listener that serves /debug and /metrics handlers.

For metrics scrapping we use prometheus; and grafana for visualization. Our grafana server is nodePort service behind a GCE Load Balancer which uses oauth2 based authentication for access. This still leaves a gap in terms of access to the pod debug information such as /debug/vars or /debug/pprof.

In order to address this gap, we created a simple HTTP proxy for kubernetes services and endpoints. We deploy this proxy behind a oauth2 authenticator which is then exposed via an external load balancer.

The service proxy uses the kubernetes client library in order to consume annotations on the service objects. For example, the following annotation, instructs the service proxy to expose the debug port of the endpoints of the specified service:

    k8s-svc-proxy.local/endpoint-port: "8080"

The landing page on the proxy then displays a set of endpoints: