dtfft_backend_cufftmp.F90 Source File


This file depends on

sourcefile~~dtfft_backend_cufftmp.f90~~EfferentGraph sourcefile~dtfft_backend_cufftmp.f90 dtfft_backend_cufftmp.F90 sourcefile~dtfft_abstract_backend.f90 dtfft_abstract_backend.F90 sourcefile~dtfft_backend_cufftmp.f90->sourcefile~dtfft_abstract_backend.f90 sourcefile~dtfft_errors.f90 dtfft_errors.F90 sourcefile~dtfft_backend_cufftmp.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_interface_cuda_runtime.f90 dtfft_interface_cuda_runtime.F90 sourcefile~dtfft_backend_cufftmp.f90->sourcefile~dtfft_interface_cuda_runtime.f90 sourcefile~dtfft_interface_cufft.f90 dtfft_interface_cufft.F90 sourcefile~dtfft_backend_cufftmp.f90->sourcefile~dtfft_interface_cufft.f90 sourcefile~dtfft_interface_nvshmem.f90 dtfft_interface_nvshmem.F90 sourcefile~dtfft_backend_cufftmp.f90->sourcefile~dtfft_interface_nvshmem.f90 sourcefile~dtfft_parameters.f90 dtfft_parameters.F90 sourcefile~dtfft_backend_cufftmp.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_pencil.f90 dtfft_pencil.F90 sourcefile~dtfft_backend_cufftmp.f90->sourcefile~dtfft_pencil.f90 sourcefile~dtfft_utils.f90 dtfft_utils.F90 sourcefile~dtfft_backend_cufftmp.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_interface_cuda_runtime.f90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_pencil.f90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_abstract_kernel.f90 dtfft_abstract_kernel.F90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_abstract_kernel.f90 sourcefile~dtfft_interface_nccl.f90 dtfft_interface_nccl.F90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_interface_nccl.f90 sourcefile~dtfft_interface_nvtx.f90 dtfft_interface_nvtx.F90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_interface_nvtx.f90 sourcefile~dtfft_interface_cuda_runtime.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_interface_cuda_runtime.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_interface_cufft.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_interface_cufft.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_interface_nvshmem.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_interface_nvshmem.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_pencil.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_pencil.f90->sourcefile~dtfft_interface_cuda_runtime.f90 sourcefile~dtfft_pencil.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_pencil.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_config.f90 dtfft_config.F90 sourcefile~dtfft_pencil.f90->sourcefile~dtfft_config.f90 sourcefile~dtfft_utils.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_utils.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_config.f90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_interface_nvtx.f90 sourcefile~dtfft_abstract_compressor.f90 dtfft_abstract_compressor.F90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_abstract_compressor.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_interface_cuda_runtime.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_abstract_compressor.f90 sourcefile~dtfft_interface_nccl.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_interface_nccl.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_interface_nvtx.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_abstract_compressor.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_abstract_compressor.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_abstract_compressor.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_abstract_compressor.f90->sourcefile~dtfft_interface_nvtx.f90

Files dependent on this one

sourcefile~~dtfft_backend_cufftmp.f90~~AfferentGraph sourcefile~dtfft_backend_cufftmp.f90 dtfft_backend_cufftmp.F90 sourcefile~dtfft_reshape_handle_generic.f90 dtfft_reshape_handle_generic.F90 sourcefile~dtfft_reshape_handle_generic.f90->sourcefile~dtfft_backend_cufftmp.f90 sourcefile~dtfft_reshape_plan_base.f90 dtfft_reshape_plan_base.F90 sourcefile~dtfft_reshape_plan_base.f90->sourcefile~dtfft_reshape_handle_generic.f90 sourcefile~dtfft_reshape_plan.f90 dtfft_reshape_plan.F90 sourcefile~dtfft_reshape_plan.f90->sourcefile~dtfft_reshape_plan_base.f90 sourcefile~dtfft_transpose_plan.f90 dtfft_transpose_plan.F90 sourcefile~dtfft_transpose_plan.f90->sourcefile~dtfft_reshape_plan_base.f90 sourcefile~dtfft_plan.f90 dtfft_plan.F90 sourcefile~dtfft_plan.f90->sourcefile~dtfft_reshape_plan.f90 sourcefile~dtfft_plan.f90->sourcefile~dtfft_transpose_plan.f90 sourcefile~dtfft.f90 dtfft.F90 sourcefile~dtfft.f90->sourcefile~dtfft_plan.f90 sourcefile~dtfft_api.f90 dtfft_api.F90 sourcefile~dtfft_api.f90->sourcefile~dtfft_plan.f90

