Skip to content

Commit

Permalink
Merge pull request #267 from FluxML/entity-embeddings
Browse files Browse the repository at this point in the history
Introduce EntityEmbeddings
  • Loading branch information
ablaom authored Sep 9, 2024
2 parents 1e41256 + 310cb12 commit 945016d
Show file tree
Hide file tree
Showing 19 changed files with 1,539 additions and 350 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJFlux"
uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>", "Ayush Shridhar <ayush.shridhar1999@gmail.com>"]
version = "0.5.1"
version = "0.6.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
11 changes: 6 additions & 5 deletions src/MLJFlux.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module MLJFlux

export CUDALibs, CPU1

import Flux
using MLJModelInterface
using MLJModelInterface.ScientificTypesBase
Expand All @@ -17,22 +16,24 @@ import Metalhead
import Optimisers

include("utilities.jl")
const MMI=MLJModelInterface
const MMI = MLJModelInterface

include("encoders.jl")
include("entity_embedding.jl")
include("builders.jl")
include("metalhead.jl")
include("types.jl")
include("core.jl")
include("regressor.jl")
include("classifier.jl")
include("image.jl")
include("fit_utils.jl")
include("entity_embedding_utils.jl")
include("mlj_model_interface.jl")

export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor
export NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, ImageClassifier
export CUDALibs, CPU1

include("deprecated.jl")


