Improving the accuracy of medical diagnosis using Causal Machine Learning

Poor access to healthcare and errors in differential diagnosis represents a significant challenge to global healthcare systems. In US alone, an estimated 5% of outpatients are misdiagnosed every year. For patients with serious medical conditions, an estimated 20% are misdiagnosed at the level of primary care, out of which 1/3rd of the cases results in serious patient harm.
If machine learning is to help overcome these challenges, it is important that we first understand how diagnosis is performed and clearly define the desired output of our algorithms. Existing approaches, like Bayesian model-based and Deep Learning approaches, have conflated diagnosis with associative inference. While the former involves determining the underlying cause of a patient’s symptoms, the latter involves learning correlations between patient data and disease occurrences, determining the most likely diseases in the population that the patient belongs to. While this approach is perhaps sufficient for simple causal scenarios involving single diseases, it places strong constraints on the accuracy of these algorithms when applied to differential diagnosis, where a clinician chooses from multiple competing disease hypotheses. Overcoming these constraints requires that we fundamentally rethink how we define diagnosis and how we design diagnostic algorithms.
Associative Diagnosis
Model-based diagnosis is using a model, parameterised by θ, to estimate the likelihood of a disease (D) given the findings (ε) which could include symptoms, test results and relevant medical history. Posterior probabilities are used to rank diseases for differential diagnosis. These algorithms could either be :
- Discriminative: like neural nets

- Generative: like Bayesian networks

Causal definition of diagnosis
The identification of the disease that are most likely to be causing the patient’s symptoms, given their medical history.
Using only the posterior to identify causal relations can lead to spurious conclusions in all but the simplest causal scenarios — confounding.

Example — for case b
An elderly smoker reports chest pain, nausea, and fatigue. A good doctor will present a diagnosis that is both likely and relevant given the evidence (such as angina). Although this patient belongs to a population with a high prevalence of emphysema, this disease is unlikely to have caused the symptoms presented and should not be put forward as a diagnosis. Emphysema is positively correlated with the patient’s symptoms, but this is primarily due to common causes
Example — for case c
A study found that asthmatic patients who were admitted to hospital for pneumonia were more aggressively treated for the infection, lowering the sub-population mortality rate. An associative model trained on this data to diagnose pneumonia will learn that asthma is a protective risk factor — a dangerous conclusion that could result in a less aggressive treatment regime being proposed for asthmatics, despite the fact that asthma increases the risk of developing pneumonia. In this example, the confounding factor is the unobserved level of care received by the patient.
Principles of Diagnostic Reasoning
To reason about the causal responsibility, that the probability that the occurrence of the effect S was due to the cause disease D, we require a diagnostic measure M(D, ε), which must satisfy the following properties:
- Consistency
The likelihood that a disease D is causing a patient’s symptoms should be proportional to the posterior likelihood of that disease - Causality
A disease D that cannot cause any of the patient’s symptoms cannot constitute a diagnosis - Simplicity
Diseases that explain a greater number of the patient’s symptoms should be more likely
The posterior only satisfies the first condition, violating the last two.
Counterfactual Diagnosis
Counterfactuals can test whether certain outcomes would have occurred had some precondition been different.
P(ε = e´| ε = e, do(X=x))
Given evidence ε=e we calculate the likelihood that we would have observed a different outcome ε=e´, had some hypothetical intervention, do(X=x), would have taken place.
Counterfactuals provide us with the language to quantify how well a disease hypothesis D = T explains symptom evidence S = T by determining the likelihood that the symptom would not be present if we were to intervene and ‘cure’ the disease by setting do(D = F), given by the counterfactual probability P(S = F ∣ S = T, do(D = F))
We define the following two counterfactual measures:
- Expected Disablement
It is the number of present symptoms that we would expect to “switch off” if we intervened to cure D. (derived from the notion of necessary cause)

- Expected Sufficiency
It is the number of positively evidenced symptoms we would expect to persist if we intervene to “switch off” all other possible causes of the patient’s symptoms. (derived from the notion of sufficient cause)

