Baxter Eaves
2020-01-21
[news]

Announcing rv 0.8.0

A rust crate for building probabilistic programming tools



RV is a statistical toolbox for the Rust programming language designed for our work in probabilistic AI and machine learning. In this post I'll discuss the history of RV and the Rust programming language at Redpoll, RV's design principles, and walk through a few simple examples.

Rust at Redpoll

At Redpoll, we use Rust for everything in the back end including our AI and services. Safety is our selling point and what's not to love about a safe, performant modern language with a great type system? Well, there's the common problem with all young languages: lack of libraries. We're a research organization, so we don't use much off-the-shelf machine learning code in our product, but still, coming from a python background it sure hurts not to be able to drop scipy and numpy into a project for easy access to scientific computing. What hurt most for the type of machine learning we do was the lack of statistical tools. If I was going to use Rust I'd have to write my own.

RV

RV — which stands for Random Variables — is a crate for random variables in rust. It's designed for developing probabilistic programming tools. This means that it needs to be flexible, general, and to simplify common statistical pattens.

Probability distributions

The first ingredient of any probabilistic programming toolbox is probability distributions. In rv 0.8.0 we've implemented 30 distributions. Distributions have different functionality depending on which traits they have implemented. At the minimum, each distribution must implement the Rv trait, which allows you to do the bare essentials to be useful in Monte Carlo. Let's create a standard Gaussian (Normal) distribution.

use rv::traits::Rv;
use rv::dist::Gaussian;

// Standard Normal distribution with mean 0 and stddev 1
let gauss = Gaussian::standard();

We can draw values with the draw method,

let mut rng = rand::thread_rng();

let x: f64 = gauss.draw(&mut rng);

and we can compute likelihoods with the f method.

let fx = gauss.f(&x);
println!("f({}) = {}", x, fx);

// -- Output --
// f(-0.41559754609242056) = 0.3659351330603808

The Rv trait has a sample method, which draws many values into a vector; and a ln_f method, which is the log likelihood.

We can also compute the distribution's descriptive statistics like mean, median, mode, etc:

use rv::traits::{Mean, Median, Mode, Variance};

let mean: f64 = gauss.mean().unwrap();
let median: f64 = gauss.median().unwrap();
let mode: f64 = gauss.mode().unwrap();
let variance: f64 = gauss.variance().unwrap();

println!("{} mean: {}", gauss, mean);
println!("{} median: {}", gauss, median);
println!("{} mode: {}", gauss, mode);
println!("{} variance: {}", gauss, variance);

// -- Output --
// N(μ: 0, σ: 1) mean: 0
// N(μ: 0, σ: 1) median: 0
// N(μ: 0, σ: 1) mode: 0
// N(μ: 0, σ: 1) variance: 1

You'll notice a lot of unwrapping. These methods return Option because some of these quantities are not always defined. For example the variance of an inverse gamma distribution is not always defined.

use rv::dist::InvGamma;

let ig_var = InvGamma::new(3.0, 1.0).unwrap();

// The inverse Gamma variance is undefined if the first parameter, shape, is
// less than 2.
let ig_no_var = InvGamma::new(1.0, 1.0).unwrap();

assert!(ig_no_var.variance().is_none());
assert!(ig_var.variance().is_some());

Conjugate analysis

We have built in a few examples of conjugate priors to enable simple conjugate analysis. For example, if we wanted to infer the unknown weight of a coin given a set of observations (flips) we could model the flipping of the coin as a Bernoulli process and our prior belief about the weights of coins as a Beta distribution. Beta is conjugate to Bernoulli, which means that the posterior distribution is also Beta. So we can make posterior inferences about the coin weight without fussing with a Bernoulli distribution at all.

In rv, we can implement the ConjugatePrior trait for Beta. The ConjugatePrior trait looks like this

pub trait ConjugatePrior<X, Fx>: Rv<Fx>
where
    Fx: Rv<X> + HasSuffStat<X>,
{
    // ...
}

where X is the type of data and Fx is the likelihood we're implementing the conjugate prior for. We see that the prior must implement Rv<Fx>, and Fx must itself implement Rv<X> (be a likelihood of X), and implement HasSuffStat<X>, which we'll discuss more later.

ConjugatePrior<bool, Bernoulli> and Rv<Bernoulli> are implemented for Beta. That's right, you can sample Bernoulli distributions from a Beta distribution.