end #module
end # module
30 changes: 18 additions & 12 deletions src/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
A private method that returns the shape of the input and output of the model for given
data `X` and `y`.
"""
function MLJFlux.shape(model::NeuralNetworkClassifier, X, y)
X = X isa Matrix ? Tables.table(X) : X
Expand All @@ -14,6 +13,7 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y)
n_input = Tables.schema(X).names |> length
return (n_input, n_output)
end
is_embedding_enabled(::NeuralNetworkClassifier) = true

# builds the end-to-end Flux chain needed, given the `model` and `shape`:
MLJFlux.build(
Expand All @@ -29,24 +29,28 @@ MLJFlux.fitresult(
model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier},
chain,
y,
) = (chain, MLJModelInterface.classes(y[1]))
ordinal_mappings = nothing,
embedding_matrices = nothing,
) = (chain, MLJModelInterface.classes(y[1]), ordinal_mappings, embedding_matrices)

function MLJModelInterface.predict(
model::NeuralNetworkClassifier,
fitresult,
Xnew,
)
chain, levels = fitresult
)
chain, levels, ordinal_mappings, _ = fitresult
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) # what if Xnew is a matrix
X = reformat(Xnew)
probs = vcat([chain(tomat(X[:, i]))' for i in 1:size(X, 2)]...)
return MLJModelInterface.UnivariateFinite(levels, probs)
end


MLJModelInterface.metadata_model(
NeuralNetworkClassifier,
input_scitype=Union{AbstractMatrix{Continuous},Table(Continuous)},
target_scitype=AbstractVector{<:Finite},
load_path="MLJFlux.NeuralNetworkClassifier",
input_scitype = Union{AbstractMatrix{Continuous}, Table(Continuous, Finite)},
target_scitype = AbstractVector{<:Finite},
load_path = "MLJFlux.NeuralNetworkClassifier",
)

#### Binary Classifier
Expand All @@ -56,21 +60,23 @@ function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y)
n_input = Tables.schema(X).names |> length
return (n_input, 1) # n_output is always 1 for a binary classifier
end
is_embedding_enabled(::NeuralNetworkBinaryClassifier) = true

function MLJModelInterface.predict(
model::NeuralNetworkBinaryClassifier,
fitresult,
Xnew,
)
chain, levels = fitresult
)
chain, levels, ordinal_mappings, _ = fitresult
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
X = reformat(Xnew)
probs = vec(chain(X))
return MLJModelInterface.UnivariateFinite(levels, probs; augment = true)
end

MLJModelInterface.metadata_model(
NeuralNetworkBinaryClassifier,
input_scitype=Union{AbstractMatrix{Continuous},Table(Continuous)},
target_scitype=AbstractVector{<:Finite{2}},
load_path="MLJFlux.NeuralNetworkBinaryClassifier",
input_scitype = Union{AbstractMatrix{Continuous}, Table(Continuous, Finite)},
target_scitype = AbstractVector{<:Finite{2}},
load_path = "MLJFlux.NeuralNetworkBinaryClassifier",
)
8 changes: 8 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ end
y,
) -> updated_chain, updated_optimiser_state, training_loss
**Private method.**
Update the parameters of a Flux `chain`, where:
- `model` is typically an `MLJFluxModel` instance, but could be any object such that
Expand Down Expand Up @@ -77,6 +79,8 @@ end
y,
) -> (updated_chain, updated_optimiser_state, history)
**Private method.**
Optimize a Flux model `chain`, where `(yhat, y) -> loss(yhat, y)` is the loss function
inferred from the `model`. Typically, `model` will be an `MLJFluxModel` instance, but it
could be any object such that `model.loss` is a Flux.jl loss function.
Expand Down Expand Up @@ -162,6 +166,8 @@ end
"""
gpu_isdead()
**Private method.**
Returns `true` if `acceleration=CUDALibs()` option is unavailable, and
false otherwise.
Expand All @@ -171,6 +177,8 @@ gpu_isdead() = Flux.gpu([1.0,]) isa Array
"""
nrows(X)
**Private method.**
Find the number of rows of `X`, where `X` is an `AbstractVector or
Tables.jl table.
"""
Expand Down
152 changes: 152 additions & 0 deletions src/encoders.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
File containing ordinal encoder and entity embedding encoder. Borrows code from the MLJTransforms package.
"""

### Ordinal Encoder
"""
**Private Method**
Fits an ordinal encoder to the table `X`, using only the columns with indices in `featinds`.
Returns a dictionary mapping each column index to a dictionary mapping each level in that column to an integer.
"""
function ordinal_encoder_fit(X; featinds)
# 1. Define mapping per column per level dictionary
mapping_matrix = Dict()

# 2. Use feature mapper to compute the mapping of each level in each column
for i in featinds
feat_col = Tables.getcolumn(Tables.columns(X), i)
feat_levels = levels(feat_col)
# Check if feat levels is already ordinal encoded in which case we skip
(Set([float(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue
# Compute the dict using the given feature_mapper function
mapping_matrix[i] =
Dict{Any, AbstractFloat}(
value => float(index) for (index, value) in enumerate(feat_levels)
)
end
return mapping_matrix
end

"""
**Private Method**
Checks that all levels in `test_levels` are also in `train_levels`. If not, throws an error.
"""
function check_unkown_levels(train_levels, test_levels)
# test levels must be a subset of train levels
if !issubset(test_levels, train_levels)
# get the levels in test that are not in train
lost_levels = setdiff(test_levels, train_levels)
error(
"While transforming, found novel levels for the column: $(lost_levels) that were not seen while training.",
)
end
end

"""
**Private Method**
Transforms the table `X` using the ordinal encoder defined by `mapping_matrix`.
Returns a new table with the same column names as `X`, but with categorical columns replaced by integer columns.
"""
function ordinal_encoder_transform(X, mapping_matrix)
isnothing(mapping_matrix) && return X
isempty(mapping_matrix) && return X
feat_names = Tables.schema(X).names
numfeats = length(feat_names)
new_feats = []
for ind in 1:numfeats
col = Tables.getcolumn(Tables.columns(X), ind)

# Create the transformation function for each column
if ind in keys(mapping_matrix)
train_levels = keys(mapping_matrix[ind])
test_levels = levels(col)
check_unkown_levels(train_levels, test_levels)
level2scalar = mapping_matrix[ind]
new_col = recode(col, level2scalar...)
push!(new_feats, new_col)
else
push!(new_feats, col)
end
end

transformed_X = NamedTuple{tuple(feat_names...)}(tuple(new_feats)...)
# Attempt to preserve table type
transformed_X = Tables.materializer(X)(transformed_X)
return transformed_X
end

"""
**Private Method**
Combine ordinal_encoder_fit and ordinal_encoder_transform and return both X and ordinal_mappings
"""
function ordinal_encoder_fit_transform(X; featinds)
ordinal_mappings = ordinal_encoder_fit(X; featinds = featinds)
return ordinal_encoder_transform(X, ordinal_mappings), ordinal_mappings
end



## Entity Embedding Encoder (assuming precomputed weights)
"""
**Private method.**
Function to generate new feature names: feat_name_0, feat_name_1,..., feat_name_n
"""
function generate_new_feat_names(feat_name, num_inds, existing_names)
conflict = true # will be kept true as long as there is a conflict
count = 1 # number of conflicts+1 = number of underscores

new_column_names = []
while conflict
suffix = repeat("_", count)
new_column_names = [Symbol("$(feat_name)$(suffix)$i") for i in 1:num_inds]
conflict = any(name -> name in existing_names, new_column_names)
count += 1
end
return new_column_names
end


"""
Given X and a dict of mapping_matrices that map each categorical column to a matrix, use the matrix to transform
each level in each categorical columns using the columns of the matrix.
This is used with the embedding matrices of the entity embedding layer in entity enabled models to implement entity embeddings.
"""
function embedding_transform(X, mapping_matrices)
(isempty(mapping_matrices)) && return X
feat_names = Tables.schema(X).names
new_feat_names = Symbol[]
new_cols = []
for feat_name in feat_names
col = Tables.getcolumn(Tables.columns(X), feat_name)
# Create the transformation function for each column
if feat_name in keys(mapping_matrices)
level2vector = mapping_matrices[feat_name]
new_multi_col = map(x -> level2vector[:, Int.(unwrap(x))], col)
new_multi_col = [col for col in eachrow(hcat(new_multi_col...))]
push!(new_cols, new_multi_col...)
feat_names_with_inds = generate_new_feat_names(
feat_name,
size(level2vector, 1),
feat_names,
)
push!(new_feat_names, feat_names_with_inds...)
else
# Not to be transformed => left as is
push!(new_feat_names, feat_name)
push!(new_cols, col)
end
end

transformed_X = NamedTuple{tuple(new_feat_names...)}(tuple(new_cols)...)
# Attempt to preserve table type
transformed_X = Tables.materializer(X)(transformed_X)
return transformed_X
end
Loading

0 comments on commit 945016d

Please sign in to comment.