A powerful but little-known method for explainable ai through data summaries
Despite being a powerful tool for data summarization, the MMD-Critic method has a surprising lack of both usage and “coverage”. Perhaps this is because there are simpler and more established methods for data summarization (e.g., K-medoids, see (1) or, more simply, the Wikipedia page), or perhaps this is because there was no Python package for the method (until now). In any case, the results presented in the original document (2) justify a greater use than MMD-Critic currently has. Therefore, I will explain the MMD-Critic method here as clearly as possible. I have also published a open source python package with an implementation of the technique so you can use it easily.
Before we dive into the MMD-Critic method, it's worth looking at what exactly we're trying to accomplish. Ultimately, we want to take a dataset and find examples that are representative of the data (Prototypes), as well as examples of edge cases that can confuse our machine learning models (Reviews).
There are many reasons why this can be useful:
- We can get a nice summary view of our dataset by looking at both stereotypical and atypical examples.
- We can test models on reviews to see how they handle edge cases (this is, for obvious reasons, very important).
- While perhaps not as useful, we can use prototypes to create a naturally explainable K-means algorithm, where the prototype closest to the new data point is used to label it. This makes explanations simple, as we simply show the user the most similar data point.
- Further
You can see section 6.3 in This book For more information on the applications of this (and also for a decent explanation of MMD-Critic), but suffice it to say that finding these examples is useful for a wide variety of reasons. MMD-Critic allows us to do this.
Unfortunately, I cannot claim to have a hyper-rigorous understanding of maximum mean discrepancy (MMD), as such an understanding would require a strong background in functional analysis. If you do have such a background, you can find the article that introduced the measure here.
In simple terms, MMD is a way of determining the difference between two probability distributions. Formally, for two probability distributions P and Qwe define the MMD of the two as
Here, F It's anyone functional space — that is, any set of functions with the same domain and codomain. Note also that the notation x~P It means that we are trying unknown As if it were a random variable drawn from the distribution P – that is, unknown is described by P. This formula finds the largest difference in the expected values of unknown and AND when they are transformed by some function of our space F.
It may be a little difficult to understand, but here is an example. Suppose that unknown is Uniform(0, 1)
(i.e., a distribution that is equivalent to choosing a random number between 0 and 1), and AND is Uniform(-1, 1)
. Let us also leave F be a fairly simple family containing three functions: f(x) = 0, f(x) = xand f(x) = x²Iterating over each function in our space, we get:
- In it f(x) = 0 case, E(f(x)) when x ~ P is 0 since it doesn't matter what unknown We choose, f(x) will be 0. The same thing happens when x ~ Q. Therefore, we get an average discrepancy of 0
- In it f(x) = x case, we have E(f(x)) = 0.5 for the P case and 0 for the Q case, so our average discrepancy is 0.5
- In it f(x) = x² case, we observe that
So in case P, we get
and in case Q, we get
So our discrepancy in this case is also 0. The supremum over our function space is then 0.5, so that's our MMD.
Now you may notice some problems with our MMD. It seems to depend heavily on our choice of feature space and also seems very expensive (or even impossible) to compute for a large or infinite feature space. Not only that, but it also requires us to know our distributions. P and Qwhich is not realistic.
The last problem can be easily solved, since we can rewrite our MMD metric to use estimates of P and Q based on our dataset:
Here, our unknown's are our samples from the dataset extracted from Pand the andSamples are taken from Q.
The first two problems can be solved with a little bit of additional math. Without going into too much detail, it turns out that if F It's something called Reproduction of the Hilbert space of the kernel (RKHS), we know in advance what function our MMD is going to give us. That is, it is the following function, called witness function:
where to It is the core (internal product) associated with the RKHS¹Intuitively, this function “presences” the discrepancy between P and Q At the point unknown.
Therefore, we only need to choose a sufficiently expressive RKHS/kernel; usually, the RBF kernel is used which has the kernel function
Overall, this produces fairly intuitive results. Here, for example, is the plot of the witness function with the RBF kernel when estimated (in the same way as mentioned above, i.e. by replacing expectations with a sum) on two data sets drawn from Uniform(-0.5, 0.5)
and Uniform(-1, 1)
:
The code to generate the above graph is here:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as snsdef rbf(v1, v2, sigma=0.5):
return np.exp(-(v2 - v1) ** 2/(2 * sigma**0.5))
def comp_wit_fn(x, d1, d2):
return 1/len(d1) * sum((rbf(x, dp) for dp in d1)) - 1/len(d2) * sum((rbf(x, dp) for dp in d2))
low1, high1 = -0.5, 0.5 # Range for the first uniform distribution
low2, high2 = -1, 1 # Range for the second uniform distribution
# Generate data for the uniform distributions
data1 = np.random.uniform(low1, high1, 10000)
data2 = np.random.uniform(low2, high2, 10000)
# Generate a range of x values for which to compute comp_wit_fn
x_values = np.linspace(min(low1 * 2, low2 * 2), max(high1 * 2, high2 * 2), 100)
comp_wit_values = (comp_wit_fn(x, data1, data2) for x in x_values)
sns.kdeplot(data1, label=f'Uniform({low1}, {high1})', color='blue', fill=True)
sns.kdeplot(data2, label=f'Uniform({low2}, {high2})', color='red', fill=True)
plt.plot(x_values, comp_wit_values, label='Witness Function', color='green')
plt.xlabel('Value')
plt.ylabel('Density / Wit Fn')
plt.legend()
plt.show()
The idea behind MMD-Critic is now quite simple: if we want to find to prototypes, we need to find the set of prototypes that best matches the distribution of the original dataset given by its squared MMD. In other words, we wish to find a subset P of cardinality to from our dataset that minimizes MMD²(F, x, P)Without going into too much detail about why, the square MMD is given by
After finding these prototypes, we select the points where the hypothesized distribution of our prototypes is most different from the distribution of our dataset as critiques. As we have seen before, the difference between two distributions at a point can be measured by our witness function, so we simply find points that maximize their absolute value in the context of unknown and PIn other words, we define our critique “score” as
Or, in the most usable approximate form,
Then to find the desired amount of reviews, let's say metro Of them, we simply want to find the whole. do in size metro that maximizes
To promote the selection of more varied critiques, the paper also suggests adding a regularizer term that encourages the selected critiques to be as distant as possible. The regularizer suggested in the paper is the logarithmic determinant regularizer, although it is not mandatory. I won't go into much detail here because it is not essential, but the paper suggests reading (6)².
In this way we can implement a extremely naive MMD-Critical without regularization of criticism as follows (DO NOT use this):
import math
import itertoolsdef euc_distance(p1, p2):
return math.sqrt(sum((x - y) ** 2 for x, y in zip(p1, p2)))
def rbf(v1, v2, sigma=0.5):
return math.exp(-euc_distance(v1, v2) ** 2/(2 * sigma**0.5))
def mmd_sq(x, Y, sigma=0.5):
sm_xx = 0
for x in x:
for x2 in x:
sm_xx += rbf(x, x2, sigma)
sm_xy = 0
for x in x:
for y in Y:
sm_xy += rbf(x, y, sigma)
sm_yy = 0
for y in Y:
for y2 in Y:
sm_yy += rbf(y, y2, sigma)
return 1/(len(x) ** 2) * sm_xx \
- 2/(len(x) * len(Y)) * sm_xy \
+ 1/(len(Y) ** 2) * sm_yy
def select_protos(x, n, sigma=0.5):
min_score, min_sub = math.inf, None
for subset in itertools.combinations(x, n):
new_mmd = mmd_sq(x, subset, sigma)
if new_mmd < min_score:
min_score = new_mmd
min_sub = subset
return min_sub
def criticism_score(criticism, prototypes, x, sigma=0.5):
return abs(1/len(x) * sum((rbf(criticism, x, sigma) for x in x))\
- 1/len(prototypes) * sum((rbf(criticism, p, sigma) for p in prototypes)))
def select_criticisms(x, P, n, sigma=0.5):
candidates = (c for c in x if c not in P)
max_score, crits = -math.inf, ()
for subset in itertools.combinations(candidates, n):
new_score = sum((criticism_score(c, P, x, sigma) for c in subset))
if new_score > max_score:
max_score = new_score
crits = subset
return crits
The above implementation is so impractical that when I ran it, I was unable to find 5 prototypes in a dataset with 25 points in a reasonable time. This is because our MMD calculation is O(max(|x|, |Y|)²)and the iteration over each subset of length n is O(C(|x|, n)) (where C is the choice function), which gives us horrendous execution complexity.
Aside from using more efficient computational methods (e.g. using pure numpy/numexpr/matrix calculations instead of loops/whatever) and caching repeated calculations, there are some optimizations we can do on a theoretical level. First, the most obvious slowdown we have is the repetition of loops. C(|x|, n) subsets in our prototype and critique methods. Instead, we can use an approximation that runs through north times, greedily selecting the best prototype each time. This allows us to change our prototype selection code to
def select_protos(x, n, sigma=0.5):
protos = ()
for _ in range(n):
min_score, min_proto = math.inf, None
for cand in x:
if cand in protos:
continue
new_score = mmd_sq(x, protos + (cand), sigma)
if new_score < min_score:
min_score = new_score
min_proto = cand
protos.append(min_proto)
return protos
and similar for criticism.
There is another important lemma that makes this problem much more optimizable. It turns out that by changing our prototype selection to a minimization problem and adding a regularization term to the cost, we can compute the cost function very efficiently with matrix operations. I won’t go into much detail here, but you can check out the original paper for more details.
Now that we understand the MMD-Critic method, we can finally play with it! You can install it by running
pip install mmd-critic
The implementation in the package itself is much faster than the one presented here, so don't worry.
We can run a fairly simple example using blobs like this:
from sklearn.datasets import make_blobs
from mmd_critic import MMDCritic
from mmd_critic.kernels import RBFKerneln_samples = 50 # Total number of samples
centers = 4 # Number of clusters
cluster_std = 1 # Standard deviation of the clusters
x, _ = make_blobs(n_samples=n_samples, centers=centers, cluster_std=cluster_std, n_features=2, random_state=42)
x = x.tolist()
# MMD critic with the kernel used for the prototypes being an RBF with sigma=1,
# for the criticisms one with sigma=0.025
critic = MMDCritic(x, RBFKernel(1), RBFKernel(0.025))
protos, _ = critic.select_prototypes(centers)
criticisms, _ = critic.select_criticisms(10, protos)
Then, plotting the points and critiques leads us to…
You'll notice that I provided the option to use a separate kernel for prototype and critique selection. This is because I found that critique results, in particular, can be extremely Sensitive to the sigma hyperparameter. This is an unfortunate limitation of the MMD Critic method and kernel methods in general. I have generally had good results using a large sigma for prototypes and a smaller one for critiques.
Of course, we can also use a more complex dataset. Here, for example, the method used in MNIST is shown.³:
from sklearn.datasets import fetch_openml
import numpy as np
from mmd_critic import MMDCritic
from mmd_critic.kernels import RBFKernel# Load MNIST data
mnist = fetch_openml('mnist_784', version=1)
images = (mnist('data').astype(np.float32)).to_numpy() / 255.0
labels = mnist('target').astype(np.int64)
critic = MMDCritic(images(:15000), RBFKernel(2.5), RBFKernel(0.025))
protos, _ = critic.select_prototypes(40)
criticisms, _ = critic.select_criticisms(40, protos)
Which brings us to the next prototypes
and criticism
Pretty neat, huh?
And that's all about the MMD-Critic method. It's pretty simple in essence and nice to use, except for having to play with the Sigma hyperparameter. I hope the newly released Python package will give it more use.
For any inquiries, please contact [email protected]. All images are the property of the author unless otherwise noted.