// Beta(1, 1), which is uniform over (0, 1)
let beta = Beta::uniform();
let berns: Vec<Bernoulli> = beta.sample(10, &mut rng);

In a conjugate analysis, we pack the data into an enum DataOrSuffStat because, for many distributions, we don't want or need to store the raw data, we can store a summary or sufficient statistic that is much more compact. This is why we need to implement HasSuffStat<X> for Fx. The DataOrSuffStat enum also tells the prior what type of likelihood these data are meant for. In this case, we have bool data for a Bernoulli likelihood.

use rv::dist::{Beta, Bernoulli};
use rv::data::DataOrSuffStat;
use rv::traits::ConjugatePrior;

let flips = vec![true, false, false, false];
let xs: DataOrSuffStat<bool, Bernoulli> = DataOrSuffStat::Data(&flips);

Now we initialize the Beta prior and simply ask it for the posterior distribution given the observations.

let prior = Beta::jeffreys();
let posterior: Beta = prior.posterior(&xs);
println!("Coin weight Posterior: {}", posterior);

let ml_weight: f64 = posterior.mode().unwrap();
println!("Coin weight ML: {}", ml_weight);

let var_weight: f64 = posterior.variance().unwrap();
println!("Coin weight variance: {}", var_weight);

// -- Output --
// Coin weight Posterior: Beta(α: 1.5, β: 3.5)
// Coin weight ML: 0.16666666666666666
// Coin weight variance: 0.035

Inference on Custom Types

Some distributions, namely those with support for categorical data, support user-defined data types by way of traits that convert those types to and from a primitive data type. For example, instead of representing coin flips as bool, we can go ahead and define a Coin enum. We will need to implement the Booleable trait, which requires that our type implements Copy.

use rv::data::Booleable;

// Explicitly represent each variant as a u8
#[repr(u8)]
#[derive(Clone, Copy, PartialEq)]
enum Coin {
    Tails,
    Heads
}

impl Booleable for Coin {
    fn from_bool(heads: bool) -> Coin {
        if heads {
            Coin::Heads
        } else {
            Coin::Tails
        }
    }

    fn try_into_bool(self) -> Option<bool> {
        Some(self == Coin::Heads)
    }
}

Now we can re-write the above example more explicitly using Coin

let flips = vec![Coin::Heads, Coin::Tails, Coin::Tails, Coin::Tails];

let prior = Beta::jeffreys();

let xs: DataOrSuffStat<Coin, Bernoulli> = DataOrSuffStat::Data(&flips);

let posterior: Beta = prior.posterior(&xs);
println!("Coin weight Posterior: {}", posterior);

let ml_weight: f64 = posterior.mode().unwrap();
println!("Coin weight ML: {}", ml_weight);

let var_weight: f64 = posterior.variance().unwrap();
println!("Coin weight variance: {}", var_weight);

// -- Output --
// Coin weight Posterior: Beta(α: 1.5, β: 3.5)
// Coin weight ML: 0.16666666666666666
// Coin weight variance: 0.035

Other Stuff

Information theory

We have access to information theoretic quantities:

use rv::traits::{Entropy, KlDivergence};

// This is continuous (differential) entropy
let h = gauss.entropy();
println!("{} entropy: {}", gauss, h);

// KL Divergence
let other = Gaussian::new(2.0, 1.0).unwrap();
let kl = gauss.kl(&other);
println!("KL({} | {}): {}", gauss, other, kl);

// -- Output --
// N(μ: 0, σ: 1) entropy: 1.4189385332046727
// KL(N(μ: 0, σ: 1) | N(μ: 2, σ: 1)): 2

Errors

Constructors validate input unless they are marked with _unchecked.

use rv::dist::GaussianError;

match Gaussian::new(0.0, -3.0) {
    Err(GaussianError::SigmaTooLowError) => println!("Expected error"),
    Err(_) => panic!("Unexpected error"),
    Ok(_) => panic!("There should have been an error"),
}

// -- Output --
// Expected error

Unchecked constructors let whatever you want through (quickly). This could lead to panics or undefined behavior if you can't guarantee that the inputs are valid.

let gauss = Gaussian::new_unchecked(0.0, -3.0);

// panics here because the PDF will try to compute ln(-3.0)
let ln_fx = gauss.ln_f(&0.0);

Building bigger things

These are all very simple examples but they hopefully give you a sense of the building blocks of RV. In the examples directory, we have built an infinite Gaussian mixture model (an unsupervised clustering and density estimation algorithm), in about 100 line of code.

Sign up for updates