Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: aggregation of SurvSHAP values across multiple observation #74

Merged

Conversation

kapsner
Copy link
Contributor

@kapsner kapsner commented Apr 5, 2023

First of all, thanks for this helpful package and the inclusion of survshap calculations.

This PR introduces the computation of global SurvSHAP values that are aggregated over multiple observations by averaging. If I have not mistaken something, this was also suggested by @krzyzinskim et. al in the discussion of their publication on survshap where they wrote: "Future works should consider the possibility of aggregating the SurvSHAP(t) function across data distribution to introduce global explanations of machine learning survival models."

The implementation is as follows:

To facilitate calculations, I use the data.table R Package and have added it to the DESCRIPTION file.

Furthermore, some enhancements regarding error handling were made for the surv_shap function: kapsner@00d6a65#diff-28ac05366e312a7571ab5c890e7360ea1534cc403cbb14f37624f20a925315c6R27

Finally, a unit test has been added to verify that the code is working: kapsner@00d6a65#diff-98f5248ab657edc483b0a9eba0a12ef10e595ecaa84bf21399f5a68d746aa103R60

Looking forward to hearing from you.

Best,
Lorenz

@mikolajsp
Copy link
Collaborator

Wow, this looks amazing, thank you so much for the contribution @kapsner

I'm going to investigate and review this in detail tomorrow!

@mikolajsp
Copy link
Collaborator

I've also seen that you opened an issue about integrating survival forests at our treeshap repository.

I think that treeshap + aggregation of shap explanations has much potential so we'll try to incorporate these changes into both packages!

@kapsner
Copy link
Contributor Author

kapsner commented Apr 5, 2023

Hi @mikolajsp , Thanks for you quick reply.

Regarding treeshap:

Yes, I am already working on adapting treeshap to work with survival functions as a preparation for integration into survex.

The treeshap-part already works quite well: kapsner/treeshap@7d28afe...6b2aa13#diff-ef4d0c389ad9619a8514528ddc4a14db7e270dcf6363b0e4c9e63a06131d171dR61 (despite the observed differences between original values and predicted values in the unit tests).

Using this adaptions, I have also begun to work on the treeshap-integration into survex here: kapsner/survex@f17674a...bb47009#diff-28ac05366e312a7571ab5c890e7360ea1534cc403cbb14f37624f20a925315c6R216

I'd be really glad on feedback / help in this regard to implement everything statistically sound ;)

@mikolajsp
Copy link
Collaborator

Sure I'll be happy to help, so feel free to reach out with questions.

I will get back to you tomorrow regarding this pull request, and will also try to find out why the unified model predictions differ in treeshap 😁

@kapsner
Copy link
Contributor Author

kapsner commented Apr 11, 2023

Hi @mikolajsp ,
were you already able to have a look at the proposed adaptions regarding the computation of global SurvSHAP(t) values when providing multiple observations?

I have had also some thoughts on how to express this mathematically. I think, the part implemented here could be expressed in the following manner:

  • $i: {i = 1, 2, ..., n}$ be the $i$-th observation in the dataset.
  • $d_{j}: {j = 1, 2, ..., p}$ be the $j$-th feature.
  • $t \in {t_{1}, ..., t_{m}}$ be the times to the event of interest with $t_{i1} < t_{i2} < ... < t_{im}$.
  • $\phi_t(i_{*}, d_j)$ be the SurvSHAP(t) value of the $j$-th feature of observation $i_{*}$ at time point $t$.

Then global SurvSHAP(t) should be:

$$ \phi_{gt}(i_{*}, d_j) = \frac{1}{n} \sum_{i=1}^{n} \phi_{t} (i_{*},d_j) $$

Is this correct?

@mikolajsp
Copy link
Collaborator

Hi,

I've checked out your changes and they seem great, but I haven't had time to test them out myself, I would like to do some manual testing before I merge the PR.

And the math here looks all good to me!

I'll try to get to this PR and merge it hopefully this week

@kapsner
Copy link
Contributor Author

kapsner commented Apr 11, 2023

Perfect, thanks a lot for the feedback; and no hurry from my side!

Just for transparency, once the global SurvSHAP computation is validated and merged, I am also planning the following PRs in the future, which are:

@kapsner
Copy link
Contributor Author

kapsner commented May 21, 2023

Hi @mikolajsp , have you now already found some time to have a closer look a this PR? I would be happy to get some feedback, if these commits are likely to be merged or if I perhaps would need to do some further adjustments.

@mikolajsp
Copy link
Collaborator

Hello, sorry for the delay.

I think I'm going to merge the PR into a dev branch and make some modifications today.

The part you made looks good but I want to make the following modifications:

  • Make the surv_shap function return a matrix if multiple observations are provided i.e. not handle the aggregation part.
  • Create aggregate_shap function to make it possible to aggregate surv_shap profiles in different ways (e.g. mean, median etc.)
  • Create a specific plotting function to plot aggregated shap.
  • As aggregated shap is a global explanation, make it possible to call it from the model_parts function and safeguard the call from predict_parts so that it only allows one observation.

I'm going to ping you when they are ready to see what you think.

@mikolajsp mikolajsp changed the base branch from main to dev-global-survshap May 22, 2023 09:32
@mikolajsp mikolajsp merged commit 2843273 into ModelOriented:dev-global-survshap May 22, 2023
@kapsner
Copy link
Contributor Author

kapsner commented May 22, 2023

Thanks, sounds all good to me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants