dtfft_interface_cufft.F90 Source File


This file depends on

sourcefile~~dtfft_interface_cufft.f90~~EfferentGraph sourcefile~dtfft_interface_cufft.f90 dtfft_interface_cufft.F90 sourcefile~dtfft_parameters.f90 dtfft_parameters.F90 sourcefile~dtfft_interface_cufft.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_utils.f90 dtfft_utils.F90 sourcefile~dtfft_interface_cufft.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_utils.f90->sourcefile~dtfft_parameters.f90

Files dependent on this one

sourcefile~~dtfft_interface_cufft.f90~~AfferentGraph sourcefile~dtfft_interface_cufft.f90 dtfft_interface_cufft.F90 sourcefile~dtfft_backend_cufftmp.f90 dtfft_backend_cufftmp.F90 sourcefile~dtfft_backend_cufftmp.f90->sourcefile~dtfft_interface_cufft.f90 sourcefile~dtfft_executor_cufft_m.f90 dtfft_executor_cufft_m.F90 sourcefile~dtfft_executor_cufft_m.f90->sourcefile~dtfft_interface_cufft.f90 sourcefile~dtfft_plan.f90 dtfft_plan.F90 sourcefile~dtfft_plan.f90->sourcefile~dtfft_executor_cufft_m.f90 sourcefile~dtfft_transpose_plan_cuda.f90 dtfft_transpose_plan_cuda.F90 sourcefile~dtfft_plan.f90->sourcefile~dtfft_transpose_plan_cuda.f90 sourcefile~dtfft_transpose_handle_cuda.f90 dtfft_transpose_handle_cuda.F90 sourcefile~dtfft_transpose_handle_cuda.f90->sourcefile~dtfft_backend_cufftmp.f90 sourcefile~dtfft.f90 dtfft.F90 sourcefile~dtfft.f90->sourcefile~dtfft_plan.f90 sourcefile~dtfft_api.f90 dtfft_api.F90 sourcefile~dtfft_api.f90->sourcefile~dtfft_plan.f90 sourcefile~dtfft_transpose_plan_cuda.f90->sourcefile~dtfft_transpose_handle_cuda.f90

Source Code

!------------------------------------------------------------------------------------------------
! Copyright (c) 2021, Oleg Shatrov
! All rights reserved.
! This file is part of dtFFT library.

! dtFFT is free software: you can redistribute it and/or modify
! it under the terms of the GNU General Public License as published by
! the Free Software Foundation, either version 3 of the License, or
! (at your option) any later version.

! dtFFT is distributed in the hope that it will be useful,
! but WITHOUT ANY WARRANTY; without even the implied warranty of
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
! GNU General Public License for more details.

! You should have received a copy of the GNU General Public License
! along with this program.  If not, see <https://www.gnu.org/licenses/>.
!------------------------------------------------------------------------------------------------
module dtfft_interface_cufft
!! cuFFT Interfaces
use iso_c_binding
use iso_fortran_env
use dtfft_parameters
! #ifdef DTFFT_WITH_NVSHMEM
! use dtfft_interface_nvshmem,              only: load_nvshmem
! #endif
use dtfft_utils
implicit none
private
public :: CUFFT_R2C, CUFFT_C2R, CUFFT_C2C
public :: CUFFT_D2Z, CUFFT_Z2D, CUFFT_Z2Z
public :: cufftGetErrorString
! public :: load_cufft


  integer(c_int), parameter,  public :: CUFFT_COMM_MPI = 0

  enum, bind(C)
    enumerator :: CUFFT_R2C = 42
    enumerator :: CUFFT_C2R = 44
    enumerator :: CUFFT_C2C = 41
    enumerator :: CUFFT_D2Z = 106
    enumerator :: CUFFT_Z2D = 108
    enumerator :: CUFFT_Z2Z = 105
  end enum

  enum, bind(C)
    enumerator :: CUFFT_SUCCESS = 0
    enumerator :: CUFFT_INVALID_PLAN = 1
    enumerator :: CUFFT_ALLOC_FAILED = 2
    enumerator :: CUFFT_INVALID_TYPE = 3
    enumerator :: CUFFT_INVALID_VALUE = 4
    enumerator :: CUFFT_INTERNAL_ERROR = 5
    enumerator :: CUFFT_EXEC_FAILED = 6
    enumerator :: CUFFT_SETUP_FAILED = 7
    enumerator :: CUFFT_INVALID_SIZE = 8
    enumerator :: CUFFT_UNALIGNED_DATA = 9
    enumerator :: CUFFT_INCOMPLETE_PARAMETER_LIST = 10
    enumerator :: CUFFT_INVALID_DEVICE = 11
    enumerator :: CUFFT_PARSE_ERROR = 12
    enumerator :: CUFFT_NO_WORKSPACE = 13
    enumerator :: CUFFT_NOT_IMPLEMENTED = 14
    enumerator :: CUFFT_LICENSE_ERROR = 15
    enumerator :: CUFFT_NOT_SUPPORTED = 16
  end enum

public :: cufftReshapeHandle
  type, bind(C) :: cufftReshapeHandle
  !! An opaque handle to a reshape operation.
    type(c_ptr) :: cptr
  end type cufftReshapeHandle

public :: cufftPlanMany
  interface
  !! Creates a FFT plan configuration of dimension rank, with sizes specified in the array n.
    function cufftPlanMany(plan, rank, n, inembed, istride, idist, onembed, ostride, odist, ffttype, batch)     &
      result(cufftResult)                                                                                       &
      bind(C, name="cufftPlanMany")
    import
      type(c_ptr)                       :: plan
        !! Pointer to an uninitialized cufftHandle object.
      integer(c_int),             value :: rank
        !! Dimensionality of the transform (1, 2, or 3).
      integer(c_int)                    :: n(*)
        !! Array of size rank, describing the size of each dimension,
        !! n[0] being the size of the outermost
        !! and n[rank-1] innermost (contiguous) dimension of a transform.
      integer(c_int)                    :: inembed(*)
        !! Pointer of size rank that indicates the storage dimensions of the input data in memory.
        !! If set to NULL, all other advanced data layout parameters are ignored.
      integer(c_int),             value :: istride
        !! Indicates the distance between two successive input elements in the least
        !! significant (i.e., innermost) dimension.
      integer(c_int),             value :: idist
        !! Indicates the distance between the first element of two consecutive signals
        !! in a batch of the input data.
      integer(c_int)                    :: onembed(*)
        !! Pointer of size rank that indicates the storage dimensions of the output data in memory.
        !! If set to NULL, all other advanced data layout parameters are ignored.
      integer(c_int),             value :: ostride
        !! Indicates the distance between two successive output elements in the output array
        !! in the least significant (i.e., innermost) dimension.
      integer(c_int),             value :: odist
        !! Indicates the distance between the first element of two consecutive signals
        !! in a batch of the output data.
      integer(c_int),             value :: ffttype
        !! The transform data type (e.g., CUFFT_R2C for single precision real to complex).
      integer(c_int),             value :: batch
        !! Batch size for this transform.
      integer(c_int)                    :: cufftResult
        !! The enumerated type cufftResult defines API call result codes.
    end function cufftPlanMany
  end interface

public :: cufftXtExec
  interface
  !! Executes any cuFFT transform regardless of precision and type.
  !! In case of complex-to-real and real-to-complex transforms, the direction parameter is ignored.
    function cufftXtExec(plan, input, output, direction)                                                        &
      result(cufftResult)                                                                                       &
      bind(C, name="cufftXtExec")
    import
      type(c_ptr),                value :: plan
        !! cufftHandle returned by cufftCreate.
      type(c_ptr),                value :: input
        !! Pointer to the input data (in GPU memory) to transform.
      type(c_ptr),                value :: output
        !! Pointer to the output data (in GPU memory).
      integer(c_int),             value :: direction
        !! The transform direction: CUFFT_FORWARD or CUFFT_INVERSE.
        !! Ignored for complex-to-real and real-to-complex transforms.
      integer(c_int)                    :: cufftResult
        !! The enumerated type cufftResult defines API call result codes.
    end function cufftXtExec
  end interface

public :: cufftDestroy
  interface
  !! Frees all GPU resources associated with a cuFFT plan and destroys the internal plan data structure.
    function cufftDestroy(plan)                                                                                 &
      result(cufftResult)                                                                                       &
      bind(C, name="cufftDestroy")
    import
      type(c_ptr), value :: plan         !! Object of the plan to be destroyed.
      integer(c_int)     :: cufftResult  !! The enumerated type cufftResult defines API call result codes.
    end function cufftDestroy
  end interface

