dtfft_executor_vkfft_m.F90 Source File


This file depends on

sourcefile~~dtfft_executor_vkfft_m.f90~~EfferentGraph sourcefile~dtfft_executor_vkfft_m.f90 dtfft_executor_vkfft_m.F90 sourcefile~dtfft_abstract_executor.f90 dtfft_abstract_executor.F90 sourcefile~dtfft_executor_vkfft_m.f90->sourcefile~dtfft_abstract_executor.f90 sourcefile~dtfft_config.f90 dtfft_config.F90 sourcefile~dtfft_executor_vkfft_m.f90->sourcefile~dtfft_config.f90 sourcefile~dtfft_interface_vkfft_m.f90 dtfft_interface_vkfft_m.F90 sourcefile~dtfft_executor_vkfft_m.f90->sourcefile~dtfft_interface_vkfft_m.f90 sourcefile~dtfft_parameters.f90 dtfft_parameters.F90 sourcefile~dtfft_executor_vkfft_m.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_abstract_executor.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_interface_nvtx.f90 dtfft_interface_nvtx.F90 sourcefile~dtfft_abstract_executor.f90->sourcefile~dtfft_interface_nvtx.f90 sourcefile~dtfft_pencil.f90 dtfft_pencil.F90 sourcefile~dtfft_abstract_executor.f90->sourcefile~dtfft_pencil.f90 sourcefile~dtfft_utils.f90 dtfft_utils.F90 sourcefile~dtfft_abstract_executor.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_interface_cuda_runtime.f90 dtfft_interface_cuda_runtime.F90 sourcefile~dtfft_config.f90->sourcefile~dtfft_interface_cuda_runtime.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_interface_vkfft_m.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_interface_vkfft_m.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_interface_cuda_runtime.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_interface_cuda_runtime.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_interface_nvtx.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_pencil.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_pencil.f90->sourcefile~dtfft_interface_cuda_runtime.f90 sourcefile~dtfft_pencil.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_utils.f90->sourcefile~dtfft_parameters.f90

Files dependent on this one

sourcefile~~dtfft_executor_vkfft_m.f90~~AfferentGraph sourcefile~dtfft_executor_vkfft_m.f90 dtfft_executor_vkfft_m.F90 sourcefile~dtfft_plan.f90 dtfft_plan.F90 sourcefile~dtfft_plan.f90->sourcefile~dtfft_executor_vkfft_m.f90 sourcefile~dtfft.f90 dtfft.F90 sourcefile~dtfft.f90->sourcefile~dtfft_plan.f90 sourcefile~dtfft_api.f90 dtfft_api.F90 sourcefile~dtfft_api.f90->sourcefile~dtfft_plan.f90

Source Code

!------------------------------------------------------------------------------------------------
! Copyright (c) 2021, Oleg Shatrov
! All rights reserved.
! This file is part of dtFFT library.

! dtFFT is free software: you can redistribute it and/or modify
! it under the terms of the GNU General Public License as published by
! the Free Software Foundation, either version 3 of the License, or
! (at your option) any later version.

! dtFFT is distributed in the hope that it will be useful,
! but WITHOUT ANY WARRANTY; without even the implied warranty of
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
! GNU General Public License for more details.

! You should have received a copy of the GNU General Public License
! along with this program.  If not, see <https://www.gnu.org/licenses/>.
!------------------------------------------------------------------------------------------------
module dtfft_executor_vkfft_m
!! This module describes VkFFT based FFT Executor: [[vkfft_executor]]
!!
!! https://github.com/DTolm/VkFFT/tree/master
use iso_c_binding,                  only: c_ptr, c_int, c_int8_t
use iso_fortran_env,                only: int8, int32, int64
use dtfft_parameters
use dtfft_abstract_executor,        only: abstract_executor, FFT_C2C, FFT_R2C, FFT_R2R
use dtfft_interface_vkfft_m
use dtfft_config,                   only: get_user_stream, get_user_platform
implicit none
private
#include "dtfft_private.h"
public :: vkfft_executor

  type, extends(abstract_executor) :: vkfft_executor
  !! vkFFT FFT Executor
  private
    type(vkfft_wrapper), pointer  :: wrapper => null()
      !! VkFFT Wrapper
    logical                       :: is_inverse_required
      !! Should be create separate inverse FFT Plan or not
  contains
    procedure, pass(self)  :: create_private => create      !! Creates FFT plan via vkFFT Interface
    procedure, pass(self)  :: execute_private => execute    !! Executes vkFFT plan
    procedure, pass(self)  :: destroy_private => destroy    !! Destroys vkFFT plan
    procedure, nopass :: mem_alloc                          !! Dummy method. Raises `error stop`
    procedure, nopass :: mem_free                           !! Dummy method. Raises `error stop`
  end type vkfft_executor

contains

  subroutine create(self, fft_rank, fft_type, precision, idist, odist, how_many, fft_sizes, inembed, onembed, error_code, r2r_kinds)
  !! Creates FFT plan via vkFFT Interface
    class(vkfft_executor),            intent(inout) :: self           !! vkFFT FFT Executor
    integer(int8),                    intent(in)    :: fft_rank       !! Rank of fft: 1 or 2
    integer(int8),                    intent(in)    :: fft_type       !! Type of fft: r2r, r2c, c2c
    type(dtfft_precision_t),          intent(in)    :: precision      !! Precision of fft: DTFFT_SINGLE or DTFFT_DOUBLE
    integer(int32),                   intent(in)    :: idist          !! Distance between the first element of two consecutive signals in a batch of the input data.
    integer(int32),                   intent(in)    :: odist          !! Distance between the first element of two consecutive signals in a batch of the output data.
    integer(int32),                   intent(in)    :: how_many       !! Number of transforms to create
    integer(int32),                   intent(in)    :: fft_sizes(:)   !! Dimensions of transform
    integer(int32),                   intent(in)    :: inembed(:)     !! Storage dimensions of the input data in memory.
    integer(int32),                   intent(in)    :: onembed(:)     !! Storage dimensions of the output data in memory.
    integer(int32),                   intent(inout) :: error_code     !! Error code to be returned to user
    type(dtfft_r2r_kind_t), optional, intent(in)    :: r2r_kinds(:)   !! Kinds of r2r transform
    integer(c_int8_t)       :: r2c              !! Is R2C transform required
    integer(c_int8_t)       :: dct              !! Is DCT transform required
    integer(c_int8_t)       :: dst              !! Is DST transform required
    integer(c_int)          :: knd              !! Kind of r2r transform
    integer(c_int)          :: i                !! Loop index
    integer(c_int)          :: dims(2)          !! Dimensions of transform
    integer(c_int)          :: double_precision !! Precision of fft: DTFFT_SINGLE or DTFFT_DOUBLE
    type(dtfft_platform_t)  :: platfrom         !! Platform of the executor

    error_code = DTFFT_SUCCESS
    do i = 1, fft_rank
      dims(i) = fft_sizes(fft_rank - i + 1)
    enddo

    platfrom = get_user_platform()
    CHECK_CALL( load_vkfft(platfrom), error_code )

    if ( platfrom == DTFFT_PLATFORM_CUDA ) then
      self%wrapper => cuda_wrapper
    endif

    r2c = 0
    dct = 0
    dst = 0
    self%is_inverse_required = .false.
    select case ( fft_type )
    case ( FFT_R2C )
      r2c = 1
      self%is_inverse_required = .true.
    case ( FFT_R2R )
      knd = r2r_kinds(1)%val
      do i = 2, fft_rank
        if ( knd /= r2r_kinds(i)%val ) then
          error_code = DTFFT_ERROR_VKFFT_R2R_2D_PLAN
          return
        endif
      enddo
      select case ( knd )
      case ( DTFFT_DCT_1%val )
        dct = 1
      case ( DTFFT_DCT_2%val )
        dct = 2
      case ( DTFFT_DCT_3%val )
        dct = 3
      case ( DTFFT_DCT_4%val )
        dct = 4
      case ( DTFFT_DST_1%val )
        dst = 1
      case ( DTFFT_DST_2%val )
        dst = 2
      case ( DTFFT_DST_3%val )
        dst = 3
      case ( DTFFT_DST_4%val )
        dst = 4
      endselect
    endselect
    if ( precision == DTFFT_DOUBLE ) then
      double_precision = 1
    else
      double_precision = 0
    endif
    call self%wrapper%create(fft_rank, dims, double_precision, how_many, r2c, int(0, int8), dct, dst, get_user_stream(), self%plan_forward)
    if ( self%is_inverse_required ) then
      call self%wrapper%create(fft_rank, dims, double_precision, how_many, int(0, int8), r2c, dct, dst, get_user_stream(), self%plan_backward)
    else
      self%plan_backward = self%plan_forward
    endif
  end subroutine create

  subroutine execute(self, a, b, sign)
  !! Executes vkFFT plan
    class(vkfft_executor),  intent(in)      :: self           !! vkFFT FFT Executor
    type(c_ptr),            intent(in)      :: a              !! Source pointer
    type(c_ptr),            intent(in)      :: b              !! Target pointer
    integer(int8),          intent(in)      :: sign           !! Sign of transform

    if ( self%is_inverse_required .and. sign == FFT_BACKWARD ) then
      call self%wrapper%execute(self%plan_backward, a, b, sign)
    else
      call self%wrapper%execute(self%plan_forward, a, b, sign)
    endif
  end subroutine execute

  subroutine destroy(self)
  !! Destroys vkFFT plan
    class(vkfft_executor), intent(inout)    :: self           !! vkFFT FFT Executor

    call self%wrapper%destroy(self%plan_forward)
    if ( self%is_inverse_required ) then
      call self%wrapper%destroy(self%plan_backward)
    endif
  end subroutine destroy

  subroutine mem_alloc(alloc_bytes, ptr)
  !! Dummy method. Raises `error stop`
    integer(int64),           intent(in)  :: alloc_bytes  !! Number of bytes to allocate
    type(c_ptr),              intent(out) :: ptr          !! Allocated pointer

    INTERNAL_ERROR("mem_alloc for VkFFT called")
  end subroutine mem_alloc

  subroutine mem_free(ptr)
  !! Dummy method. Raises `error stop`
    type(c_ptr),               intent(in)   :: ptr        !! Pointer to free

    INTERNAL_ERROR("mem_free for VkFFT called")
  end subroutine mem_free
end module dtfft_executor_vkfft_m