Source Code

!------------------------------------------------------------------------------------------------
! Copyright (c) 2021 - 2025, Oleg Shatrov
! All rights reserved.
! This file is part of dtFFT library.

! dtFFT is free software: you can redistribute it and/or modify
! it under the terms of the GNU General Public License as published by
! the Free Software Foundation, either version 3 of the License, or
! (at your option) any later version.

! dtFFT is distributed in the hope that it will be useful,
! but WITHOUT ANY WARRANTY; without even the implied warranty of
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
! GNU General Public License for more details.

! You should have received a copy of the GNU General Public License
! along with this program.  If not, see <https://www.gnu.org/licenses/>.
!------------------------------------------------------------------------------------------------
#include "dtfft_config.h"
module dtfft_backend_cufftmp_m
!! cuFFTMp GPU Backend [[backend_cufftmp]]
use iso_fortran_env
use iso_c_binding
use dtfft_abstract_backend, only: abstract_backend, backend_helper
use dtfft_interface_nvshmem
use dtfft_interface_cuda_runtime
use dtfft_interface_cufft
use dtfft_errors
use dtfft_parameters
use dtfft_pencil, only: pencil
use dtfft_utils
#include "_dtfft_mpi.h"
#include "_dtfft_cuda.h"
#include "_dtfft_private.h"
implicit none
private
public :: backend_cufftmp

    type :: Box3D
    !! cuFFTMp Box
        integer(c_long_long) :: lower(3)   !! Lower box boundaries
        integer(c_long_long) :: upper(3)   !! Upper box boundaries
        integer(c_long_long) :: strides(3) !! Strides in memory
    end type Box3D

    type, extends(abstract_backend) :: backend_cufftmp
    !! cuFFTMp GPU Backend
    private
        type(cufftReshapeHandle) :: plan
    contains
        procedure :: create_private => create
        procedure :: execute_private => execute
        procedure :: destroy_private => destroy
    end type backend_cufftmp

