Causal ai, which explores the integration of causal reasoning into machine learning
This article offers a practical introduction to the potential of causal graphs.
It is aimed at anyone who wants to understand more about:
- What are causal graphs and how do they work?
- A case study worked in Python illustrating how to build causal graphs
- How they compare to ML
- Key challenges and future considerations
The complete notebook can be found here:
Causal graphs help us separate causes from correlations. They are a key part of the causal inference/causal ML/causal ai toolbox and can be used to answer causal questions.
Often called a DAG (directed acyclic graph), a causal graph contains nodes and edges: edges link nodes that are causally related.
There are two ways to determine a causal graph:
- Expert domain knowledge
- Causal discovery algorithms
For now, we will assume that we have domain expert knowledge to determine the causal graph (we will cover causal discovery algorithms later).
The goal of ML is to classify or predict as accurately as possible given some training data. There is no incentive for an ML algorithm to ensure that the features it uses are causally linked to the target. There is no guarantee that the direction (positive/negative effect) and strength of each feature aligns with the true data generation process. ML will not take into account the following situations:
- Spurious correlations: Two variables that have a spurious correlation when they have a common cause, for example, high temperatures increase the number of ice cream sales and shark attacks.
- Confounding factors: A variable affects its treatment and outcome, for example, demand affects how much we spend on marketing and how many new customers sign up.
- Colliders: A variable that is affected by two independent variables, for example, Customer Service Quality -> User Satisfaction <- Company Size.
- Mediators: two variables linked (indirectly) through a mediator, e.g. regular exercise -> cardiovascular fitness (the mediator) -> general health
Because of these complexities and the black-box nature of ML, we cannot trust its ability to answer causal questions.
Given a known causal graph and observed data, we can train a structural causal model (SCM). An SCM can be thought of as a series of causal models, one per node. Each model uses a node as a target and its direct parents as features. If the relationships in our observed data are linear, an SCM will be a series of linear equations. This could be modeled using a series of linear regression models. If the relationships in our observed data are nonlinear, this could be modeled with a series of boosted trees.
The key difference from traditional ML is that an SCM models causal relationships and takes into account spurious correlations, confounders, colliders, and mediators.
It is common to use an additive noise model (ANM) for each non-root node (meaning it has at least one parent). This allows us to use a variety of machine learning algorithms (plus a noise term) to estimate each non-root node.
Y := f(X) + n
Root nodes can be modeled using a stochastic model to describe the distribution.
An SCM can be seen as a generative model that can generate new samples of data, allowing it to answer a variety of causal questions. Generates new data by sampling the root nodes and then propagating the data through the graph.
The value of an SCM is that it allows us to answer causal questions by calculating counterfactuals and simulating interventions:
- Counterfactuals: Using historically observed data to calculate what would have happened to y if we had changed xeg. What would have happened to the number of churns if we had reduced call wait times by 20% last month?
- Interventions: Very similar to counterfactuals (and are often used interchangeably), but interventions simulate what would happen in the future, for example, what will happen to the number of customers who churn if we reduce call wait time by 20 % next year?
There are several KPIs that the customer service team monitors. One of them is call waiting times. Increasing the number of call center staff will reduce call wait times.
But how will decreasing call wait times affect customer abandonment levels? And will this offset the cost of additional call center staff?
The data science team is asked to create and evaluate the business case.
The population of interest is customers who make an incoming call. The following time series data is collected daily:
In this example, we use time series data, but causal graphs can also work with customer-level data.
In this example, we use expert domain knowledge to determine the causal graph.
# Create node lookup for channels
node_lookup = {0: 'Demand',
1: 'Call waiting time',
2: 'Call abandoned',
3: 'Reported problems',
4: 'Discount sent',
5: 'Churn'
}total_nodes = len(node_lookup)
# Create adjacency matrix - this is the base for our graph
graph_actual = np.zeros((total_nodes, total_nodes))
# Create graph using expert domain knowledge
graph_actual(0, 1) = 1.0 # Demand -> Call waiting time
graph_actual(0, 2) = 1.0 # Demand -> Call abandoned
graph_actual(0, 3) = 1.0 # Demand -> Reported problems
graph_actual(1, 2) = 1.0 # Call waiting time -> Call abandoned
graph_actual(1, 5) = 1.0 # Call waiting time -> Churn
graph_actual(2, 3) = 1.0 # Call abandoned -> Reported problems
graph_actual(2, 5) = 1.0 # Call abandoned -> Churn
graph_actual(3, 4) = 1.0 # Reported problems -> Discount sent
graph_actual(3, 5) = 1.0 # Reported problems -> Churn
graph_actual(4, 5) = 1.0 # Discount sent -> Churn
Next, we need to generate data for our case study.
We want to generate some data that will allow us to compare calculating counterfactuals using causal graphs versus ML (to keep things simple, ridge regression).
Since we identified the causal graph in the last section, we can use this knowledge to create a data generation process.
def data_generator(max_call_waiting, inbound_calls, call_reduction):
'''
A data generating function that has the flexibility to reduce the value of node 0 (Call waiting time) - this enables us to calculate ground truth counterfactualsArgs:
max_call_waiting (int): Maximum call waiting time in seconds
inbound_calls (int): Total number of inbound calls (observations in data)
call_reduction (float): Reduction to apply to call waiting time
Returns:
DataFrame: Generated data
'''
df = pd.DataFrame(columns=node_lookup.values())
df(node_lookup(0)) = np.random.randint(low=10, high=max_call_waiting, size=(inbound_calls)) # Demand
df(node_lookup(1)) = (df(node_lookup(0)) * 0.5) * (call_reduction) + np.random.normal(loc=0, scale=40, size=inbound_calls) # Call waiting time
df(node_lookup(2)) = (df(node_lookup(1)) * 0.5) + (df(node_lookup(0)) * 0.2) + np.random.normal(loc=0, scale=30, size=inbound_calls) # Call abandoned
df(node_lookup(3)) = (df(node_lookup(2)) * 0.6) + (df(node_lookup(0)) * 0.3) + np.random.normal(loc=0, scale=20, size=inbound_calls) # Reported problems
df(node_lookup(4)) = (df(node_lookup(3)) * 0.7) + np.random.normal(loc=0, scale=10, size=inbound_calls) # Discount sent
df(node_lookup(5)) = (0.10 * df(node_lookup(1)) ) + (0.30 * df(node_lookup(2))) + (0.15 * df(node_lookup(3))) + (-0.20 * df(node_lookup(4))) # Churn
return df
# Generate data
np.random.seed(999)
df = data_generator(max_call_waiting=600, inbound_calls=10000, call_reduction=1.00)sns.pairplot(df)
We now have an adjacency matrix representing our causal graph and some data. We use the gcm module of the Dowhy Python package to train an SCM.
It is important to think about what causal mechanism to use for root and non-root nodes. If you look at our data generation function, you will see that all the relationships are linear. Therefore, choosing ridge regression should be enough.
# Setup graph
graph = nx.from_numpy_array(graph_actual, create_using=nx.DiGraph)
graph = nx.relabel_nodes(graph, node_lookup)# Create SCM
causal_model = gcm.InvertibleStructuralCausalModel(graph)
causal_model.set_causal_mechanism('Demand', gcm.EmpiricalDistribution()) # Root node
causal_model.set_causal_mechanism('Call waiting time', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Non-root node
causal_model.set_causal_mechanism('Call abandoned', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Non-root node
causal_model.set_causal_mechanism('Reported problems', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Non-root node
causal_model.set_causal_mechanism('Discount sent', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Non-root
causal_model.set_causal_mechanism('Churn', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Non-root
gcm.fit(causal_model, df)
You can also use the auto-assign feature to automatically assign causal mechanisms instead of manually assigning them.
For more information about the gcm package, see the docs:
We also use ridge regression to help create a baseline comparison. We can look at the data generator again and see that it correctly estimates the coefficients for each variable. However, in addition to directly influencing attrition, call wait time indirectly influences attrition through abandoned calls, reported problems, and discounts sent.
When it comes to estimating counterfactuals, it will be interesting to see how SCM compares to ridge regression.
# Ridge regression
y = df('Churn').copy()
X = df.iloc(:, 1:-1).copy()
model = RidgeCV()
model = model.fit(X, y)
y_pred = model.predict(X)print(f'Intercept: {model.intercept_}')
print(f'Coefficient: {model.coef_}')
# Ground truth(0.10 0.30 0.15 -0.20)
Before we move on to calculating counterfactuals using causal graphs and ridge regression, we need a ground truth benchmark. We can use our data generator to create counterfactual samples after we have reduced call wait time by 20%.
We wouldn't be able to do this with real-world problems, but this method allows us to evaluate how effective the causal graph and ridge regression are.
# Set call reduction to 20%
reduce = 0.20
call_reduction = 1 - reduce# Generate counterfactual data
np.random.seed(999)
df_cf = data_generator(max_call_waiting=600, inbound_calls=10000, call_reduction=call_reduction)
Now we can estimate what would have happened if we had decreased the call wait time by 20% using our 3 methods:
- Ground truth (from data generator)
- Ridge regression
- causal graph
We see that the ridge regression significantly underestimates the impact on attrition, while the causal graph is very close to the ground truth.
# Ground truth counterfactual
ground_truth = round((df('Churn').sum() - df_cf('Churn').sum()) / df('Churn').sum(), 2)# Causal graph counterfactual
df_counterfactual = gcm.counterfactual_samples(causal_model, {'Call waiting time': lambda x: x*call_reduction}, observed_data=df)
causal_graph = round((df('Churn').sum() - df_counterfactual('Churn').sum()) / (df('Churn').sum()), 3)
# Ridge regression counterfactual
ridge_regression = round((df('Call waiting time').sum() * 1.0 * model.coef_(0) - (df('Call waiting time').sum() * call_reduction * model.coef_(0))) / (df('Churn').sum()), 3)
This was a simple example to start thinking about the power of causal graphs.
For more complex situations, several challenges that would need consideration:
- What assumptions are made and what is the impact of their violation?
- What if we don't have the expert knowledge to identify the causal graph?
- What happens if non-linear relationships exist?
- How harmful is multicollinearity?
- What if some variables have lagged effects?
- How can we deal with high dimensional data sets (many variables)?
All these points will be covered in future blogs.
If you are interested in learning more about causal ai, I recommend the following resources:
“Meet Ryan, an experienced lead data scientist with a specialized focus on employing causal techniques within business contexts, spanning marketing, operations and customer service. His competency lies in unraveling the complexities of cause and effect relationships to drive informed decision making and strategic improvements across various organizational functions.”