Skip to content

Commit

Permalink
Add Samples and Score types, allow args...
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Mar 12, 2024
1 parent b7fae57 commit 2926b0d
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 7 deletions.
41 changes: 40 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
# ColabMPNN

[![Build Status](https://github.com/anton083/ColabMPNN.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/anton083/ColabMPNN.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Build Status](https://github.com/MurrellGroup/ColabMPNN.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/ColabMPNN.jl/actions/workflows/CI.yml?query=branch%3Amain)

A Julia wrapper for ColabDesign's MPNN module using PyCall.

For more details, see the [original Python documentation](https://github.com/sokrypton/ColabDesign/blob/main/mpnn/README.md)

## Installation

Add ColabMPNN to your Julia environment in the REPL:
```
]add https://github.com/MurrellGroup/ColabMPNN.jl
```

## Usage

Create a model using the `mk_mpnn_model` function. See arguments in the [Python code](https://github.com/sokrypton/ColabDesign/blob/main/colabdesign/mpnn/model.py#L24).

```julia
mpnn_model = mk_mpnn_model()
```

Inputs are prepared to a model using, with the model as first argument. See

```julia
prep_inputs(mpnn_model, pdb_filename="example.pdb", chain="A")
```

Sample sequences using the `sample` function, or in parallel with `sample_parallel`, with the model as the first argument. These functions return a `Samples` instance.

```julia
samples = sample_parallel(mpnn_model)
```

The `Samples` type has the following fields:
- `seq::Vector{String}`
- `seqid::Vector{Float64}`
- `score::Vector{Float64}`
- `logits::Array{Float32, 3}`
- `decoding_order::Array{Int32, 3}`
- `S::Array{Float32, 3}`
40 changes: 34 additions & 6 deletions src/ColabMPNN.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module ColabMPNN

export mpnn, mk_mpnn_model, prep_inputs, sample, sample_parallel, score, get_unconditional_logits
export mpnn
export Samples, Score
export mk_mpnn_model, prep_inputs, sample, sample_parallel, score, get_unconditional_logits

import Pkg
using Conda, PyCall
Expand All @@ -20,15 +22,41 @@ function __init__()
copy!(mpnn, pyimport_conda("colabdesign.mpnn", "colabdesign"))
end

mk_mpnn_model(; kwargs...) = mpnn.mk_mpnn_model(; kwargs...)
struct Samples
seq::Vector{String}
seqid::Vector{Float64}
score::Vector{Float64}
logits::Array{Float32, 3}
decoding_order::Array{Int32, 3}
S::Array{Float32, 3}

prep_inputs(mpnn_model; kwargs...) = mpnn_model.prep_inputs(; kwargs...)
function Samples(samples::Dict{Any, Any})
new([samples[string(f)] for f in fieldnames(Samples)]...)
end
end

struct Score
seqid::Float64
score::Float64
logits::Array{Float32, 2}
decoding_order::Array{Int32, 1}
S::Array{Float32, 2}

function Score(scores::Dict{Any, Any})
new([scores[string(f)] for f in fieldnames(Score)]...)
end
end


mk_mpnn_model(args...; kwargs...) = mpnn.mk_mpnn_model(args...; kwargs...)

prep_inputs(mpnn_model, args...; kwargs...) = mpnn_model.prep_inputs(args...; kwargs...)

sample(mpnn_model; kwargs...) = mpnn_model.sample(; kwargs...)
sample(mpnn_model, args...; kwargs...) = Samples(mpnn_model.sample(args...; kwargs...))

sample_parallel(mpnn_model; kwargs...) = mpnn_model.sample_parallel(; kwargs...)
sample_parallel(mpnn_model, args...; kwargs...) = Samples(mpnn_model.sample_parallel(args...; kwargs...))

score(mpnn_model; kwargs...) = mpnn_model.score(; kwargs...)
score(mpnn_model, args...; kwargs...) = Score(mpnn_model.score(args...; kwargs...))

get_unconditional_logits(mpnn_model) = mpnn_model.get_unconditional_logits()

Expand Down

0 comments on commit 2926b0d

Please sign in to comment.