dtfft_interface_nvshmem.F90 Source File


This file depends on

sourcefile~~dtfft_interface_nvshmem.f90~~EfferentGraph sourcefile~dtfft_interface_nvshmem.f90 dtfft_interface_nvshmem.F90 sourcefile~dtfft_parameters.f90 dtfft_parameters.F90 sourcefile~dtfft_interface_nvshmem.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_utils.f90 dtfft_utils.F90 sourcefile~dtfft_interface_nvshmem.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_utils.f90->sourcefile~dtfft_parameters.f90

Files dependent on this one

sourcefile~~dtfft_interface_nvshmem.f90~~AfferentGraph sourcefile~dtfft_interface_nvshmem.f90 dtfft_interface_nvshmem.F90 sourcefile~dtfft_abstract_transpose_plan.f90 dtfft_abstract_transpose_plan.F90 sourcefile~dtfft_abstract_transpose_plan.f90->sourcefile~dtfft_interface_nvshmem.f90 sourcefile~dtfft_backend_cufftmp.f90 dtfft_backend_cufftmp.F90 sourcefile~dtfft_backend_cufftmp.f90->sourcefile~dtfft_interface_nvshmem.f90 sourcefile~dtfft_plan.f90 dtfft_plan.F90 sourcefile~dtfft_plan.f90->sourcefile~dtfft_interface_nvshmem.f90 sourcefile~dtfft_plan.f90->sourcefile~dtfft_abstract_transpose_plan.f90 sourcefile~dtfft_transpose_plan_cuda.f90 dtfft_transpose_plan_cuda.F90 sourcefile~dtfft_plan.f90->sourcefile~dtfft_transpose_plan_cuda.f90 sourcefile~dtfft_transpose_plan_host.f90 dtfft_transpose_plan_host.F90 sourcefile~dtfft_plan.f90->sourcefile~dtfft_transpose_plan_host.f90 sourcefile~dtfft.f90 dtfft.F90 sourcefile~dtfft.f90->sourcefile~dtfft_plan.f90 sourcefile~dtfft_api.f90 dtfft_api.F90 sourcefile~dtfft_api.f90->sourcefile~dtfft_plan.f90 sourcefile~dtfft_transpose_handle_cuda.f90 dtfft_transpose_handle_cuda.F90 sourcefile~dtfft_transpose_handle_cuda.f90->sourcefile~dtfft_backend_cufftmp.f90 sourcefile~dtfft_transpose_plan_cuda.f90->sourcefile~dtfft_abstract_transpose_plan.f90 sourcefile~dtfft_transpose_plan_cuda.f90->sourcefile~dtfft_transpose_handle_cuda.f90 sourcefile~dtfft_transpose_plan_host.f90->sourcefile~dtfft_abstract_transpose_plan.f90

Source Code

!------------------------------------------------------------------------------------------------
! Copyright (c) 2021, Oleg Shatrov
! All rights reserved.
! This file is part of dtFFT library.

! dtFFT is free software: you can redistribute it and/or modify
! it under the terms of the GNU General Public License as published by
! the Free Software Foundation, either version 3 of the License, or
! (at your option) any later version.

! dtFFT is distributed in the hope that it will be useful,
! but WITHOUT ANY WARRANTY; without even the implied warranty of
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
! GNU General Public License for more details.

! You should have received a copy of the GNU General Public License
! along with this program.  If not, see <https://www.gnu.org/licenses/>.
!------------------------------------------------------------------------------------------------
module dtfft_interface_nvshmem
!! NVSHMEM Interfaces
use iso_c_binding
use iso_fortran_env
use dtfft_parameters
use dtfft_utils
implicit none
private
public :: nvshmem_team_t
public :: is_nvshmem_ptr
! public :: load_nvshmem


  type, bind(C) :: nvshmem_team_t
  !! NVSHMEM team.
    integer(c_int32_t) :: val  !! Internal representation of the NVSHMEM team.
  end type nvshmem_team_t

  type(nvshmem_team_t), parameter, public :: NVSHMEM_TEAM_WORLD = nvshmem_team_t(0)