public :: cufftSetStream
  interface
  !! Associates a CUDA stream with a cuFFT plan.
    function cufftSetStream(plan, stream) &
      result(cufftResult) &
      bind(C, name="cufftSetStream")
    import
      type(c_ptr),          value :: plan
        !! Object to associate with the stream.
      type(dtfft_stream_t), value :: stream
        !! A valid CUDA stream.
      integer(c_int)              :: cufftResult
        !! The enumerated type cufftResult defines API call result codes.
    end function cufftSetStream
  end interface

public :: cufftMpCreateReshape
  interface
  !! Initializes a reshape handle for future use. This function is not collective.
    function cufftMpCreateReshape(reshapeHandle)                                                                &
      result(cufftResult)                                                                                       &
      bind(C, name="cufftMpCreateReshape")
    import
      type(cufftReshapeHandle) :: reshapeHandle
        !! The reshape handle.
      integer(c_int)           :: cufftResult
        !! The enumerated type cufftResult defines API call result codes.
    end function cufftMpCreateReshape
  end interface

public :: cufftMpAttachReshapeComm
  interface
  !! Attaches a communication handle to a reshape. This function is not collective.
    function cufftMpAttachReshapeComm(reshapeHandle, commType, comm)                                            &
      result(cufftResult)                                                                                       &
      bind(C, name="cufftMpAttachReshapeComm")
    import
      type(cufftReshapeHandle), value :: reshapeHandle
        !! The reshape handle.
      integer(c_int),           value :: commType
        !! An enum describing the communication type of the handle.
      type(c_ptr)                     :: comm
        !! If commType is CUFFT_COMM_MPI, this should be a pointer to an MPI communicator.
        !! The pointer should remain valid until destruction of the handle.
        !! Otherwise, this should be NULL.
      integer(c_int)                  :: cufftResult
        !! The enumerated type cufftResult defines API call result codes.
    end function cufftMpAttachReshapeComm
  end interface

public :: cufftMpGetReshapeSize
  interface
  !! Returns the amount (in bytes) of workspace required to execute the handle.
    function cufftMpGetReshapeSize(reshapeHandle, workSize)                                                     &
      result(cufftResult)                                                                                       &
      bind(C, name="cufftMpGetReshapeSize")
    import
      type(cufftReshapeHandle), value :: reshapeHandle
        !! The reshape handle.
      integer(c_size_t)               :: workSize
        !! The size, in bytes, of the workspace required during reshape execution.
      integer(c_int)                  :: cufftResult
        !! The enumerated type cufftResult defines API call result codes.
    end function cufftMpGetReshapeSize
  end interface

public :: cufftMpMakeReshape
  interface
  !! Creates a reshape intended to re-distribute a global array of 3D data.
    function cufftMpMakeReshape(reshapeHandle, elementSize, rank, lower_input, upper_input,                     &
                                                                  lower_output, upper_output,                   &
                                                                  strides_input, strides_output)                &
      result(cufftResult)                                                                                       &
      bind(C, name="cufftMpMakeReshape")
    import
      type(cufftReshapeHandle), value :: reshapeHandle
        !! The reshape handle.
      integer(c_long_long),     value :: elementSize
        !! The size of the individual elements, in bytes. Allowed values are 4, 8, and 16.
      integer(c_int),           value :: rank
        !! The length of the lower_input, upper_input, lower_output, upper_output, strides_input, and strides_output arrays. rank should be 3.
      integer(c_long_long)            :: lower_input(*)
        !! An array of length rank, representing the lower-corner of the portion of the global nx * ny * nz array owned by the current process in the input descriptor.
      integer(c_long_long)            :: upper_input(*)
        !! An array of length rank, representing the upper-corner of the portion of the global nx * ny * nz array owned by the current process in the input descriptor.
      integer(c_long_long)            :: lower_output(*)
        !! An array of length rank, representing the lower-corner of the portion of the global nx * ny * nz array owned by the current process in the output descriptor.
      integer(c_long_long)            :: upper_output(*)
        !! An array of length rank, representing the upper-corner of the portion of the global nx * ny * nz array owned by the current process in the output descriptor.
      integer(c_long_long)            :: strides_input(*)
        !! An array of length rank, representing the local data layout of the input descriptor in memory. All entries must be decreasing and positive.
      integer(c_long_long)            :: strides_output(*)
        !! An array of length rank, representing the local data layout of the output descriptor in memory. All entries must be decreasing and positive.
      integer(c_int)                  :: cufftResult
        !! The enumerated type cufftResult defines API call result codes.
    end function cufftMpMakeReshape
  end interface

