Introducing Lace: Bayesian tabular data analysis in Rust and Python

A machine learning tool optimized for human learning

Image: 'Fireworks on Dark Sky' Ryan Klaus

2023-09-21

Author

Baxter Eaves

[data science, open source]

Recently we released the source code of one of our core internal tools, Lace. Lace is a Bayesian tabular data analysis tool designed to make learning from your data as easy as possible. Learn more at lace.dev.

A tour of lace

Lace is a tool for analyzing tabular data — think pandas DataFrames. It is both generative and discriminative in that it has a generative model from which you can generate data and predict unobserved values. It natively handles continuous, categorical, and missing data, and can even do missing-not-at-random inference when the absence/presence of data has potential significance.

Our goal with Lace is to optimize for human learning by optimizing for transparency and speed of asking questions.

Whereas typical machine learning models are designed to learn a function linking inputs to outputs, Lace learns a joint probability distribution over the entire data set. Once you have the joint distribution you can trivially create conditional distributions through which you can ask and discover questions.

For most use cases, using Lace is as simple as dropping a pandas or polars dataframe into the Engine constructor and runing update.

import pandas as pd
import lace

df = pd.read_csv("my-data.csv")

engine = lace.Engine.from_df(df)
engine.update(1000)

Understanding statistical structure

But before we ask any questions, we like to know which questions we can answer. So, we'll ask which features are statistically dependent, i.e., which features are predictive of each other, using depprob.

# lace comes with an Animals and Satellites example dataset
from lace.examples import Satellites

sats = Satellites()
sats.clustermap("depprob", zmin=0, zmax=1)

Above: Dependence probability matrix. Each cell shows the probability that a dependence path exists between two features.

Above, each cell tells us the probabilty that two variables are statistically dependent (though that dependence might flow through one or more intermediate variables).

Predicition and likelihood evaluation

We can of course do prediction (regression or classification) using the predict command.

# Marginal distribution of orbital period
sats.predict("Period_minutes")
# (100.59185703181058, 1.4439934663361522)

# Add conditions (add as many as you like)
sats.predict("Period_minutes", given={"Class_of_Orbit": "GEO"})
# (1436.0404065183673, 0.8641390940629012)

# Condition on missing values
sats.predict("Class_of_Orbit", given={"longitude_radians_of_geo": None})
# ('LEO', 0.002252910143782927)

Note that calls to predict return two values: the prediction and a second number describing uncertainty (Jensen-Shannon divergence among the posterior samples' predictive distributions).

If you'd like to view the entire predictive distribution rather than just the most likely value (prediction), you can ask about the likelihood of values.

import numpy as np
import pandas as pd

xs = pd.Series(np.linspace(0, 1500, 20), name="Period_minutes")
sats.logp(xs)
logp
0-10.5922
1-6.48785
2-10.8964
3-10.8551
4-10.782
15-10.8064
16-10.8447
17-10.5111
18-8.90361
19-9.86572

Visualizing distributions for each posterior sample (often referred to as "states") allows us to get a nice view of uncertainty.

from lace.plot import prediction_uncertainty

prediction_uncertainty(
    sats, 
    "Period_minutes",
    given={"Class_of_Orbit": "GEO"},
    xs=pd.Series(np.linspace(1400, 1480, 500), name="Period_minutes")
)

Above: A visualization of uncertainty when predicting Period_minutes given a geosynchronous orbit class. The red line is the most likely value, the black line is the likelihood of Period_minutes over a range of values, and the gray lines represent the likleihoods emitted by a number of posterior samples.

Above, the red line is the prediction (the most likely value), the black line is the probability distribution for this prediction, and the gray lines are the distribution for each posterior sample, which tells you how certain the model is that it has captured the distribution (learn more about uncertainty quantification here).

Simulation

We can simulate values

sats.simulate(["Users"], n=5)
Users
0Commercial
1Military
2Commercial
3Military
4Government

Just like with logp and predict, we can add conditions for our simulations.

sats.simulate(
    ["Users", "Purpose"],
    given={
        "Class_of_Orbit": "LEO",
        "Launch_Site": "Taiyuan Launch Center"
    },
    n=5
)
UsersPurpose
0GovernmentTechnology Development
1GovernmentEarth Science
2CivilCommunications
3GovernmentTechnology Development
4CommercialCommunications

We can easily re-simulate the entire dataset. Lace is generally very good at generating tabular synthetic data, outperforming deep-learning-based approaches using Generative Adversarial Netowrks (GANs) and Tranformers (manuscript under review).

sats.simulate(sats.columns, n=sats.shape[0])

Row/Record similarity

We can also ask which records (rows) are similar in terms of model space. This frees us from having to come up with a distance metric that works well for mixed data types and missing data. It also provides more nuanced information that just looking at the values. To make this a bit more inuitive, we'll switch to an animals example, since people generally have a better sense of what makes animals similar than they do what makes satellites similar.

from lace.examples import Animals

animals = Animals()
animals.clustermap("rowsim", zmin=0, zmax=1)

Above: Row similarity of animals. Higher row similarity means the animals are closer in model space, meaning their features are modeled similarly.

This essentailly generates a data-drive taxonomy of animals for us. More similar animals will be modeled more similarly. If two animals have a row similarity of 1 it means their features are modeled identically.

We can ask for similarity given a specific context. Say that we only cared about similarity with respect to whether an animals swims.

animals.clustermap("rowsim", zmin=0, zmax=1, fn_kwargs={"wrt": ["swims"]})

Above: Row similarity with respect to how the 'swimming' feature is modeled.

Notice that there are two main clusters of animals, those that swims and those that do not. If we were just looking at the data similarity the values would all be either 0 or 1 because the swims feature is binary, but here we get more nuanced information. For example, within the animals that swim there are two similarity clusters. These are cluster of animals that swim, but we predict that they swim for different reasons.

Learn more

We've put a lot of love into lace, and there's a lot more that you can do with it than we've gone over here. To learn more visit lace.dev or checkout the github repository.

Get in touch

Redpoll is deliberate about the organizations with whom we choose to partner. If you’re interested in working with us, please fill out the brief form below and we’ll set up a time to connect!