-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update truncated
and MvNormal
#325
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
6459699
Update `truncated` and `MvNormal`
devmotion ab10ad3
Apply suggestions from code review
devmotion e81416e
Update tutorials/10-bayesian-differential-equations/10_bayesian-diffe…
devmotion c3ebb5b
Update Manifest.toml
devmotion cc144bf
Remove `Truncated`
devmotion 314252b
Remove use of `describe`
devmotion File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,12 @@ using MLDataUtils: shuffleobs, splitobs, rescale! | |
# Functionality for evaluating the model predictions. | ||
using Distances | ||
|
||
# Functionality for constructing arrays with identical elements efficiently. | ||
using FillArrays | ||
|
||
# Functionality for working with scaled identity matrices. | ||
using LinearAlgebra | ||
|
||
# Set a seed for reproducibility. | ||
using Random | ||
Random.seed!(0) | ||
|
@@ -93,45 +99,47 @@ $$ | |
|
||
where $\alpha$ is an intercept term common to all observations, $\boldsymbol{\beta}$ is a coefficient vector, $\boldsymbol{X_i}$ is the observed data for car $i$, and $\sigma^2$ is a common variance term. | ||
|
||
For $\sigma^2$, we assign a prior of `truncated(Normal(0, 100), 0, Inf)`. This is consistent with [Andrew Gelman's recommendations](http://www.stat.columbia.edu/%7Egelman/research/published/taumain.pdf) on noninformative priors for variance. The intercept term ($\alpha$) is assumed to be normally distributed with a mean of zero and a variance of three. This represents our assumptions that miles per gallon can be explained mostly by our assorted variables, but a high variance term indicates our uncertainty about that. Each coefficient is assumed to be normally distributed with a mean of zero and a variance of 10. We do not know that our coefficients are different from zero, and we don't know which ones are likely to be the most important, so the variance term is quite high. Lastly, each observation $y_i$ is distributed according to the calculated `mu` term given by $\alpha + \boldsymbol{\beta}^\mathsf{T}\boldsymbol{X_i}$. | ||
For $\sigma^2$, we assign a prior of `truncated(Normal(0, 100); lower=0)`. | ||
This is consistent with [Andrew Gelman's recommendations](http://www.stat.columbia.edu/%7Egelman/research/published/taumain.pdf) on noninformative priors for variance. | ||
The intercept term ($\alpha$) is assumed to be normally distributed with a mean of zero and a variance of three. | ||
This represents our assumptions that miles per gallon can be explained mostly by our assorted variables, but a high variance term indicates our uncertainty about that. | ||
Each coefficient is assumed to be normally distributed with a mean of zero and a variance of 10. | ||
We do not know that our coefficients are different from zero, and we don't know which ones are likely to be the most important, so the variance term is quite high. | ||
Lastly, each observation $y_i$ is distributed according to the calculated `mu` term given by $\alpha + \boldsymbol{\beta}^\mathsf{T}\boldsymbol{X_i}$. | ||
|
||
```julia | ||
# Bayesian linear regression. | ||
@model function linear_regression(x, y) | ||
# Set variance prior. | ||
σ₂ ~ truncated(Normal(0, 100), 0, Inf) | ||
σ² ~ truncated(Normal(0, 100); lower=0) | ||
|
||
# Set intercept prior. | ||
intercept ~ Normal(0, sqrt(3)) | ||
|
||
# Set the priors on our coefficients. | ||
nfeatures = size(x, 2) | ||
coefficients ~ MvNormal(nfeatures, sqrt(10)) | ||
coefficients ~ MvNormal(Zeros(nfeatures), 10.0 * I) | ||
|
||
# Calculate all the mu terms. | ||
mu = intercept .+ x * coefficients | ||
return y ~ MvNormal(mu, sqrt(σ₂)) | ||
return y ~ MvNormal(mu, σ² * I) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is nicer than the old syntax. |
||
end | ||
``` | ||
|
||
With our model specified, we can call the sampler. We will use the No U-Turn Sampler ([NUTS](http://turing.ml/docs/library/#-turingnuts--type)) here. | ||
|
||
```julia | ||
model = linear_regression(train, train_target) | ||
chain = sample(model, NUTS(0.65), 3_000); | ||
chain = sample(model, NUTS(0.65), 3_000) | ||
``` | ||
|
||
As a visual check to confirm that our coefficients have converged, we show the densities and trace plots for our parameters using the `plot` functionality. | ||
We can also check the densities and traces of the parameters visually using the `plot` functionality. | ||
|
||
```julia | ||
plot(chain) | ||
``` | ||
|
||
It looks like each of our parameters has converged. We can check our numerical esimates using `describe(chain)`, as below. | ||
|
||
```julia | ||
describe(chain) | ||
``` | ||
It looks like all parameters have converged. | ||
|
||
## Comparing to OLS | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason why we are switching styles for imports?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no consistent style currently (and also not when this PR is merged), I mainly felt that the diff is simpler when adding or removing packages if they are on separate lines.