Skip to content

Commit

Permalink
Merge pull request #90 from BerkeleyLab/learn-icar-sat-mr-func
Browse files Browse the repository at this point in the history
feat(example): train ICAR saturated mixing ratio
  • Loading branch information
rouson authored Oct 4, 2023
2 parents aa8c8db + 14a983d commit 967b0da
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions example/learn-saturated-mixing-ratio.f90
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ program train_saturated_mixture_ratio
call system_clock(counter_start, clock_rate)

block
integer, parameter :: num_epochs = 10000000, num_mini_batches = 6
integer, parameter :: max_num_epochs = 10000000, num_mini_batches = 6
integer num_pairs ! number of input/output pairs

type(mini_batch_t), allocatable :: mini_batches(:)
Expand All @@ -35,7 +35,7 @@ program train_saturated_mixture_ratio
real, allocatable :: cost(:), random_numbers(:)
integer io_status, network_unit, plot_unit
integer, parameter :: io_success=0, diagnostics_print_interval = 1000, network_save_interval = 10000
integer, parameter :: nodes_per_layer(*) = [2, 31, 31, 1]
integer, parameter :: nodes_per_layer(*) = [2, 72, 1]
real, parameter :: cost_tolerance = 1.E-08

call random_init(image_distinct=.true., repeatable=.true.)
Expand Down Expand Up @@ -81,7 +81,7 @@ program train_saturated_mixture_ratio
print *, " Epoch | Cost Function| System_Clock | Nodes per Layer"
allocate(random_numbers(2:size(input_output_pairs)))

do e = previous_epoch + 1, previous_epoch + num_epochs
do e = previous_epoch + 1, previous_epoch + max_num_epochs
call random_number(random_numbers)
call shuffle(input_output_pairs, random_numbers)
mini_batches = [(mini_batch_t(input_output_pairs(bins(b)%first():bins(b)%last())), b = 1, size(bins))]
Expand All @@ -91,7 +91,7 @@ program train_saturated_mixture_ratio
associate( &
cost_avg => sum(cost)/size(cost), &
cumulative_clock_time => previous_clock_time + real(counter_end - counter_start) / real(clock_rate), &
loop_ending => e == previous_epoch + num_epochs &
loop_ending => e == previous_epoch + max_num_epochs &
)
write_and_exit_if_converged: &
if (cost_avg < cost_tolerance) then
Expand Down

0 comments on commit 967b0da

Please sign in to comment.