public :: cufftMpExecReshapeAsync
  interface
  !! Executes the reshape, redistributing data_in into data_out using the workspace in workspace.
    function cufftMpExecReshapeAsync(reshapeHandle, dataOut, dataIn, workSpace, stream)                         &
      result(cufftResult)                                                                                       &
      bind(C, name="cufftMpExecReshapeAsync")
    import
      type(cufftReshapeHandle), value :: reshapeHandle
        !! The reshape handle.
      type(c_ptr),              value :: dataOut
        !! A symmetric-heap pointer to the output data. This memory should be NVSHMEM allocated and identical on all processes.
      type(c_ptr),              value :: dataIn
        !! A symmetric-heap pointer to the input data. This memory should be NVSHMEM allocated and identical on all processes.
      type(c_ptr),              value :: workSpace
        !! A symmetric-heap pointer to the workspace data. This memory should be NVSHMEM allocated and identical on all processes.
      type(dtfft_stream_t),     value :: stream
        !! The CUDA stream in which to run the reshape operation.
      integer(c_int)                  :: cufftResult
        !! The enumerated type cufftResult defines API call result codes.
    end function cufftMpExecReshapeAsync
  end interface

public :: cufftMpDestroyReshape
  interface
  !! Destroys a reshape and all its associated data.
    function cufftMpDestroyReshape(reshapeHandle)                                                               &
      result(cufftResult)                                                                                       &
      bind(C, name="cufftMpDestroyReshape")
    import
      type(cufftReshapeHandle), value :: reshapeHandle  !! The reshape handle.
      integer(c_int)                  :: cufftResult    !! The enumerated type cufftResult defines API call result codes.
    end function cufftMpDestroyReshape
  end interface

!   logical, save :: is_loaded = .false.
!     !! Flag indicating whether the library is loaded
!   type(c_ptr), save :: libcufft
!     !! Handle to the loaded library

! #ifdef DTFFT_WITH_NVSHMEM
!   character(len=*), parameter :: LIB_NAME = "libcufftMp.so"
!   integer(c_int),   parameter :: CUFFT_MAX_FUNCTIONS = 10
! #else
!   character(len=*), parameter :: LIB_NAME = "libcufft.so"
!   integer(c_int),   parameter :: CUFFT_MAX_FUNCTIONS = 4
! #endif

!   type(c_funptr), save :: cufftFunctions(CUFFT_MAX_FUNCTIONS)
!     !! Array of pointers to the cuFFT functions

!   procedure(cufftPlanMany_interface),             pointer, public :: cufftPlanMany
!     !! Fortran pointer to the cufftPlanMany function
!   procedure(cufftXtExec_interface),               pointer, public :: cufftXtExec
!     !! Fortran pointer to the cufftXtExec function
!   procedure(cufftDestroy_interface),              pointer, public :: cufftDestroy
!     !! Fortran pointer to the cufftDestroy function
!   procedure(cufftSetStream_interface),            pointer, public :: cufftSetStream
!     !! Fortran pointer to the cufftSetStream function
!   procedure(cufftMpCreateReshape_interface),      pointer, public :: cufftMpCreateReshape
!     !! Fortran pointer to the cufftMpCreateReshape function
!   procedure(cufftMpAttachReshapeComm_interface),  pointer, public :: cufftMpAttachReshapeComm
!     !! Fortran pointer to the cufftMpAttachReshapeComm function
!   procedure(cufftMpGetReshapeSize_interface),     pointer, public :: cufftMpGetReshapeSize
!     !! Fortran pointer to the cufftMpGetReshapeSize function
!   procedure(cufftMpMakeReshape_interface),        pointer, public :: cufftMpMakeReshape
!     !! Fortran pointer to the cufftMpMakeReshape function
!   procedure(cufftMpExecReshapeAsync_interface),   pointer, public :: cufftMpExecReshapeAsync
!     !! Fortran pointer to the cufftMpExecReshapeAsync function
!   procedure(cufftMpDestroyReshape_interface),     pointer, public :: cufftMpDestroyReshape
!     !! Fortran pointer to the cufftMpDestroyReshape function

contains

!   function load_cufft() result(error_code)
!   !! Loads the cuFFT library and needed symbols
!     integer(int32)  :: error_code !! Error code
!     type(string), allocatable :: func_names(:)

!     error_code = DTFFT_SUCCESS
! !     if ( is_loaded ) return

