Skip to content

Commit

Permalink
Merge pull request #65 from BerkeleyLab/trainable-engine-inference
Browse files Browse the repository at this point in the history
Make trainable_engine_t independent
  • Loading branch information
rouson authored Jun 30, 2023
2 parents 234d57f + 7995aca commit 5b8366e
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 601 deletions.
153 changes: 0 additions & 153 deletions example/train-and-gate.f90

This file was deleted.

112 changes: 0 additions & 112 deletions example/train-xor-gate.f90

This file was deleted.

59 changes: 35 additions & 24 deletions src/inference_engine/trainable_engine_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
! Terms of use are as specified in LICENSE.txt
module trainable_engine_m
!! Define an abstraction that supports training a neural network
use inference_engine_m_, only : inference_engine_t
use inference_strategy_m, only : inference_strategy_t
use outputs_m, only : outputs_t
use differentiable_activation_strategy_m, only : differentiable_activation_strategy_t
use string_m, only : string_t
use kind_parameters_m, only : rkind
Expand All @@ -15,57 +15,68 @@ module trainable_engine_m
private
public :: trainable_engine_t

type, extends(inference_engine_t) :: trainable_engine_t
type trainable_engine_t
!! Encapsulate the information needed to perform training
private
real(rkind), allocatable :: w(:,:,:) ! weights
real(rkind), allocatable :: b(:,:) ! biases
integer, allocatable :: n(:) ! nuerons per layer
class(differentiable_activation_strategy_t), allocatable :: differentiable_activation_strategy_
contains
procedure :: train_single_hidden_layer
procedure :: train_deep_network
generic :: train => train_deep_network, train_single_hidden_layer
procedure :: assert_consistent
procedure :: train
procedure :: infer
procedure :: num_layers
procedure :: num_inputs
end type

interface trainable_engine_t
integer, parameter :: input_layer = 0

pure module function construct_trainable_engine( &
metadata, input_weights, hidden_weights, output_weights, biases, output_biases, differentiable_activation_strategy &
) &
result(trainable_engine)
implicit none
type(string_t), intent(in) :: metadata(:)
real(rkind), intent(in), dimension(:,:) :: input_weights, output_weights, biases
real(rkind), intent(in) :: hidden_weights(:,:,:), output_biases(:)
class(differentiable_activation_strategy_t), intent(in) :: differentiable_activation_strategy
type(trainable_engine_t) trainable_engine
end function
interface trainable_engine_t

pure module function construct_from_padded_arrays(nodes, weights, biases, differentiable_activation_strategy, metadata) &
result(trainable_engine)
implicit none
integer, intent(in), allocatable :: nodes(:)
integer, intent(in) :: nodes(input_layer:)
real(rkind), intent(in) :: weights(:,:,:), biases(:,:)
class(differentiable_activation_strategy_t), intent(in) :: differentiable_activation_strategy
type(string_t), intent(in) :: metadata(:)
type(trainable_engine_t) trainable_engine

end function

end interface

interface

pure module subroutine train_single_hidden_layer(self, mini_batch, inference_strategy)
pure module subroutine assert_consistent(self)
implicit none
class(trainable_engine_t), intent(inout) :: self
type(mini_batch_t), intent(in) :: mini_batch(:)
class(inference_strategy_t), intent(in) :: inference_strategy
class(trainable_engine_t), intent(in) :: self
end subroutine

pure module subroutine train_deep_network(self, mini_batches)
pure module subroutine train(self, mini_batches)
implicit none
class(trainable_engine_t), intent(inout) :: self
type(mini_batch_t), intent(in) :: mini_batches(:)
end subroutine

elemental module function infer(self, inputs) result(outputs)
implicit none
class(trainable_engine_t), intent(in) :: self
type(inputs_t), intent(in) :: inputs
type(outputs_t) outputs
end function

elemental module function num_inputs(self) result(n_in)
implicit none
class(trainable_engine_t), intent(in) :: self
integer n_in
end function

elemental module function num_layers(self) result(n_layers)
implicit none
class(trainable_engine_t), intent(in) :: self
integer n_layers
end function

end interface

Expand Down
Loading

0 comments on commit 5b8366e

Please sign in to comment.