contains

    subroutine create(self, helper, base_storage)
    !! Creates cuFFTMp GPU Backend
        class(backend_cufftmp), intent(inout)   :: self               !! cuFFTMp GPU Backend
        type(backend_helper),   intent(in)      :: helper             !! Backend helper
        integer(int64),         intent(in)      :: base_storage       !! Number of bytes to store single element
        type(Box3D) :: inbox, outbox  !! Reshape boxes
        type(pencil), pointer :: in, out
        type(c_ptr) :: c_comm
        integer(int64) :: aux_size
        logical :: is_transpose

        is_transpose = .false.
        if ( is_valid_transpose_type(helper%transpose_type) ) then
            is_transpose = .true.
            select case (helper%transpose_type%val)
            case (DTFFT_TRANSPOSE_X_TO_Y%val)
                in => helper%pencils(1)
                out => helper%pencils(2)
            case (DTFFT_TRANSPOSE_Y_TO_X%val)
                in => helper%pencils(2)
                out => helper%pencils(1)
            case (DTFFT_TRANSPOSE_Y_TO_Z%val)
                in => helper%pencils(2)
                out => helper%pencils(3)
            case (DTFFT_TRANSPOSE_Z_TO_Y%val)
                in => helper%pencils(3)
                out => helper%pencils(2)
            case (DTFFT_TRANSPOSE_X_TO_Z%val)
                in => helper%pencils(1)
                out => helper%pencils(3)
            case (DTFFT_TRANSPOSE_Z_TO_X%val)
                in => helper%pencils(3)
                out => helper%pencils(1)
            case default
                INTERNAL_ERROR("backend_cufftmp: unknown `transpose_type`")
            end select
        else
            select case (helper%reshape_type%val)
            case ( DTFFT_RESHAPE_X_BRICKS_TO_PENCILS%val)
                in => helper%pencils(1)
                out => helper%pencils(3)
            case ( DTFFT_RESHAPE_X_PENCILS_TO_BRICKS%val )
                in => helper%pencils(3)
                out => helper%pencils(1)
            case ( DTFFT_RESHAPE_Z_BRICKS_TO_PENCILS%val )
                in => helper%pencils(2)
                out => helper%pencils(4)
            case ( DTFFT_RESHAPE_Z_PENCILS_TO_BRICKS%val )
                in => helper%pencils(4)
                out => helper%pencils(2)
            case default
                INTERNAL_ERROR("backend_cufftmp: unknown `transpose_type`")
            endselect
        endif

        if (in%rank == 3) then
            if ( is_transpose ) then
                if (any(helper%transpose_type == [DTFFT_TRANSPOSE_X_TO_Z, DTFFT_TRANSPOSE_Y_TO_X, DTFFT_TRANSPOSE_Z_TO_Y])) then
                    inbox%lower = [in%starts(2), in%starts(1), in%starts(3)]
                    inbox%upper = [in%starts(2) + in%counts(2), in%starts(1) + in%counts(1), in%starts(3) + in%counts(3)]
                    inbox%strides = [in%counts(1) * in%counts(3), in%counts(3), 1]
                else
                    inbox%lower = [in%starts(1), in%starts(3), in%starts(2)]
                    inbox%upper = [in%starts(1) + in%counts(1), in%starts(3) + in%counts(3), in%starts(2) + in%counts(2)]
                    inbox%strides = [in%counts(2) * in%counts(3), in%counts(2), 1]
                end if
            else
                inbox%lower = [in%starts(3), in%starts(2), in%starts(1)]
                inbox%upper = [in%starts(3) + in%counts(3), in%starts(2) + in%counts(2), in%starts(1) + in%counts(1)]
                inbox%strides = [in%counts(1) * in%counts(2), in%counts(1), 1]
            endif

            outbox%lower = [out%starts(3), out%starts(2), out%starts(1)]
            outbox%upper = [out%starts(3) + out%counts(3), out%starts(2) + out%counts(2), out%starts(1) + out%counts(1)]
            outbox%strides = [out%counts(1) * out%counts(2), out%counts(1), 1]
        else
            if ( is_transpose ) then
                inbox%lower = [0, in%starts(1), in%starts(2)]
                inbox%upper = [1, in%starts(1) + in%counts(1), in%starts(2) + in%counts(2)]
                inbox%strides = [in%counts(1) * in%counts(2), in%counts(2), 1]
            else
                inbox%lower = [0, in%starts(2), in%starts(1)]
                inbox%upper = [1, in%starts(2) + in%counts(2), in%starts(1) + in%counts(1)]
                inbox%strides = [in%counts(1) * in%counts(2), in%counts(1), 1]
            endif

            outbox%lower = [0, out%starts(2), out%starts(1)]
            outbox%upper = [1, out%starts(2) + out%counts(2), out%starts(1) + out%counts(1)]
            outbox%strides = [out%counts(1) * out%counts(2), out%counts(1), 1]
        end if

        CUFFT_CALL( cufftMpCreateReshape(self%plan) )
        c_comm = Comm_f2c(GET_MPI_VALUE(helper%comms(1)))
        CUFFT_CALL( cufftMpAttachReshapeComm(self%plan, CUFFT_COMM_MPI, c_comm) )
        CUFFT_CALL( cufftMpMakeReshape(self%plan, base_storage, 3, inbox%lower, inbox%upper, outbox%lower, outbox%upper, inbox%strides, outbox%strides) )
        CUFFT_CALL( cufftMpGetReshapeSize(self%plan, aux_size) )
        self%aux_bytes = max(aux_size, self%aux_bytes)
    end subroutine create

    subroutine execute(self, in, out, stream, aux, error_code)
    !! Executes cuFFTMp GPU Backend
        class(backend_cufftmp),             intent(inout)   :: self       !! cuFFTMp GPU Backend
        real(real32),   target, contiguous, intent(inout)   :: in(:)      !! Send pointer
        real(real32),   target, contiguous, intent(inout)   :: out(:)     !! Recv pointer
        type(dtfft_stream_t),               intent(in)      :: stream     !! Main execution CUDA stream
        real(real32),   target, contiguous, intent(inout)   :: aux(:)     !! Aux pointer
        integer(int32),                     intent(out)     :: error_code !! Error code
        integer(int32) :: ierr

        ! call nvshmemx_sync_all_on_stream(stream)
        CUDA_CALL( cudaStreamSynchronize(stream) )
        call MPI_Barrier(self%comm, ierr)
        CUFFT_CALL( cufftMpExecReshapeAsync(self%plan, c_loc(out), c_loc(in), c_loc(aux), stream) )
        error_code = DTFFT_SUCCESS
    end subroutine execute

    subroutine destroy(self)
    !! Destroys cuFFTMp GPU Backend
        class(backend_cufftmp), intent(inout) :: self        !! cuFFTMp GPU Backend

        CUFFT_CALL( cufftMpDestroyReshape(self%plan) )
    end subroutine destroy
end module dtfft_backend_cufftmp_m