Here, D would be a sufficient cause of S. However, in case of multiple competing diseases, the presence of S does not imply the prior occurrence of D. If we cannot assume that a disease is a sufficient cause of S, the expected disablement should be used.
Structural Causal Models for Diagnosis

Bayesian Networks (BNs) are generally employed as statistical diagnostic models as they are interpretable and explicitly encode causal relation between variables. These models typically represent diseases, symptoms and risk factors as binary nodes. It species them in the form of a directed acyclic graph (DAG).
In the field of causal inference, BNs are replaced by the more fundamental Structural Causal Models (SCMs). They represent each variable as deterministic functions of their direct causes together with an unobserved exogenous ‘noise’ term, which itself represents all causes outside of our model.
Counterfactuals cannot in general be identified from data alone, and require modelling assumptions such as knowledge of the underlying structural equations
Noisy-OR twin diagnostic networks
Sometimes, is is necessary to make additional modelling assumptions beyond those implied by the DAG structure. Noisy-OR models are used then as they reflect basic intuitions about how diseases and symptoms are related, and allow for large BNs to be described by a number of parameters that grows linearly with the size of the network.
Under the noisy-OR assumption, a parent Di activates its child S (causing S = 1) if (i) the parent is on, Di = 1, and (ii) the activation does not randomly fail. The probability of failure, (λ _{Di, S}), is independent from all other model parameters. The ‘OR’ component of the noisy-OR states that the child is activated if any of its parents successfully activate it. Concretely, the boolean OR is used to estimate the value of s = ∨ f(di, ui), where the activation functions 𝑓(𝑑𝑖,𝑢𝑖)=𝑑𝑖∧𝑢¯, ∧ denotes the Boolean AND function, di ∈ {0, 1} is the state of a given parent Di and ui ∈ {0, 1} is a latent noise variable (𝑢¯𝑖:=1−𝑢𝑖) with a probability of failure 𝑃(𝑢𝑖=1)=𝜆_{𝐷𝑖,𝑆}.
Twin-networks represents real and counterfactual variables together in a single SCM, greatly amortising the inference cost of calculating counterfactuals compared to abduction, which is intractable for large SCMs.
Counterfactual vs Associative Rankings
The disease ranking is computed using the posterior for the associative algorithm, and the expected disablement and expected sufficiency for the counterfactual algorithms.

For k = 1, returning the top ranked disease, the counterfactual algorithm achieves a 2.5% higher accuracy than the associative algorithm. For k > 1 the performance of the two algorithms diverge, with the counterfactual algorithm giving a large reduction in the error rate over the associative algorithm. For k > 5, the counterfactual algorithm reduces the number of misdiagnoses by ~30% compared to the associative algorithm. This suggests that the best candidate disease is reasonably well identified by the posterior, but the counterfactual ranking is significantly better at identifying the next most likely diseases. These secondary candidate diseases are especially important in differential diagnosis for the purposes of triage and determining optimal testing and treatment strategies.
Comparison to Doctors
This experiment compares the counterfactual and associative algorithms to a cohort of 44 doctors. Each doctor is assigned a set of at least 50 vignettes (average 159), and returns an independent diagnosis for each vignette in the form of a partially ranked list of k diseases.

Overall, the associative algorithm performs on par with the average doctor, achieving a mean accuracy across all trails of 72.52 ± 2.97% vs 71.4 ± 3.01% for doctors. The algorithm scores higher than 21 of the doctors, draws with 2 of the doctors, and scores lower than 21 of the doctors. The counterfactual algorithm achieves a mean accuracy of 77.26 ± 2.79%, considerably higher than the average doctor and the associative algorithm, placing it in the top 25% of doctors in the cohort. The counterfactual algorithm scores higher than 32 of the doctors, draws with 1, and scores a lower accuracy than 12.
In summary, we find that the counterfactual algorithm achieves a substantially higher diagnostic accuracy than the associative algorithm. We find the improvement is particularly pronounced for rare diseases. While the associative algorithm performs on par with the average doctor, the counterfactual algorithm places in the upper quartile of doctors.