Machine learning models can fail when they try to make predictions for people who were underrepresented in the data sets they were trained on.
For example, a model that predicts the best treatment option for someone with a chronic disease can be trained using a data set containing primarily male patients. That model could make incorrect predictions for female patients when they are admitted to a hospital.
To improve results, engineers can try to balance the training data set by removing data points until all subgroups are represented equally. While balancing the dataset is promising, it often requires removing a large amount of data, which hurts the overall performance of the model.
MIT researchers developed a new technique that identifies and removes specific points in a training data set that contribute most to a model's failures in minority subgroups. By removing many fewer data points than other approaches, this technique maintains the overall accuracy of the model while improving its performance with respect to underrepresented groups.
Additionally, the technique can identify hidden sources of bias in a training data set that lacks labels. Unlabeled data is much more common than labeled data for many applications.
This method could also be combined with other approaches to improve the fairness of machine learning models deployed in high-risk situations. For example, it could one day help ensure that underrepresented patients are not misdiagnosed due to a biased ai model.
“Many other algorithms that try to address this problem assume that each data point is as important as any other. In this article, we show that that assumption is not true. “There are specific points in our data set that contribute to this bias, and we can find those data points, remove them, and get better performance,” says Kimia Hamidieh, a graduate student in electrical engineering and computer science (EECS) at MIT and co. . -main author of a article about this technique.
He wrote the paper with co-lead authors Saachi Jain PhD '24 and fellow EECS graduate student Kristian Georgiev; Andrew Ilyas MEng '18, PhD '23, Stein Fellow, Stanford University; and senior authors Marzyeh Ghassemi, associate professor at EECS and member of the Institute of Medical Engineering Sciences and the Information and Decision Systems Laboratory, and Aleksander Madry, professor of Cadence Design Systems at MIT. The research will be presented at the Neural Information Processing Systems Conference.
Eliminating bad examples
Machine learning models are often trained using huge data sets collected from many sources on the Internet. These data sets are too large to be carefully hand-picked, so they may contain bad examples that hurt model performance.
Scientists also know that some data points affect a model's performance on certain downstream tasks more than others.
The MIT researchers combined these two ideas into an approach that identifies and eliminates these problematic data points. They seek to solve a problem known as worst-group error, which occurs when a model underperforms on minority subgroups in a training data set.
The researchers' new technique is driven by previous work in which they introduced a method, called TRAKwhich identifies the most important training examples for a specific model output.
For this new technique, they take incorrect predictions the model made about minority subgroups and use TRAK to identify which training examples contributed the most to that incorrect prediction.
“By aggregating this information from incorrect test predictions in the right way, we can find the specific parts of the training that are reducing the overall accuracy of the worst group,” Ilyas explains.
They then remove those specific samples and retrain the model with the remaining data.
Since having more data generally produces better overall performance, removing only the samples that produce worse cluster failures maintains the overall accuracy of the model while improving its performance on minority subgroups.
A more accessible approach
On three machine learning data sets, their method outperformed multiple techniques. In one case, it increased the accuracy of the worst group while removing about 20,000 fewer training samples than a conventional data balancing method. Their technique also achieved greater accuracy than methods that require making changes to the internal workings of a model.
Because the MIT method involves changing a data set, it would be easier for a professional to use and can be applied to many types of models.
It can also be used when bias is unknown because subgroups of a training data set are not labeled. By identifying the data points that contribute the most to a feature the model is learning, they can understand the variables it uses to make a prediction.
“This is a tool that anyone can use when training a machine learning model. They can look at those data points and see if they are aligned with the capability they are trying to teach the model,” says Hamidieh.
Using the technique to detect bias from unknown subgroups would require intuition about which groups to look for, so the researchers hope to validate it and explore it more fully through future human studies.
They also want to improve the performance and reliability of their technique and ensure that the method is accessible and easy to use for professionals who might one day implement it in real-world settings.
“When you have tools that allow you to look critically at the data and figure out which data points will lead to biases or other undesirable behaviors, it gives you a first step toward building models that will be fairer and more reliable.” says Ilyas.
This work is funded, in part, by the National Science Foundation and the U.S. Defense Advanced Research Projects Agency.