Isolating Problematic Data for Remediation and Retraining ML models
When a system has high dimensional data, troubleshooting the right data input regions becomes a difficult problem. Hotspots automates identifying regions associated with poor ML performance to significantly reduce time and error of finding such regions.
We might have a ML model deployed in production and some monitoring in place. We might notice that performance is degrading from classic performance metrics or from drift monitoring combined with explainability techniques. We’ve identified that our model is failing, and the next step is to identify why our model is failing.
This process would involve slicing and dicing our input data that caused model degradation. That is, we want to see which particular input regions are associated with poor performance and work on a solution from there, such as finding pipeline breaks or retraining our models on those regions.
This basically boils down to a time-consuming task of finding needles in a haystack. What if we could reverse engineer the process and surface all of the needles, i.e. input regions associated with poor performance, directly to the user?
We can! The steps we’ll take are
(1) Train a decision tree on a proper partitioning objective.
(2) Create and store hotspot tree artifact.
(3) Retrieve hotspots from the hotspot tree at query time.
In the toy example below that we’ll use throughout this post, we have two ground truth regions separated by a parabolic function, with blue above and red below the parabola. The color of the datapoints represent the predictions. We want to isolate the hotspot regions where the prediction color does not match the region color, which we do so in the pale boxes for two different accuracy thresholds.
(1) Train a decision tree on a proper partitioning objective
As soon as we think about partitioning data into regions of interest, we should think about tree models, and specifically a decision tree. Remember that our task is ultimately an inference task and not a prediction task, so there is no need to use an ensemble of trees model like random forests or XGBoost because (a) we’re not trying to perform predictions and (b) ensembles introduce noise and non-deterministic decision paths for splitting our data.
Recall that the premise of decision tree splits are based on selecting a feature and split value, among an enumeration over input features and their possible values, that minimizes impurity to create children that are more pure, based on the output labels.
In simple speak, let’s say output was color and we have blue and red marbles. All of the marbles have varying diameters across both groups, but blue marbles are textured while red marbles are smooth. If we had to choose between diameter size or texture to partition our marbles, we’d choose to partition our marbles based on texture, i.e. textured or smooth, since that would perfectly separate blue marbles into one group and red marbles into another, effectively reducing impurity to 0 in each group in this case.
In reality, a dataset would need multiple splits in order to reduce impurity to 0 in the leaf nodes.
So what is exactly the equivalent of the the blue and red marbles example above? We ultimately want to separate the bad predictions from the good predictions and need some metric as the output, i.e. partitioning objective, in our decision tree.
For classification, we can encode correct classifications as 1s and incorrect classifications as 0s. If we want more granularity while partitioning, we could also encode classifications as 1, 2, 3, or 4 for true positive, false positive, true negative, and false negative, respectively.
For regression, we actually need to encode the regression outputs, i.e. RMSE between ground truths and predictions, as classification outputs, e.g. a datapoint is encoded as 1 if RMSE is greater than median RMSE + 2 median average deviation of RMSE and 0 otherwise. We could also use a percentile rule, e.g. datapoints with RMSE over the 80% percentile are 1s and 0s otherwise. The reason why we do not use mean and standard deviation is because those values are skewed by high RMSE outliers, and the entire point is to partition datapoints with high RMSE compared to the average. This mimics the behavior we want in the classification case, and we’ll dive more in the mathematics for why we cannot use regression outputs, after the methodology is explained for the classification case.
(2) Create and store hotspot tree artifact
If we feed the 500 datapoints from our toy example into a decision tree, using our encoding method discussed above for four classes, the tree looks like this:
Here, we only have two features: X and X. At each node in a decision tree, the data is split into two children nodes based on a feature and cutoff value. As an example, at the root node, the 500 datapoints are split into two groups, a left child group and right child group, where X ≤ -2.258 and X >-2.258, respectively. We can accumulate the rules along any particular path from the root node to any child node.
We can also compute performance metrics like accuracy, precision, recall, and F1 on the data in the node.
(3) Retrieve hotspots from the hotspot tree at query time
Now that we have our hotspot tree, let’s pick some hotspots! Notice in Fig2 that we have accuracy thresholds of .3 and .5. In the latter case, the hotspot regions are wider and accidentally capture more correctly classified points. As the threshold decreases, we are less tolerant with wider regions that accidentally capture correct classifications. Normally, we might think that we always want lower thresholds to capture only misclassified datapoints, but that does run the risk of (a) making the regions incredibly small and not interpretable and (b) isolating many regions that might not contain many datapoints, requiring lots of manual work to investigate.
So how were those hotspots retrieved?
When an accuracy threshold is sent in a query to extract hotspots from the hotspot tree, we traverse along all possible paths from the root node. At any node in the traversal, if a node violates the threshold, that node is defined as a hotspot and all of the information in that node is appended to a list of hotspots returned to the user.
Specifically, in our example, accuracy is our metric. At any particular node, if that node’s accuracy is less than the threshold, we know that the datapoints in that node collectively have violated the threshold and that node is a hotspot. Whenever a node is identified as a hotspot, the traversal along that path stops, since downstream nodes would be more pure and the nodes with poor performance are in even more extreme violation of the user-provided metric and threshold.
Of course, what’s great about defining the metric and threshold at query time is that a user can requery with different metrics and threshold combinations, depending on the question and tolerance level for the threshold.
Using our example with the .5 threshold, our three hotspots are
Hotspots contain the filters on the input data regions that can be applied in fetching the entire dataset for further examination, research, and model development.
As promised, answers to the hanging question about regression above, plus some deeper aspects!
Deep Dive: (1) Train a decision tree on a proper partitioning objective
So why not regression?
The naive setup would be to take the RMSE between ground truths and predictions in our regression task and use them as our partitioning objective. However, the regression criterion to split the data in a node is based on some distance metric, e.g. MSE, to the mean output value of the data the node, which in this case are the RMSE values themselves. When we traverse the hotspot tree, we’d flag nodes with MSE above a certain threshold as hotspots.
Let’s say we have a majority of RMSE values around . 05 and a few RMSE values around .10, where we want to flag and isolate the latter. It’s totally possible that a regression tree might group more of the .10 RMSE datapoints lower in the tree and we’d never be able to find them according to the stopping role once a hotspot is detected. For example, one path might be .06 MSE, .09 MSE, .04 MSE. Given a user threshold between .06 and .09 MSE, we would never reach the third node in this case since .09 > .04, and that third node could contain a lot of .10 RMSE datapoints (the low .04 value comes from the fact that there are many .10 RMSE datapoints and only a few .05 RMSE points, so the node is “pure” in the sense that the RMSE values are close together).
We could fix this with a custom regression criterion that uses distances against 0, i.e. the raw RMSE values, instead of the MSE distance against the average RMSE in a node, but that introduces an issue of standardization across models and data rollups. What if a RMSE value of .10 really isn’t bad for one model but is for another? As such, it’s more robust to convert the regression setting into a classification setting.
If we have a streaming (or batch) model, how much data do we put into creating a hotspot tree at one time? If we created a hotspot tree on last week’s (or a previous batch’s) data and now see incoming data, do we append those datapoints and retrain the previous tree or do we create a new hotspot tree for this week’s (or current batch’s) data?
There’s certainly no wrong way to implement this, but at Warrior AI, we take the latter approach. If we have an orange tree and now have some apples that arrive in a box, we’ll probably be interested in the apple tree that generated those apples to figure out why we received some rotten apples, not the orange tree that we’ve already inspected last week (or batch).
Notice how the inputs to the decision tree can really be anything, including metadata not used as inputs to the model that created the original predictions! This means surfacing insights via hotspots is not constrained only to model inputs, which can have benefits if we are tracking sensitive non-input attributes in a model, like race or gender.
Deep Dive: (2) Create and store hotspot tree artifact
Why all the metrics?
Accuracy is not king and users need different metrics depending on the task their ML model is solving. This is a good introduction to other performance metrics like precision and recall. After giving it a read, here are two practical examples that helped me when I first learned about precision and recall.
(1) Consider an anomaly detection system for a bot that scrapes financial data from financial reports. It would be terrible if we have many false positives from a ML model, which means that the model would predict that wrong information (positives) exists for many documents that actually are correct (false), leading to a remediation team wasting time on tasks that are not incorrect (and it really wouldn’t be highly problematic if some false negatives slipped through the cracks). In this case, high precision is important.
(2) Consider a cancer detection model. It would be terrible if we have many false negatives from a ML model, which means that the model would predict no cancer (negatives) for many people who indeed have cancer (false), leading to a lack of recommended treatment and further health complications for those individuals. In this case, high recall is important.
Multiclassification and Micrometrics
How do we generalize to multiclassification tasks? Accuracy is the still the same, and we can use weighted precision, recall, and F1 Score, weighted by ground truth size, under global metrics in the toy example below. Notice how instead of just macrometrics, we also can now define micrometrics for precision, recall, and F1 Score, with respect to each ground truth class. This can be powerful, as now hotspot trees can be traversed on specific classes if, e.g. the bird class is causing a lot of model failures for an object detection model and we want to figure out what’s going on for bird images in particular.
That’s it for today! Hope you gained some insights about how to implement hotspots surfacing for your particular use case! We implement these kinds of systems at Warrior, and automation is an important product category in ML monitoring for the customer experience.