Skip to content

Commit

Permalink
SchurSolve Work
Browse files Browse the repository at this point in the history
  • Loading branch information
OsKnoth committed Sep 21, 2024
1 parent da1d0df commit d0dcc0e
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 105 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ LocalPreferences.toml
list1
TestScore_P.jl
OutRace
src/Integration/SchurSolve.jl_Ori
2 changes: 1 addition & 1 deletion Jobs/NHSphere/BaroWaveDrySphere_32Elem
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
julia --project --check-bounds=yes Examples/testNHSphere.jl \
julia --project Examples/testNHSphere.jl \
--Problem="BaroWaveDrySphere" \
--FloatTypeBackend="Float32" \
--NumberThreadGPU=512 \
Expand Down
6 changes: 3 additions & 3 deletions src/GPU/FcnGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -523,10 +523,10 @@ NVTX.@annotate function FcnGPU!(F,U,FE,Metric,Phys,Cache,Exchange,Global,Param,E

# KGradKernel!(F,U,p,DS,dXdxI_I,J_I,X_I,M,Glob_I,GravitationFun,ndrange=ndrangeI)
# KernelAbstractions.synchronize(backend)
KMomentumCoriolisKernel!(F,U,DS,dXdxI_I,J_I,X_I,M,Glob_I,CoriolisFun,ndrange=ndrangeI)
KernelAbstractions.synchronize(backend)
@time KMomentumCoriolisKernel!(F,U,DS,dXdxI_I,J_I,X_I,M,Glob_I,CoriolisFun,ndrange=ndrangeI)
# KernelAbstractions.synchronize(backend)
# KRhoGradKinKernel!(F,U,DS,dXdxI_I,J_I,M,Glob_I,ndrange=ndrangeI)
KGradFullKernel!(F,U,p,DS,dXdxI_I,X_I,J_I,M,Glob_I,GravitationFun,ndrange=ndrangeI)
@time KGradFullKernel!(F,U,p,DS,dXdxI_I,X_I,J_I,M,Glob_I,GravitationFun,ndrange=ndrangeI)
# KernelAbstractions.synchronize(backend)

if State == "Dry" || State == "ShallowWater"
Expand Down
4 changes: 1 addition & 3 deletions src/Integration/RosenbrockSchur.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@ function RosenbrockSchur!(V,dt,Fcn!,FcnPrepare!,Jac,CG,Metric,Phys,Cache,JCache,
@views @. V = V + ROS.a[iStage,jStage] * k[:,:,:,jStage]
end
FcnPrepare!(V,CG,Metric,Phys,Cache,Exchange,Global,Param,DiscType)
@time Fcn!(fV,V,CG,Metric,Phys,Cache,Exchange,Global,Param,DiscType)
Fcn!(fV,V,CG,Metric,Phys,Cache,Exchange,Global,Param,DiscType)
if iStage == 1
Jac(JCache,V,CG,Metric,Phys,Cache,Global,Param,DiscType)
end
@inbounds for jStage = 1 : iStage - 1
fac = ROS.c[iStage,jStage] / dt
@views @. fV = fV + fac * k[:,:,:,jStage]
end
@time begin
@views SchurSolveGPU!(k[:,:,:,iStage],fV,JCache,dt*ROS.gamma,Cache,Global)
end
end
@. V = Vn
@inbounds for iStage = 1 : nStage
Expand Down
99 changes: 1 addition & 98 deletions src/Integration/SchurSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,74 +73,6 @@ end
end
end

@kernel inbounds = true function SchurSolveFacKernel!(NumVTr,Nz,k,v,tri,@Const(JRhoW),@Const(JWRho),@Const(JWRhoTh),@Const(JRhoThW),fac)
IC, = @index(Global, NTuple)

NumG = @uniform @ndrange()[1]

invfac = 1 / fac
invfac2 = invfac / fac

if IC <= NumG
@views rRho=v[:,IC,1]
@views rTh=v[:,IC,5]
@views rw=v[1:Nz-1,IC,4]
@views sw=k[1:Nz-1,IC,4]
k[end,IC,4] = 0
@views @. tri[1,:,IC] = 0
@views @. tri[2,:,IC] = invfac2
@views @. tri[3,:,IC] = 0
@views mulUL!(tri[:,:,IC],JWRho[:,:,IC],JRhoW[:,:,IC])
@views mulUL!(tri[:,:,IC],JWRhoTh[:,:,IC],JRhoThW[:,:,IC])
@. rw = invfac * rw
@views mulbiUv!(rw,JWRho[:,:,IC],rRho)
@views mulbiUv!(rw,JWRhoTh[:,:,IC],rTh)
@views triSolve!(sw,tri[:,:,IC],rw)
@views mulbiLv!(rRho,JRhoW[:,:,IC],sw)
@views mulbiLv!(rTh,JRhoThW[:,:,IC],sw)
for iz = 1 : Nz
k[iz,IC,1] = fac * v[iz,IC,1]
k[iz,IC,2] = fac * v[iz,IC,2]
k[iz,IC,3] = fac * v[iz,IC,3]
k[iz,IC,5] = fac * v[iz,IC,5]
for iT = 6 : NumVTr
k[iz,IC,iT] = fac * v[iz,IC,iT]
end
end
end
end

@kernel inbounds = true function SchurSolveKernel!(NumVTr,Nz,k,v,tri,@Const(JRhoW),@Const(JWRho),@Const(JWRhoTh),@Const(JRhoThW),fac)
IC, = @index(Global, NTuple)

NumG = @uniform @ndrange()[1]

invfac = 1 / fac
invfac2 = invfac / fac
if IC <= NumG
@views rRho=v[:,IC,1]
@views rTh=v[:,IC,5]
@views rw=v[1:Nz-1,IC,4]
@views sw=k[1:Nz-1,IC,4]
k[end,IC,4] = 0
@. rw = invfac * rw
@views mulbiUv!(rw,JWRho[:,:,IC],rRho)
@views mulbiUv!(rw,JWRhoTh[:,:,IC],rTh)
@views triSolve!(sw,tri[:,:,IC],rw)
@views mulbiLv!(rRho,JRhoW[:,:,IC],sw)
@views mulbiLv!(rTh,JRhoThW[:,:,IC],sw)
for iz = 1 : Nz
k[iz,IC,1] = fac * v[iz,IC,1]
k[iz,IC,2] = fac * v[iz,IC,2]
k[iz,IC,3] = fac * v[iz,IC,3]
k[iz,IC,5] = fac * v[iz,IC,5]
for iT = 6 : NumVTr
k[iz,IC,iT] = fac * v[iz,IC,iT]
end
end
end
end

@kernel inbounds = true function SchurSolveFKernel!(k,v,@Const(JWRho),@Const(JWRhoTh),fac)
Iz,IC, = @index(Global, NTuple)
NumG = @uniform @ndrange()[2]
Expand Down Expand Up @@ -181,7 +113,7 @@ end
end
end

@kernel inbounds = true function SchurSolveTriKernel!(k,v,@Const(tri))
@kernel inbounds = true function SchurSolveTriKernel!(Nz,k,v,@Const(tri))
IC, = @index(Global, NTuple)

NumG = @uniform @ndrange()[1]
Expand Down Expand Up @@ -212,35 +144,6 @@ NVTX.@annotate function SchurSolveGPU!(k,v,J,fac,Cache,Global)
J.CompTri = false
end

NVTX.@annotate function SchurSolveGPU1!(k,v,J,fac,Cache,Global)
backend = get_backend(k)
FT = eltype(k)

Nz = size(k,1)
NumG = size(k,2)
NumVTr = size(k,3)

group = (Nz,10)
ndrange = (Nz,NumG)
groupTriDiag = (Nz-1,10)
ndrangeTriDiag = (Nz-1,NumG)
# group = (1024)
groupTri = (64)
ndrangeTri = (NumG)

if J.CompTri
KTriDiagKernel! = TriDiagKernel!(backend,groupTriDiag)
KTriDiagKernel!(J.tri,J.JRhoW,J.JWRho,J.JWRhoTh,J.JRhoThW,fac,ndrange=ndrangeTriDiag)
J.CompTri = false
end
KSchurSolveKernelF! = SchurSolveKernelF!(backend,group)
KSchurSolveKernelF!(k,v,J.JWRho,J.JWRhoTh,fac,ndrange=ndrange)
KSchurSolveTriKernel! = SchurSolveTriKernel!(backend,groupTri)
KSchurSolveTriKernel!(k,v,J.tri,ndrange=ndrangeTri)
KSchurSolveBKernel! = SchurSolveBKernel!(backend,group)
KSchurSolveBKernel!(NumVTr,k,v,J.JRhoW,J.JRhoThW,fac,ndrange=ndrange)
end

function SchurSolve!(k,v,J,fac,Cache,Global)
# sw=(spdiags(repmat(invfac2,n,1),0,n,n)-invfac*JWW-JWRho*JRhoW-JWRhoTh*JRhoThW)\
# (invfac*rw+JWRho*rRho+JWRhoTh*rTh)
Expand Down

0 comments on commit d0dcc0e

Please sign in to comment.