diamond mixed effects models in python
TRANSCRIPT
Diamond: Mixed Effects Models in Python
Timothy Sweetser
Stitch Fixhttp://github.com/stitchfix/diamond
@hacktuarial
November 27, 2017
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 1 / 32
Overview
1 context and motivation
2 what is the mixed effects model
3 application to recommender systems
4 computation
5 diamond
6 appendix
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 2 / 32
context and motivation
Stitch Fix
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 3 / 32
what is the mixed effects model
Refresher: Linear Model
y ∼ N(Xβ, σ2I )
y is n x 1
X is n x p
β is an unknown vector of length p
σ2 is an unknown, nonnegative constant
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 4 / 32
what is the mixed effects model
Mixed Effects Model
y |b ∼ N(Xβ + Zb, σ2I )
We have a second set of features, Z, n x q
the coefficients on Z are b ∼ N(0,Σ)
Σ is q x q
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 5 / 32
what is the mixed effects model
simple example of a mixed effects model
You think there is some relationship between a woman’s height and theideal length of jeans for her:
length = α + β ∗ height + ε
But, you think the length might need to be shorter or longer, dependingon the silhouette of the jeans. In other words, you want α to vary bysilhouette.
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 6 / 32
what is the mixed effects model
simple example of a mixed effects model
You think there is some relationship between a woman’s height and theideal length of jeans for her:
length = α + β ∗ height + ε
But, you think the length might need to be shorter or longer, dependingon the silhouette of the jeans. In other words, you want α to vary bysilhouette.
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 6 / 32
what is the mixed effects model
why might silhouette affect length ∼ height?
SkinnyBootcut
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 7 / 32
what is the mixed effects model
linear model: formula
Linear models can be expressed in formula notation, used by patsy,statsmodels, and R
import statsmodels.formula.api as smf
lm = smf.ols(’length ~ 1 + height ’, data=train_df).fit()
in math, this means length = Xβ + ε
Xi = [1.0, 64.0]
β is what we want to learn, using (customer, item) data from jeansthat fit well
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 8 / 32
what is the mixed effects model
linear model: illustration
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 9 / 32
what is the mixed effects model
mixed effects: formula
Now, allow the intercept to vary by silhouette
mix = smf.mixedlm(’length ~ 1 + height ’,
data=train_df ,
re_formula=’1’,
groups=’silhouette ’,
use_sparse=True).fit()
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 10 / 32
what is the mixed effects model
illustration
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 11 / 32
what is the mixed effects model
mixed effects regularization
y |b ∼ N(Xβ + Zb, σ2I )
Sort by silhouette:
Z =
1bootcut 0 0 0
0 1skinny 0 00 0 1straight 00 0 0 1wide
X is n x 2
Z is n x 4
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 12 / 32
what is the mixed effects model
matrices and formulas - mixed effects
Zb =
1bootcut 0 0 0
0 1skinny 0 00 0 1straight 00 0 0 1wide
µbootcutµskinnyµstraightµwide
Each µsilhouette is drawn from N(0, σ2)
This allows for deviations from the average effects, µ and β, bysilhouette, to the extend that the data support it
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 13 / 32
application to recommender systems
a basic model
rating ∼ 1 + (1|user id) + (1|item id)
In math, this meansrui = µ+ αu + βi + εui
where
µ is an unknown constant
αu ∼ N(0, σ2user )
βi ∼ N(0, σ2item)
some items are more popular than others
some users are more picky than others
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 14 / 32
application to recommender systems
a basic model
rating ∼ 1 + (1|user id) + (1|item id)
In math, this meansrui = µ+ αu + βi + εui
where
µ is an unknown constant
αu ∼ N(0, σ2user )
βi ∼ N(0, σ2item)
some items are more popular than others
some users are more picky than others
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 14 / 32
application to recommender systems
add features
rating ∼ 1 + (1 + item feature1 + item feature2|user id)+
(1 + user feature1 + user feature2|item id)
Now,
αu ∼ N(0,Σuser )
βi ∼ N(0,Σitem)
the good: we’re using features! learn individual and shared preferences
helps with new items, new users
the bad: scales as O(p2)
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 15 / 32
application to recommender systems
comments
rating ∼ 1 + (1 + item feature1 + item feature2|user id)+
(1 + user feature1 + user feature2|item id)
this is a parametric model, and much less flexible than trees, neuralnetworks, or matrix factorization
but you don’t have to choose!
you can use an ensemble, or use this as a feature in another model
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 16 / 32
computation
computation
How can you fit models like this? We were using R’s lme4 package
Maximum likelihood computation works like this:
Estimate covariance structure of random effects, Σgiven Σ, estimate coefficients β and bwith these, compute loglikelihoodrepeat until convergence
Doesn’t scale well with number of observations, n
lme4 supports a variety of generalized linear models, but is notoptimized for any one in particular
Is it really necessary to update hyperparameters Σ every time youestimate the coefficients?
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 17 / 32
computation
computation
How can you fit models like this? We were using R’s lme4 package
Maximum likelihood computation works like this:
Estimate covariance structure of random effects, Σgiven Σ, estimate coefficients β and bwith these, compute loglikelihoodrepeat until convergence
Doesn’t scale well with number of observations, n
lme4 supports a variety of generalized linear models, but is notoptimized for any one in particular
Is it really necessary to update hyperparameters Σ every time youestimate the coefficients?
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 17 / 32
computation
computation
How can you fit models like this? We were using R’s lme4 package
Maximum likelihood computation works like this:
Estimate covariance structure of random effects, Σgiven Σ, estimate coefficients β and bwith these, compute loglikelihoodrepeat until convergence
Doesn’t scale well with number of observations, n
lme4 supports a variety of generalized linear models, but is notoptimized for any one in particular
Is it really necessary to update hyperparameters Σ every time youestimate the coefficients?
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 17 / 32
computation
diamond
Diamond solves a similar problem using these tricks:
Input Σ. Conditional on Σ, the optimization problem is convex
Use Hessian of L2 penalized loglikelihood function (pencil + paper)
logistic regressioncumulative logistic regression, for ordinal responsesif Y ∈ (1, 2, 3, . . . , J),
log
(Pr(Y ≤ j)
1− Pr(Y ≤ j)
)= αj + βT x
for j = 1, 2, . . . , J − 1
quasi-Newton optimization techniques from Minka 2003
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 18 / 32
computation
other solvers
How else could you fit mixed effects models?
”Exact” methods
Full Bayes: MCMC. e.g. PyStan, PyMC3, Edwarddiamond, but you must specify the hyperparameters Σstatsmodels only supports linear regression for Gaussian-distributedoutcomesR/lme4
Approximate methods
Simple, global L2 regularizationFull Bayes: Variational Inferencemoment-based methods
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 19 / 32
computation
other solvers
How else could you fit mixed effects models?
”Exact” methods
Full Bayes: MCMC. e.g. PyStan, PyMC3, Edwarddiamond, but you must specify the hyperparameters Σstatsmodels only supports linear regression for Gaussian-distributedoutcomesR/lme4
Approximate methods
Simple, global L2 regularizationFull Bayes: Variational Inferencemoment-based methods
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 19 / 32
diamond
Speed test
MovieLens, 20M observations like (userId, movieId, rating)
binarize (ordinal!) rating → 1(rating > 3.5)
this is well-balanced
Fit a model like
rating ∼ 1 + (1|user id) + (1|item id)
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 20 / 32
diamond
diamond
from diamond.glms.logistic import LogisticRegression
import pandas as pd
train_df = ...
priors_df = pd.DataFrame({
’group’: [’userId ’, ’movieId ’],
’var1’: [’intercept ’] * 2,
’var2’: [np.nan , np.nan],
’vcov’: [0.9, 1.0]
})
m = LogisticRegression(train_df=train_df , priors_df=
priors_df)
results = m.fit(’liked ~ 1 + (1 | userId) + (1 | movieId)’,
tol=1e-5, max_its=200 , verbose=True)
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 21 / 32
diamond
Speed test vs. sklearn
Diamond
estimate covariance on sample of 1M observations in R. 1-time, 60minutes
σ2user = 0.9, σ2
movie = 1.0
Takes 83 minutes on my laptop to fit in diamond
sklearn LogisticRegression
use cross validation to estimate regularization. 1-time, takes 24minutes
grid search would be a fairer comparison
refit takes 1 minute
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 22 / 32
diamond
diamond vs. sklearn predictions
Global L2 regularization is a good approximation for this problem, but maynot work as well when σ2
user >> σ2item, vice versa, or for more models with
more featuresTimothy Sweetser (Stitch Fix) Diamond November 27, 2017 23 / 32
diamond
diamond vs. R
lme4 takes more than 360 minutes to fit
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 24 / 32
diamond
diamond vs. moment-based
active area of research by statisticians at Stanford, NYU, elsewhere
very fast to fit simple models using method of moments
e.g. rating ∼ 1 + (1 + x |user id)
or rating ∼ 1 + (1|user id) + (1|item id)
Fitting this to movie lens 20M took 4 minutes
but not rating ∼ 1 + (1 + x |user id) + (1|item id)
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 25 / 32
diamond
diamond vs. variational inference
I fit this model in under 5 minutes using Edward, and didn’t have toinput Σ.
VI is very promising!
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 26 / 32
diamond
why use diamond?
http://github.com/stitchfix/diamond
scales well with number of observations (compared to pure R, MCMC)
solves the exact problem (compared to variational, moment-based)
scales ok with P (compared to simple global L2)
supports ordinal logistic regression
if Y ∈ (1, 2, 3, . . . , J),
log
(Pr(Y ≤ j)
1− Pr(Y ≤ j)
)= αj + βT x
for j = 1, 2, . . . , J − 1Reference: Agresti, Categorical Data Analysis
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 27 / 32
diamond
summary
mixed effects models are useful for recommender systems and otherdata science applications
they can be hard to fit for large datasets
they play well with other kinds of models
diamond, moment-based approaches, and variational inference aregood ways to estimate models quickly
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 28 / 32
diamond
discussion
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 29 / 32
diamond
References I
Patrick Perry (2015)
Moment Based Estimation for Hierarchical Models
https://arxiv.org/abs/1504.04941
Alan Agresti (2012)
Categorical Data Analysis, 3rd Ed.
ISBN-13 978-0470463635
Gao + Owen (2016)
Estimation and Inference for Very Large Linear Mixed Effects Models
https://arxiv.org/abs/1610.08088
Edward
A Library for probabilistic modeling, inference, and criticism.
https://github.com/blei-lab/edward
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 30 / 32
diamond
References II
inka
A comparison of numerical optimizers for logistic regression
https://tminka.github.io/papers/logreg/minka-logreg.pdf
me4
https://cran.r-project.org/web/packages/lme4/vignettes/lmer.pdf
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 31 / 32
appendix
regularization
Usual L2 regularization. If each βi ∼ N(0, 1λ)
minimizeβ
loss +1
2βT (λIp)β
Here, the four b coefficient vectors are samples from N(0,Σ). If we knewΣ, the regularization would be
minimizeb
loss +1
2bT
Σ−1 0 0 0
0 Σ−1 0 00 0 Σ−1 00 0 0 Σ−1
b
Timothy Sweetser (Stitch Fix) Diamond November 27, 2017 32 / 32