Introducing Lace: Bayesian tabular data analysis in Rust and Python
A machine learning tool optimized for human learning
2023-09-21 by Baxter Eaves in [data science, open source]
Recently we released the source code of one of our core internal tools, Lace. Lace is a Bayesian tabular data analysis tool designed to make learning from your data as easy as possible. Learn more at lace.dev.
A tour of lace
Lace is a tool for analyzing tabular data — think pandas DataFrame
s. It is both generative and discriminative in that it has a generative model from which you can generate data and predict unobserved values. It natively handles continuous, categorical, and missing data, and can even do missing-not-at-random inference when the absence/presence of data has potential significance.
Our goal with Lace is to optimize for human learning by optimizing for transparency and speed of asking questions.
Whereas typical machine learning models are designed to learn a function linking inputs to outputs, Lace learns a joint probability distribution over the entire data set. Once you have the joint distribution you can trivially create conditional distributions through which you can ask and discover questions.
For most use cases, using Lace is as simple as dropping a pandas or polars dataframe into the Engine
constructor and runing update
.
import pandas as pd
import lace
df = pd.read_csv("my-data.csv")
engine = lace.Engine.from_df(df)
engine.update(1000)
Understanding statistical structure
But before we ask any questions, we like to know which questions we can answer. So, we'll ask which features are statistically dependent, i.e., which features are predictive of each other, using depprob
.
# lace comes with an Animals and Satellites example dataset
from lace.examples import Satellites
sats = Satellites()
sats.clustermap("depprob", zmin=0, zmax=1)
Predicition and likelihood evaluation
We can of course do prediction (regression or classification) using the predict
command.
# Marginal distribution of orbital period
sats.predict("Period_minutes")
# (100.59185703181058, 1.4439934663361522)
# Add conditions (add as many as you like)
sats.predict("Period_minutes", given={"Class_of_Orbit": "GEO"})
# (1436.0404065183673, 0.8641390940629012)
# Condition on missing values
sats.predict("Class_of_Orbit", given={"longitude_radians_of_geo": None})
# ('LEO', 0.002252910143782927)
Note that calls to predict
return two values: the prediction and a second number describing uncertainty (Jensen-Shannon divergence among the posterior samples' predictive distributions).
If you'd like to view the entire predictive distribution rather than just the most likely value (prediction), you can ask about the likelihood of values.
import numpy as np
import pandas as pd
xs = pd.Series(np.linspace(0, 1500, 20), name="Period_minutes")
sats.logp(xs)
logp | |
---|---|
0 | -10.5922 |
1 | -6.48785 |
2 | -10.8964 |
3 | -10.8551 |
4 | -10.782 |
⋮ | ⋮ |
15 | -10.8064 |
16 | -10.8447 |
17 | -10.5111 |
18 | -8.90361 |
19 | -9.86572 |
Visualizing distributions for each posterior sample (often referred to as "states") allows us to get a nice view of uncertainty.
from lace.plot import prediction_uncertainty
prediction_uncertainty(
sats,
"Period_minutes",
given={"Class_of_Orbit": "GEO"},
xs=pd.Series(np.linspace(1400, 1480, 500), name="Period_minutes")
)
Simulation
We can simulate values
sats.simulate(["Users"], n=5)
Users | |
---|---|
0 | Commercial |
1 | Military |
2 | Commercial |
3 | Military |
4 | Government |
Just like with logp
and predict
, we can add conditions for our simulations.
sats.simulate(
["Users", "Purpose"],
given={
"Class_of_Orbit": "LEO",
"Launch_Site": "Taiyuan Launch Center"
},
n=5
)
Users | Purpose | |
---|---|---|
0 | Government | Technology Development |
1 | Government | Earth Science |
2 | Civil | Communications |
3 | Government | Technology Development |
4 | Commercial | Communications |
We can easily re-simulate the entire dataset. Lace is generally very good at generating tabular synthetic data, outperforming deep-learning-based approaches using Generative Adversarial Netowrks (GANs) and Tranformers (manuscript under review).
sats.simulate(sats.columns, n=sats.shape[0])
Row/Record similarity
We can also ask which records (rows) are similar in terms of model space. This frees us from having to come up with a distance metric that works well for mixed data types and missing data. It also provides more nuanced information that just looking at the values. To make this a bit more inuitive, we'll switch to an animals example, since people generally have a better sense of what makes animals similar.
from lace.examples import Animals
animals = Animals()
animals.clustermap("rowsim", zmin=0, zmax=1)
We can ask for similarity given a specific context. Say that we only cared about similarity with respect to whether an animals swims
.
animals.clustermap("rowsim", zmin=0, zmax=1, fn_kwargs={"wrt": ["swims"]})
Notice that there are two main clusters of animals, those that swims and those that do not. If we were just looking at the data similarity the values would all be either 0 or 1 because the swims
feature is binary, but here we get more nuanced information. For example, within the animals that swim there are two similarity clusters. These are cluster of animals that swim, but we predict that they swim for different reasons.
Learn more
We've put a lot of love into lace, and there's a lot more that you can do with it than we've gone over here. To learn more visit lace.dev or checkout the github repository.
Get in touch
Provide your information here and we will reach out to you promptly.