! !     allocate(func_names(CUFFT_MAX_FUNCTIONS))
! !     func_names(1) = string("cufftPlanMany")
! !     func_names(2) = string("cufftXtExec")
! !     func_names(3) = string("cufftDestroy")
! !     func_names(4) = string("cufftSetStream")
! ! #ifdef DTFFT_WITH_NVSHMEM
! !     func_names(5) = string("cufftMpCreateReshape")
! !     func_names(6) = string("cufftMpAttachReshapeComm")
! !     func_names(7) = string("cufftMpGetReshapeSize")
! !     func_names(8) = string("cufftMpMakeReshape")
! !     func_names(9) = string("cufftMpExecReshapeAsync")
! !     func_names(10) = string("cufftMpDestroyReshape")
! ! #endif

! !     error_code = dynamic_load(LIB_NAME, func_names, libcufft, cufftFunctions)
! !     call destroy_strings(func_names)
! !     if ( error_code /= DTFFT_SUCCESS ) return

! !     call c_f_procpointer(cufftFunctions(1), cufftPlanMany)
! !     call c_f_procpointer(cufftFunctions(2), cufftXtExec)
! !     call c_f_procpointer(cufftFunctions(3), cufftDestroy)
! !     call c_f_procpointer(cufftFunctions(4), cufftSetStream)
! ! #ifdef DTFFT_WITH_NVSHMEM
! !     call c_f_procpointer(cufftFunctions(5), cufftMpCreateReshape)
! !     call c_f_procpointer(cufftFunctions(6), cufftMpAttachReshapeComm)
! !     call c_f_procpointer(cufftFunctions(7), cufftMpGetReshapeSize)
! !     call c_f_procpointer(cufftFunctions(8), cufftMpMakeReshape)
! !     call c_f_procpointer(cufftFunctions(9), cufftMpExecReshapeAsync)
! !     call c_f_procpointer(cufftFunctions(10), cufftMpDestroyReshape)

! !     error_code = load_nvshmem(libcufft)
! ! #endif

! !     is_loaded = .true.
!   end function load_cufft



  function cufftGetErrorString(error_code) result(string)
  !! Returns a string representation of the cuFFT error code.
    integer(c_int32_t), intent(in)  :: error_code   !! cuFFT error code
    character(len=:),   allocatable :: string       !! String representation of the cuFFT error code

    select case (error_code)
    case (CUFFT_SUCCESS)
      allocate(string, source="CUFFT_SUCCESS")
    case (CUFFT_INVALID_PLAN)
      allocate(string, source="CUFFT_INVALID_PLAN")
    case (CUFFT_ALLOC_FAILED)
      allocate(string, source="CUFFT_ALLOC_FAILED")
    case (CUFFT_INVALID_TYPE)
      allocate(string, source="CUFFT_INVALID_TYPE")
    case (CUFFT_INVALID_VALUE)
      allocate(string, source="CUFFT_INVALID_VALUE")
    case (CUFFT_INTERNAL_ERROR)
      allocate(string, source="CUFFT_INTERNAL_ERROR")
    case (CUFFT_EXEC_FAILED)
      allocate(string, source="CUFFT_EXEC_FAILED")
    case (CUFFT_SETUP_FAILED)
      allocate(string, source="CUFFT_SETUP_FAILED")
    case (CUFFT_INVALID_SIZE)
      allocate(string, source="CUFFT_INVALID_SIZE")
    case (CUFFT_UNALIGNED_DATA)
      allocate(string, source="CUFFT_UNALIGNED_DATA")
    case (CUFFT_INCOMPLETE_PARAMETER_LIST)
      allocate(string, source="CUFFT_INCOMPLETE_PARAMETER_LIST")
    case (CUFFT_INVALID_DEVICE)
      allocate(string, source="CUFFT_INVALID_DEVICE")
    case (CUFFT_PARSE_ERROR)
      allocate(string, source="CUFFT_PARSE_ERROR")
    case (CUFFT_NO_WORKSPACE)
      allocate(string, source="CUFFT_NO_WORKSPACE")
    case (CUFFT_NOT_IMPLEMENTED)
      allocate(string, source="CUFFT_NOT_IMPLEMENTED")
    case (CUFFT_LICENSE_ERROR)
      allocate(string, source="CUFFT_LICENSE_ERROR")
    case (CUFFT_NOT_SUPPORTED)
      allocate(string, source="CUFFT_NOT_SUPPORTED")
    case default
      allocate(string, source="Unknown CUFFT error code")
  end select
  end function cufftGetErrorString
end module dtfft_interface_cufft