Skip to content

Commit

Permalink
Update global reduction implementation to improve performance, fix VP…
Browse files Browse the repository at this point in the history
… bug (#824)

* Update global reduction implementation to improve performance, fix VP bug

This was mainly done for situations like VP that need a fast global sum.
The VP global sum is still slightly faster than the one computed in the
infrastructure, so kept that implementation.  Found a bug in the workspace_y
calculation in VP that was fixed.  Also found that the haloupdate call
as part of the precondition step generally improves VP performance, so removed
option to NOT call the haloupdate there.

Separately, fixed a bug in the tripoleT global sum implementation, added
a tripoleT global sum unit test, and resynced ice_exit.F90, ice_reprosum.F90,
and ice_global_reductions.F90 between serial and mpi versions.

- Refactor global sums to improve performance, move if checks outside do loops
- Fix bug in tripoleT global sums, tripole seam masking
- Update VP solver, use local global sum more often
- Update VP solver, fix bug in workspace_y calculation
- Update VP solver, always call haloupdate during precondition
- Refactor ice_exit.F90 and sync serial and mpi versions
- Sync ice_reprosum.F90 between serial and mpi versions
- Update sumchk unit test to handle grids better
- Add tripoleT sumchk test

* Update VP global sum to exclude local implementation with tripole grids
  • Loading branch information
apcraig authored Apr 5, 2023
1 parent 9424497 commit 5b0418a
Show file tree
Hide file tree
Showing 9 changed files with 466 additions and 268 deletions.
6 changes: 2 additions & 4 deletions cicecore/cicedyn/analysis/ice_diagnostics.F90
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,8 @@ subroutine runtime_diags (dt)
!$OMP END PARALLEL DO
extentn = c0
extents = c0
extentn = global_sum(work1, distrb_info, field_loc_center, &
tarean)
extents = global_sum(work1, distrb_info, field_loc_center, &
tareas)
extentn = global_sum(work1, distrb_info, field_loc_center, tarean)
extents = global_sum(work1, distrb_info, field_loc_center, tareas)
extentn = extentn * m2_to_km2
extents = extents * m2_to_km2

Expand Down
41 changes: 21 additions & 20 deletions cicecore/cicedyn/dynamics/ice_dyn_vp.F90
Original file line number Diff line number Diff line change
Expand Up @@ -2502,7 +2502,7 @@ function global_dot_product (nx_block , ny_block , &
vector2_x , vector2_y) &
result(dot_product)

use ice_domain, only: distrb_info
use ice_domain, only: distrb_info, ns_boundary_type
use ice_domain_size, only: max_blocks
use ice_fileunits, only: bfbflag

Expand Down Expand Up @@ -2552,8 +2552,14 @@ function global_dot_product (nx_block , ny_block , &
enddo
!$OMP END PARALLEL DO

! Use local summation result unless bfbflag is active
if (bfbflag == 'off') then
! Use faster local summation result for several bfbflag settings.
! The local implementation sums over each block, sums over local
! blocks, and calls global_sum on a scalar and should be just as accurate as
! bfbflag = 'off', 'lsum8', and 'lsum4' without the extra copies and overhead
! in the more general array global_sum. But use the array global_sum
! if bfbflag is more strict or for tripole grids (requires special masking)
if (ns_boundary_type /= 'tripole' .and. ns_boundary_type /= 'tripoleT' .and. &
(bfbflag == 'off' .or. bfbflag == 'lsum8' .or. bfbflag == 'lsum4')) then
dot_product = global_sum(sum(dot), distrb_info)
else
dot_product = global_sum(prod, distrb_info, field_loc_NEcorner)
Expand Down Expand Up @@ -3120,7 +3126,7 @@ subroutine fgmres (zetax2 , etax2 , &
j = indxUj(ij, iblk)

workspace_x(i, j, iblk) = workspace_x(i, j, iblk) + rhs_hess(it) * arnoldi_basis_x(i, j, iblk, it)
workspace_y(i, j, iblk) = workspace_x(i, j, iblk) + rhs_hess(it) * arnoldi_basis_y(i, j, iblk, it)
workspace_y(i, j, iblk) = workspace_y(i, j, iblk) + rhs_hess(it) * arnoldi_basis_y(i, j, iblk, it)
enddo ! ij
enddo
!$OMP END PARALLEL DO
Expand Down Expand Up @@ -3151,7 +3157,6 @@ subroutine pgmres (zetax2 , etax2 , &

use ice_boundary, only: ice_HaloUpdate
use ice_domain, only: maskhalo_dyn, halo_info
use ice_fileunits, only: bfbflag
use ice_timers, only: ice_timer_start, ice_timer_stop, timer_bound

real (kind=dbl_kind), dimension(nx_block,ny_block,max_blocks,4), intent(in) :: &
Expand Down Expand Up @@ -3343,21 +3348,17 @@ subroutine pgmres (zetax2 , etax2 , &
workspace_x , workspace_y)

! Update workspace with boundary values
! NOTE: skipped for efficiency since this is just a preconditioner
! unless bfbflag is active
if (bfbflag /= 'off') then
call stack_fields(workspace_x, workspace_y, fld2)
call ice_timer_start(timer_bound)
if (maskhalo_dyn) then
call ice_HaloUpdate (fld2, halo_info_mask, &
field_loc_NEcorner, field_type_vector)
else
call ice_HaloUpdate (fld2, halo_info, &
field_loc_NEcorner, field_type_vector)
endif
call ice_timer_stop(timer_bound)
call unstack_fields(fld2, workspace_x, workspace_y)
call stack_fields(workspace_x, workspace_y, fld2)
call ice_timer_start(timer_bound)
if (maskhalo_dyn) then
call ice_HaloUpdate (fld2, halo_info_mask, &
field_loc_NEcorner, field_type_vector)
else
call ice_HaloUpdate (fld2, halo_info, &
field_loc_NEcorner, field_type_vector)
endif
call ice_timer_stop(timer_bound)
call unstack_fields(fld2, workspace_x, workspace_y)

!$OMP PARALLEL DO PRIVATE(iblk)
do iblk = 1, nblocks
Expand Down Expand Up @@ -3528,7 +3529,7 @@ subroutine pgmres (zetax2 , etax2 , &
j = indxUj(ij, iblk)

workspace_x(i, j, iblk) = workspace_x(i, j, iblk) + rhs_hess(it) * arnoldi_basis_x(i, j, iblk, it)
workspace_y(i, j, iblk) = workspace_x(i, j, iblk) + rhs_hess(it) * arnoldi_basis_y(i, j, iblk, it)
workspace_y(i, j, iblk) = workspace_y(i, j, iblk) + rhs_hess(it) * arnoldi_basis_y(i, j, iblk, it)
enddo ! ij
enddo
!$OMP END PARALLEL DO
Expand Down
60 changes: 32 additions & 28 deletions cicecore/cicedyn/infrastructure/comm/mpi/ice_exit.F90
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

!=======================================================================
!
! Exit the model.
Expand All @@ -8,7 +9,15 @@
module ice_exit

use ice_kinds_mod
use ice_fileunits, only: nu_diag, ice_stderr, flush_fileunit
use icepack_intfc, only: icepack_warnings_flush, icepack_warnings_aborted
#if (defined CESMCOUPLED)
use shr_sys_mod
#else
#ifndef SERIAL_REMOVE_MPI
use mpi ! MPI Fortran module
#endif
#endif

implicit none
public
Expand All @@ -23,70 +32,65 @@ subroutine abort_ice(error_message, file, line, doabort)

! This routine aborts the ice model and prints an error message.

#if (defined CESMCOUPLED)
use ice_fileunits, only: nu_diag, flush_fileunit
use shr_sys_mod
#else
use ice_fileunits, only: nu_diag, ice_stderr, flush_fileunit
use mpi ! MPI Fortran module
#endif

character (len=*), intent(in),optional :: error_message ! error message
character (len=*), intent(in),optional :: file ! file
integer (kind=int_kind), intent(in), optional :: line ! line number
logical (kind=log_kind), intent(in), optional :: doabort ! abort flag

! local variables

#ifndef CESMCOUPLED
integer (int_kind) :: &
ierr, & ! MPI error flag
outunit, & ! output unit
error_code ! return code
#endif
logical (log_kind) :: ldoabort ! local doabort flag
character(len=*), parameter :: subname='(abort_ice)'

ldoabort = .true.
if (present(doabort)) ldoabort = doabort

#if (defined CESMCOUPLED)
call flush_fileunit(nu_diag)
call icepack_warnings_flush(nu_diag)
write(nu_diag,*) ' '
write(nu_diag,*) subname, 'ABORTED: '
if (present(file)) write (nu_diag,*) subname,' called from ',trim(file)
if (present(line)) write (nu_diag,*) subname,' line number ',line
if (present(error_message)) write (nu_diag,*) subname,' error = ',trim(error_message)
call flush_fileunit(nu_diag)
if (ldoabort) call shr_sys_abort(subname//trim(error_message))
outunit = nu_diag
#else
outunit = ice_stderr
#endif

call flush_fileunit(nu_diag)
call icepack_warnings_flush(nu_diag)
write(ice_stderr,*) ' '
write(ice_stderr,*) subname, 'ABORTED: '
if (present(file)) write (ice_stderr,*) subname,' called from ',trim(file)
if (present(line)) write (ice_stderr,*) subname,' line number ',line
if (present(error_message)) write (ice_stderr,*) subname,' error = ',trim(error_message)
call flush_fileunit(ice_stderr)
error_code = 128
write(outunit,*) ' '
write(outunit,*) subname, 'ABORTED: '
if (present(file)) write (outunit,*) subname,' called from ',trim(file)
if (present(line)) write (outunit,*) subname,' line number ',line
if (present(error_message)) write (outunit,*) subname,' error = ',trim(error_message)
call flush_fileunit(outunit)

if (ldoabort) then
#if (defined CESMCOUPLED)
call shr_sys_abort(subname//trim(error_message))
#else
#ifndef SERIAL_REMOVE_MPI
error_code = 128
call MPI_ABORT(MPI_COMM_WORLD, error_code, ierr)
#endif
stop
endif
#endif
endif

end subroutine abort_ice

!=======================================================================

subroutine end_run

! Ends run by calling MPI_FINALIZE.
! Ends run by calling MPI_FINALIZE
! Does nothing in serial runs

integer (int_kind) :: ierr ! MPI error flag
character(len=*), parameter :: subname = '(end_run)'

#ifndef SERIAL_REMOVE_MPI
call MPI_FINALIZE(ierr)
#endif

end subroutine end_run

Expand Down
Loading

0 comments on commit 5b0418a

Please sign in to comment.