public :: nvshmem_malloc
public :: nvshmem_free
public :: nvshmemx_sync_all_on_stream
public :: nvshmemx_float_alltoall_on_stream
public :: nvshmem_ptr
public :: nvshmem_my_pe
public :: nvshmemx_init_status

  interface
    function nvshmem_malloc(size) result(ptr) bind(C)
    !! Allocates symmetric memory in the NVSHMEM heap.
      import
      type(c_ptr)               :: ptr  !! Pointer to the allocated memory.
      integer(c_size_t), value  :: size            !! Size of the allocation in bytes.
    end function nvshmem_malloc

    subroutine nvshmem_free(ptr) bind(C)
    !! Frees symmetric memory allocated by nvshmem_malloc.
      import
      type(c_ptr), value :: ptr  !! Pointer to the memory to free.
    end subroutine nvshmem_free


    subroutine nvshmemx_sync_all_on_stream(stream) bind(C)
    !! Synchronizes all PEs (Processing Elements) on the specified stream.
      import
      type(dtfft_stream_t), intent(in), value :: stream  !! CUDA stream for synchronization.
    end subroutine nvshmemx_sync_all_on_stream

    function nvshmemx_float_alltoall_on_stream(team, dest, source, nelems, stream) result(ierr) bind(C)
    !! Performs an all-to-all exchange of floating-point data on the specified stream.
      import
      integer(c_int)                          :: ierr  !! Completion status.
      type(nvshmem_team_t), intent(in), value :: team   !! NVSHMEM team.
      type(c_ptr),                      value :: dest   !! Destination buffer.
      type(c_ptr),                      value :: source !! Source buffer.
      integer(c_size_t),    intent(in), value :: nelems !! Number of elements to exchange.
      type(dtfft_stream_t), intent(in), value :: stream !! CUDA stream for the operation.
    end function nvshmemx_float_alltoall_on_stream

    function nvshmem_ptr(ptr, pe) result(pe_ptr) bind(C)
    !! Returns a pointer to a symmetric memory location on a specified PE.
      import
      type(c_ptr)               :: pe_ptr  !! Pointer to the symmetric memory on the specified PE.
      type(c_ptr),    value     :: ptr          !! Local pointer to the symmetric memory.
      integer(c_int), value     :: pe           !! PE (Processing Element) number.
    end function nvshmem_ptr

    function nvshmem_my_pe() result(pe) bind(C)
    !! Returns the PE (Processing Element) number of the calling thread.
      import
      integer(c_int) :: pe  !! PE number of the calling thread.
    end function nvshmem_my_pe

    function nvshmemx_init_status() result(status) bind(C)
      import
      integer(c_int)                          :: status  !! Completion status.
    end function nvshmemx_init_status
  end interface

  ! logical, save :: is_loaded = .false.
  !   !! Flag indicating whether the library is loaded
  ! type(c_ptr), save :: libnvshmem
  !   !! Handle to the loaded library
  ! type(c_funptr), save :: nvshmemFunctions(7)
  !   !! Array of pointers to the NVSHMEM functions

  ! procedure(nvshmem_malloc_interface),              pointer, public :: nvshmem_malloc
  !   !! Fortran pointer to the nvshmem_malloc function
  ! procedure(nvshmem_free_interface),                pointer, public :: nvshmem_free
  !   !! Fortran pointer to the nvshmem_free function
  ! procedure(nvshmemx_sync_all_on_stream_interface), pointer, public :: nvshmemx_sync_all_on_stream
  !   !! Fortran pointer to the nvshmemx_sync_all_on_stream function
  ! procedure(nvshmemx_float_alltoall_on_stream_interface), pointer, public :: nvshmemx_float_alltoall_on_stream
  !   !! Fortran pointer to the nvshmemx_float_alltoall_on_stream function
  ! procedure(nvshmem_ptr_interface),                 pointer, public :: nvshmem_ptr
  !   !! Fortran pointer to the nvshmem_ptr function
  ! procedure(nvshmem_my_pe_interface),               pointer, public :: nvshmem_my_pe
  !   !! Fortran pointer to the nvshmem_my_pe function
  ! procedure(nvshmemx_init_status_interface),                pointer,  public  :: nvshmemx_init_status

contains

  ! function load_nvshmem(cufftmp_handle) result(error_code)
  ! !! Loads the NVSHMEM library and needed symbols
  !   type(c_ptr),  intent(in)  :: cufftmp_handle
  !   integer(int32)  :: error_code !! Error code
  !   type(string), allocatable :: func_names(:)
  !   integer(int32) :: i

  !   error_code = DTFFT_SUCCESS
    ! if ( is_loaded ) return

    ! allocate(func_names(7))
    ! func_names(1) = string("nvshmem_malloc")
    ! func_names(2) = string("nvshmem_free")
    ! func_names(3) = string("nvshmemx_sync_all_on_stream")
    ! func_names(4) = string("nvshmemx_float_alltoall_on_stream")
    ! func_names(5) = string("nvshmem_ptr")
    ! func_names(6) = string("nvshmem_my_pe")
    ! func_names(7) = string("nvshmemx_init_status")

    ! ! error_code = dynamic_load(NVSHMEM_HOST_LIB, func_names, libnvshmem, nvshmemFunctions)
    ! ! call destroy_strings(func_names)
    ! ! if ( error_code /= DTFFT_SUCCESS ) return
    ! do i = 1, size(func_names)
    !   nvshmemFunctions(i) = load_symbol(cufftmp_handle, func_names(i)%raw)
    ! enddo

    ! call c_f_procpointer(nvshmemFunctions(1), nvshmem_malloc)
    ! call c_f_procpointer(nvshmemFunctions(2), nvshmem_free)
    ! call c_f_procpointer(nvshmemFunctions(3), nvshmemx_sync_all_on_stream)
    ! call c_f_procpointer(nvshmemFunctions(4), nvshmemx_float_alltoall_on_stream)
    ! call c_f_procpointer(nvshmemFunctions(5), nvshmem_ptr)
    ! call c_f_procpointer(nvshmemFunctions(6), nvshmem_my_pe)
    ! call c_f_procpointer(nvshmemFunctions(7), nvshmemx_init_status)

    ! print*, 'nvshmemx_init_status after load = ', nvshmemx_init_status()

    ! is_loaded = .true.
  ! end function load_nvshmem

  function is_nvshmem_ptr(ptr) result(bool)
  !! Checks if pointer is a symmetric nvshmem allocated pointer
    type(c_ptr)   :: ptr    !! Device pointer
    logical       :: bool   !! Result

    bool = .not.is_null_ptr( nvshmem_ptr(ptr, nvshmem_my_pe()) )
  end function is_nvshmem_ptr
end module dtfft_interface_nvshmem