dtfft_transpose_handle_datatype.F90 Source File


This file depends on

sourcefile~~dtfft_transpose_handle_datatype.f90~~EfferentGraph sourcefile~dtfft_transpose_handle_datatype.f90 dtfft_transpose_handle_datatype.F90 sourcefile~dtfft_abstract_transpose_handle.f90 dtfft_abstract_transpose_handle.F90 sourcefile~dtfft_transpose_handle_datatype.f90->sourcefile~dtfft_abstract_transpose_handle.f90 sourcefile~dtfft_errors.f90 dtfft_errors.F90 sourcefile~dtfft_transpose_handle_datatype.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_interface_nvtx.f90 dtfft_interface_nvtx.F90 sourcefile~dtfft_transpose_handle_datatype.f90->sourcefile~dtfft_interface_nvtx.f90 sourcefile~dtfft_parameters.f90 dtfft_parameters.F90 sourcefile~dtfft_transpose_handle_datatype.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_pencil.f90 dtfft_pencil.F90 sourcefile~dtfft_transpose_handle_datatype.f90->sourcefile~dtfft_pencil.f90 sourcefile~dtfft_utils.f90 dtfft_utils.F90 sourcefile~dtfft_transpose_handle_datatype.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_abstract_transpose_handle.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_abstract_transpose_handle.f90->sourcefile~dtfft_pencil.f90 sourcefile~dtfft_abstract_backend.f90 dtfft_abstract_backend.F90 sourcefile~dtfft_abstract_transpose_handle.f90->sourcefile~dtfft_abstract_backend.f90 sourcefile~dtfft_interface_nvtx.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_pencil.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_pencil.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_pencil.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_interface_cuda_runtime.f90 dtfft_interface_cuda_runtime.F90 sourcefile~dtfft_pencil.f90->sourcefile~dtfft_interface_cuda_runtime.f90 sourcefile~dtfft_utils.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_utils.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_pencil.f90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_interface_cuda_runtime.f90 sourcefile~dtfft_abstract_kernel.f90 dtfft_abstract_kernel.F90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_abstract_kernel.f90 sourcefile~dtfft_config.f90 dtfft_config.F90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_config.f90 sourcefile~dtfft_interface_nccl.f90 dtfft_interface_nccl.F90 sourcefile~dtfft_abstract_backend.f90->sourcefile~dtfft_interface_nccl.f90 sourcefile~dtfft_interface_cuda_runtime.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_interface_cuda_runtime.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_interface_nvtx.f90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_interface_cuda_runtime.f90 sourcefile~dtfft_interface_nccl.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_interface_nccl.f90->sourcefile~dtfft_utils.f90

Files dependent on this one

sourcefile~~dtfft_transpose_handle_datatype.f90~~AfferentGraph sourcefile~dtfft_transpose_handle_datatype.f90 dtfft_transpose_handle_datatype.F90 sourcefile~dtfft_transpose_plan.f90 dtfft_transpose_plan.F90 sourcefile~dtfft_transpose_plan.f90->sourcefile~dtfft_transpose_handle_datatype.f90 sourcefile~dtfft_plan.f90 dtfft_plan.F90 sourcefile~dtfft_plan.f90->sourcefile~dtfft_transpose_plan.f90 sourcefile~dtfft.f90 dtfft.F90 sourcefile~dtfft.f90->sourcefile~dtfft_plan.f90 sourcefile~dtfft_api.f90 dtfft_api.F90 sourcefile~dtfft_api.f90->sourcefile~dtfft_plan.f90

Source Code

!------------------------------------------------------------------------------------------------
! Copyright (c) 2021 - 2025, Oleg Shatrov
! All rights reserved.
! This file is part of dtFFT library.

! dtFFT is free software: you can redistribute it and/or modify
! it under the terms of the GNU General Public License as published by
! the Free Software Foundation, either version 3 of the License, or
! (at your option) any later version.

! dtFFT is distributed in the hope that it will be useful,
! but WITHOUT ANY WARRANTY; without even the implied warranty of
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
! GNU General Public License for more details.

! You should have received a copy of the GNU General Public License
! along with this program.  If not, see <https://www.gnu.org/licenses/>.
!------------------------------------------------------------------------------------------------
#include "dtfft_config.h"
module dtfft_transpose_handle_datatype
!! This module describes [[transpose_handle_datatype]] class
!! This class implements transposition using MPI_Ialltoall(w)
!! with custom MPI datatypes
!! For the end user this is `DTFFT_BACKEND_MPI_DATATYPE` - backend.
!! But since it does not perform sequence: transpose -> exchange -> unpack, it is internally treated as tranpose_handle.
use iso_fortran_env
use dtfft_abstract_transpose_handle,  only: abstract_transpose_handle, create_args, execute_args
use dtfft_errors
use dtfft_parameters
use dtfft_pencil,     only: pencil, get_transpose_type
use dtfft_utils
#include "_dtfft_mpi.h"
#include "_dtfft_private.h"
#include "_dtfft_profile.h"
#include "_dtfft_cuda.h"
implicit none
private
public :: transpose_handle_datatype

  integer(MPI_ADDRESS_KIND), parameter :: LB = 0
  !! Lower bound for all derived datatypes

#if defined (ENABLE_PERSISTENT_COMM)
# if defined(ENABLE_PERSISTENT_COLLECTIVES)
#   if defined(OMPI_FIX_REQUIRED)
      logical, parameter :: IS_P2P_ENABLED = .true.
#   else
      logical, parameter :: IS_P2P_ENABLED = .false.
#   endif
# else
      logical, parameter :: IS_P2P_ENABLED = .true.
# endif
#else
# if defined(OMPI_FIX_REQUIRED)
      logical, parameter :: IS_P2P_ENABLED = .true.
# else
      logical, parameter :: IS_P2P_ENABLED = .false.
# endif
#endif
  !! Is point-to-point communication enabled

  type :: handle_t
  !! Transposition handle class
    TYPE_MPI_DATATYPE,     allocatable :: dtypes(:)           !! Datatypes buffer
    integer(int32),        allocatable :: counts(:)           !! Number of datatypes (always equals 1)
    integer(int32),        allocatable :: displs(:)           !! Displacements is bytes
  contains
    procedure, pass(self) :: create => create_handle          !! Creates transposition handle
    procedure, pass(self) :: destroy => destroy_handle        !! Destroys transposition handle
  end type handle_t

  type, extends(abstract_transpose_handle) :: transpose_handle_datatype
  !! Tranpose backend that uses MPI_Ialltoall(w) with custom MPI datatypes
  private
    TYPE_MPI_COMM                   :: comm                   !! 1d communicator
    logical                         :: is_even = .false.      !! Is decomposition even
    logical                         :: is_active = .false.    !! Is async transposition active
    type(handle_t)                  :: send                   !! Handle to send data
    type(handle_t)                  :: recv                   !! Handle to recieve data
    TYPE_MPI_REQUEST, allocatable   :: requests(:)            !! Requests for communication
    integer(int32)                  :: n_requests             !! Actual number of requests, can be less than size(requests)
#if defined(ENABLE_PERSISTENT_COMM)
    logical                         :: is_request_created = .false.     !! Is request created
#endif
  contains
  private
    procedure, pass(self),  public  :: create_private => create !! Initializes class
    procedure, pass(self),  public  :: execute                !! Performs MPI_Ialltoall(w)
    procedure, pass(self),  public  :: execute_end            !! Waits for MPI_Ialltoall(w) to complete
    procedure, pass(self),  public  :: destroy                !! Destroys class
    procedure, pass(self),  public  :: get_async_active       !! Returns .true. if async transposition is active
  end type transpose_handle_datatype

contains

  subroutine create_handle(self, n)
  !! Creates transposition handle
    class(handle_t),  intent(inout) :: self   !! Transposition handle
    integer(int32),   intent(in)    :: n      !! Number of datatypes to be created

    call self%destroy()
    allocate(self%dtypes(n), source = MPI_DATATYPE_NULL)
    allocate(self%counts(n), source = 1_int32)
    allocate(self%displs(n), source = 0_int32)
  end subroutine create_handle

  subroutine destroy_handle(self)
  !! Destroys transposition handle
    class(handle_t),  intent(inout)   :: self   !! Transposition handle
    integer(int32)                    :: i      !! Counter
    integer(int32)                    :: ierr   !! Error code

    if ( allocated(self%dtypes) ) then
      do i = 1, size(self%dtypes)
        call MPI_Type_free(self%dtypes(i), ierr)
      enddo
      deallocate(self%dtypes)
    endif
    if ( allocated(self%displs) ) deallocate(self%displs)
    if ( allocated(self%counts) ) deallocate(self%counts)
  end subroutine destroy_handle

  subroutine create(self, comm, send, recv, transpose_type, base_storage, kwargs)
  !! Creates `transpose_handle_datatype` class
    class(transpose_handle_datatype), intent(inout) :: self           !! Transpose handle
    TYPE_MPI_COMM,                    intent(in)    :: comm           !! MPI Communicator
    type(pencil),                     intent(in)    :: send           !! Send pencil
    type(pencil),                     intent(in)    :: recv           !! Recv pencil
    type(dtfft_transpose_t),          intent(in)    :: transpose_type !! Type of transpose to create
    integer(int64),                   intent(in)    :: base_storage   !! Base storage
    type(create_args),                intent(in)    :: kwargs         !! Additional arguments
    integer(int32)                              :: comm_size          !! Size of 1d communicator
    integer(int32)                              :: n_neighbors        !! Number of datatypes to be created
    integer(int32),               allocatable   :: recv_counts(:,:)   !! Each processor should know how much data each processor recieves
    integer(int32),               allocatable   :: send_counts(:,:)   !! Each processor should know how much data each processor sends
    integer(int32)                              :: i                  !! Counter
    integer(int32)                              :: ierr               !! Error code
    integer(int32) :: send_displ, recv_displ

    call self%destroy()
    self%comm = comm
    call MPI_Comm_size(comm, comm_size, ierr)
    self%is_even = send%is_even .and. recv%is_even
    n_neighbors = comm_size;  if ( self%is_even ) n_neighbors = 1
    self%is_active = .false.

    allocate(self%requests(2 * comm_size))

#if defined(ENABLE_PERSISTENT_COMM)
    self%is_request_created = .false.
#endif

    call self%send%create(n_neighbors)
    call self%recv%create(n_neighbors)

    allocate(recv_counts(recv%rank, comm_size), source = 0_int32)
    allocate(send_counts, source = recv_counts)
    call MPI_Allgather(recv%counts, int(recv%rank, int32), MPI_INTEGER4, recv_counts, int(recv%rank, int32), MPI_INTEGER4, comm, ierr)
    call MPI_Allgather(send%counts, int(send%rank, int32), MPI_INTEGER4, send_counts, int(send%rank, int32), MPI_INTEGER4, comm, ierr)
    do i = 1, n_neighbors
      if ( send%rank == 2 ) then
        call create_transpose_2d(send, send_counts(:,i), recv, recv_counts(:,i), kwargs%datatype_id, kwargs%base_type, base_storage,      &
          self%send%dtypes(i), send_displ, self%recv%dtypes(i), recv_displ)
      else if ( any( transpose_type == [DTFFT_TRANSPOSE_X_TO_Y, DTFFT_TRANSPOSE_Y_TO_Z]) ) then
        call create_forw_permutation(send, send_counts(:,i), recv, recv_counts(:,i), kwargs%datatype_id, kwargs%base_type, base_storage,  &
          self%send%dtypes(i), send_displ, self%recv%dtypes(i), recv_displ)
      else if ( any( transpose_type == [DTFFT_TRANSPOSE_Y_TO_X, DTFFT_TRANSPOSE_Z_TO_Y]) ) then
        call create_back_permutation(send, send_counts(:,i), recv, recv_counts(:,i), kwargs%datatype_id, kwargs%base_type, base_storage,  &
          self%send%dtypes(i), send_displ, self%recv%dtypes(i), recv_displ)
      else if ( transpose_type == DTFFT_TRANSPOSE_X_TO_Z ) then
        call create_transpose_XZ(send, send_counts(:,i), recv, recv_counts(:,i), kwargs%datatype_id, kwargs%base_type, base_storage,      &
          self%send%dtypes(i), send_displ, self%recv%dtypes(i), recv_displ)
      else
        call create_transpose_ZX(send, send_counts(:,i), recv, recv_counts(:,i), kwargs%datatype_id, kwargs%base_type, base_storage,      &
          self%send%dtypes(i), send_displ, self%recv%dtypes(i), recv_displ)
      endif
      if ( self%is_even ) then
        self%send%displs(i) = send_displ
        self%recv%displs(i) = recv_displ
      else
        if ( i < n_neighbors ) then
          if ( any(send%counts == 0) ) then
            self%send%displs(i + 1) = self%send%displs(i)
          else
            self%send%displs(i + 1) = self%send%displs(i) + send_displ
          endif
          if ( any(recv%counts == 0) ) then
            self%recv%displs(i + 1) = self%recv%displs(i)
          else
            self%recv%displs(i + 1) = self%recv%displs(i) + recv_displ
          endif
        endif
      endif

    enddo

    if ( IS_P2P_ENABLED ) then
      self%send%displs(:) = self%send%displs(:) / int(FLOAT_STORAGE_SIZE, int32)
      self%recv%displs(:) = self%recv%displs(:) / int(FLOAT_STORAGE_SIZE, int32)
      if ( .not. self%is_even ) then
        if ( any(send%counts == 0) ) self%send%counts(:) = 0
        if ( any(recv%counts == 0) ) self%recv%counts(:) = 0
        do i = 1, n_neighbors
          if ( any(recv_counts(:,i) == 0) ) then
            self%send%counts(i) = 0
          endif
          if ( any(send_counts(:,i) == 0) ) then
            self%recv%counts(i) = 0
          endif
        enddo
      endif
    endif
    deallocate(recv_counts, send_counts)
  end subroutine create

  subroutine execute(self, in, out, kwargs, error_code)
  !! Executes transposition
    class(transpose_handle_datatype), intent(inout) :: self       !! Transpose handle
    real(real32),                     intent(inout) :: in(:)      !! Send pointer
    real(real32),                     intent(inout) :: out(:)     !! Recv pointer
    type(execute_args),               intent(inout) :: kwargs     !! Additional arguments
    integer(int32),                   intent(out)   :: error_code !! Result of execution
    integer(int32) :: i, comm_size, ierr

    error_code = DTFFT_SUCCESS
    if ( self%is_active ) then
      error_code = DTFFT_ERROR_TRANSPOSE_ACTIVE
      return
    endif

    call MPI_Comm_size(self%comm, comm_size, ierr)

#if defined (ENABLE_PERSISTENT_COMM)
    if ( .not. self%is_request_created ) then
# if defined(ENABLE_PERSISTENT_COLLECTIVES)
      if ( self%is_even ) then
        self%n_requests = 1
        call MPI_Alltoall_init(in, 1, self%send%dtypes(1), out, 1, self%recv%dtypes(1),     &
                                self%comm, MPI_INFO_NULL, self%requests(1), ierr)
      else
#   if defined(OMPI_FIX_REQUIRED)
        self%n_requests = 0
        do i = 1, comm_size
          if ( self%recv%counts(i) > 0 ) then
            self%n_requests = self%n_requests + 1
            call MPI_Recv_init(out(self%recv%displs(i)), 1, self%recv%dtypes(i), i - 1, 0_int32,  &
                                self%comm, self%requests(self%n_requests), ierr)
          endif
        enddo

        do i = 1, comm_size
          if ( self%send%counts(i) > 0 ) then
            self%n_requests = self%n_requests + 1
            call MPI_Send_init(in(self%send%displs(i)), 1, self%send%dtypes(i), i - 1, 0_int32,   &
                                self%comm, self%requests(self%n_requests), ierr)
          endif
        enddo
#   else
        self%n_requests = 1
        call MPI_Alltoallw_init(in, self%send%counts, self%send%displs, self%send%dtypes,   &
                                out, self%recv%counts, self%recv%displs, self%recv%dtypes,  &
                                self%comm, MPI_INFO_NULL, self%requests(1), ierr)
      endif
#   endif
# else
      self%n_requests = 0
      if ( self%is_even ) then
        do i = 1, comm_size
          self%n_requests = self%n_requests + 1
          call MPI_Recv_init(out((i - 1) * self%recv%displs(1) + 1), 1, self%recv%dtypes(1), i - 1, 0_int32,  &
                              self%comm, self%requests(self%n_requests), ierr)
        enddo

        do i = 1, comm_size
          self%n_requests = self%n_requests + 1
          call MPI_Send_init(in((i - 1) * self%send%displs(1) + 1), 1, self%send%dtypes(1), i - 1, 0_int32,   &
                              self%comm, self%requests(self%n_requests), ierr)
        enddo
      else
        do i = 1, comm_size
          if ( self%recv%counts(i) > 0 ) then
            self%n_requests = self%n_requests + 1
            call MPI_Recv_init(out(self%recv%displs(i) + 1), 1, self%recv%dtypes(i), i - 1, 0_int32,  &
                                self%comm, self%requests(self%n_requests), ierr)
          endif
        enddo

        do i = 1, comm_size
          if ( self%send%counts(i) > 0 ) then
            self%n_requests = self%n_requests + 1
            call MPI_Send_init(in(self%send%displs(i) + 1), 1, self%send%dtypes(i), i - 1, 0_int32,   &
                                self%comm, self%requests(self%n_requests), ierr)
          endif
        enddo
      endif
# endif
      self%is_request_created = .true.
    endif
    call MPI_Startall(self%n_requests, self%requests, ierr)
#else
    if ( self%is_even ) then
      self%n_requests = 1
      call MPI_Ialltoall(in, 1, self%send%dtypes(1), out, 1, self%recv%dtypes(1),           &
                          self%comm, self%requests(1), ierr)
    else
# if defined(OMPI_FIX_REQUIRED)
      block
        integer(int32) :: i, comm_size
        call MPI_Comm_size(self%comm, comm_size, ierr)
        self%n_requests = 0
        do i = 1, comm_size
          if ( self%recv%counts(i) > 0 ) then
            self%n_requests = self%n_requests + 1
            call MPI_Irecv(out(self%recv%displs(i) + 1), 1, self%recv%dtypes(i), i - 1, 0_int32,    &
                            self%comm, self%requests(self%n_requests), ierr)
          endif
        enddo
        do i = 1, comm_size
          if ( self%send%counts(i) > 0 ) then
            self%n_requests = self%n_requests + 1
            call MPI_Isend(in(self%send%displs(i) + 1), 1, self%send%dtypes(i), i - 1, 0_int32,     &
                            self%comm, self%requests(self%n_requests), ierr)
          endif
        enddo
      endblock
# else
      self%n_requests = 1
      call MPI_Ialltoallw(in, self%send%counts, self%send%displs, self%send%dtypes,         &
                          out, self%recv%counts, self%recv%displs, self%recv%dtypes,        &
                          self%comm, self%requests(1), ierr)
# endif
    endif
#endif
    self%is_active = .true.
    if ( kwargs%exec_type == EXEC_BLOCKING ) call self%execute_end(kwargs, error_code)
  end subroutine execute

  subroutine execute_end(self, kwargs, error_code)
  !! Ends execution of transposition
    class(transpose_handle_datatype), intent(inout) :: self       !! Transpose handle
    type(execute_args),               intent(inout) :: kwargs     !! Additional arguments
    integer(int32),                   intent(out)   :: error_code !! Error code
    integer(int32)  :: ierr         !! Error code

    error_code = DTFFT_SUCCESS
    if ( .not. self%is_active ) then
      error_code = DTFFT_ERROR_TRANSPOSE_NOT_ACTIVE
      return
    endif
    call MPI_Waitall(self%n_requests, self%requests, MPI_STATUSES_IGNORE, ierr)
    self%is_active = .false.
  end subroutine execute_end

  elemental logical function get_async_active(self)
  !! Returns if async transpose is active
    class(transpose_handle_datatype), intent(in)    :: self         !! Transpose handle
    get_async_active = self%is_active
  end function get_async_active

  subroutine destroy(self)
  !! Destroys `transpose_handle_datatype` class
    class(transpose_handle_datatype), intent(inout) :: self         !! Transpose handle

    call self%send%destroy()
    call self%recv%destroy()
#if defined(ENABLE_PERSISTENT_COLLECTIVES)
    block
      integer(int32) :: i, ierr
      if( self%is_request_created ) then
        do i = 1, self%n_requests
          call MPI_Request_free(self%requests(i), ierr)
        enddo
        self%is_request_created = .false.
      endif
    endblock
#endif
    if( allocated(self%requests) ) deallocate( self%requests )
    self%is_active = .false.
    self%is_even = .false.
  end subroutine destroy

  subroutine create_transpose_2d(send, send_counts, recv, recv_counts, datatype_id, base_type, base_storage, send_dtype, send_displ, recv_dtype, recv_displ)
  !! Creates two-dimensional transposition datatypes
    class(pencil),                intent(in)    :: send               !! Information about send buffer
    integer(int32),               intent(in)    :: send_counts(:)     !! Rank i is sending this counts
    class(pencil),                intent(in)    :: recv               !! Information about send buffer
    integer(int32),               intent(in)    :: recv_counts(:)     !! Rank i is recieving this counts
    integer(int8),                intent(in)    :: datatype_id        !! Id of transpose plan to use
    TYPE_MPI_DATATYPE,            intent(in)    :: base_type          !! Base MPI_Datatype
    integer(int64),               intent(in)    :: base_storage       !! Number of bytes needed to store single element
    TYPE_MPI_DATATYPE,            intent(out)   :: send_dtype         !! Datatype used to send data
    integer(int32),               intent(out)   :: send_displ         !! Send displacement in bytes
    TYPE_MPI_DATATYPE,            intent(out)   :: recv_dtype         !! Datatype used to recv data
    integer(int32),               intent(out)   :: recv_displ         !! Recv displacement in bytes
    TYPE_MPI_DATATYPE   :: temp1              !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp2              !! Temporary datatype
    integer(int32)      :: ierr               !! Error code

    send_displ = recv_counts(2) * int(base_storage, int32)
    recv_displ = send_counts(2) * int(base_storage, int32)
    if ( datatype_id == 1 ) then
      call MPI_Type_vector(send%counts(2), recv_counts(2), send%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(send_displ, MPI_ADDRESS_KIND), send_dtype, ierr)
      call free_datatypes(temp1)

      call MPI_Type_vector(recv%counts(2), 1, recv%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr)
      call MPI_Type_contiguous(send_counts(2), temp2, recv_dtype, ierr)
      call free_datatypes(temp1, temp2)
    else
      call MPI_Type_vector(send%counts(2), 1, send%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr)
      call MPI_Type_contiguous(recv_counts(2), temp2, send_dtype, ierr)
      call free_datatypes(temp1, temp2)

      call MPI_Type_vector(recv%counts(2), send_counts(2), recv%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr)
      call free_datatypes(temp1)
    endif

    call MPI_Type_commit(send_dtype, ierr)
    call MPI_Type_commit(recv_dtype, ierr)
  end subroutine create_transpose_2d

  subroutine create_forw_permutation(send, send_counts, recv, recv_counts, datatype_id, base_type, base_storage, send_dtype, send_displ, recv_dtype, recv_displ)
  !! Creates three-dimensional X --> Y and Y -> Z transposition datatypes
    class(pencil),                intent(in)    :: send               !! Information about send buffer
    integer(int32),               intent(in)    :: send_counts(:)     !! Rank i is sending this counts
    class(pencil),                intent(in)    :: recv               !! Information about send buffer
    integer(int32),               intent(in)    :: recv_counts(:)     !! Rank i is recieving this counts
    integer(int8),                intent(in)    :: datatype_id        !! Id of transpose plan to use
    TYPE_MPI_DATATYPE,            intent(in)    :: base_type          !! Base MPI_Datatype
    integer(int64),               intent(in)    :: base_storage       !! Number of bytes needed to store single element
    TYPE_MPI_DATATYPE,            intent(out)   :: send_dtype         !! Datatype used to send data
    integer(int32),               intent(out)   :: send_displ         !! Send displacement in bytes
    TYPE_MPI_DATATYPE,            intent(out)   :: recv_dtype         !! Datatype used to recv data
    integer(int32),               intent(out)   :: recv_displ         !! Recv displacement in bytes
    TYPE_MPI_DATATYPE   :: temp1                !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp2                !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp3                !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp4                !! Temporary datatype
    integer(int32)      :: ierr                 !! Error code

    send_displ = recv_counts(3) * int(base_storage, int32)
    recv_displ = send_counts(2) * int(base_storage, int32)
    if ( datatype_id == 1 ) then
    ! This datatype_id has "contiguous" send and strided recieve datatype
      call MPI_Type_vector(send%counts(2) * send%counts(3), recv_counts(3), send%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(send_displ, MPI_ADDRESS_KIND), send_dtype, ierr)
      call free_datatypes(temp1)

      call MPI_Type_vector(recv%counts(3), 1, recv%counts(1) * recv%counts(2), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr)
      call MPI_Type_contiguous(send_counts(2), temp2, temp3, ierr)
      call MPI_Type_create_hvector(recv%counts(2), 1, int(recv%counts(1) * base_storage, MPI_ADDRESS_KIND), temp3, temp4, ierr)
      call MPI_Type_create_resized(temp4, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr)
      call free_datatypes(temp1, temp2, temp3, temp4)
    elseif ( datatype_id == 2 ) then
    ! This datatype_id has strided send and "contiguous" recieve datatypes
      call MPI_Type_vector(send%counts(2) * send%counts(3), 1, send%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr)
      call MPI_Type_contiguous(recv_counts(3), temp2, send_dtype, ierr)
      call free_datatypes(temp1, temp2)

      call MPI_Type_vector(recv%counts(2), send_counts(2), recv%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(recv_displ, MPI_ADDRESS_KIND), temp2, ierr)
      call MPI_Type_create_hvector(recv%counts(3), 1, int(recv%counts(1) * recv%counts(2) * base_storage, MPI_ADDRESS_KIND), temp2, temp3, ierr)
      call MPI_Type_create_resized(temp3, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr)
      call free_datatypes(temp1, temp2, temp3)
    endif

    call MPI_Type_commit(send_dtype, ierr)
    call MPI_Type_commit(recv_dtype, ierr)
  end subroutine create_forw_permutation

  subroutine create_back_permutation(send, send_counts, recv, recv_counts, datatype_id, base_type, base_storage, send_dtype, send_displ, recv_dtype, recv_displ)
  !! Creates three-dimensional Y --> X and Z --> Y transposition datatypes
    class(pencil),                intent(in)    :: send               !! Information about send buffer
    integer(int32),               intent(in)    :: send_counts(:)     !! Rank i is sending this counts
    class(pencil),                intent(in)    :: recv               !! Information about send buffer
    integer(int32),               intent(in)    :: recv_counts(:)     !! Rank i is recieving this counts
    integer(int8),                intent(in)    :: datatype_id        !! Id of transpose plan to use
    TYPE_MPI_DATATYPE,            intent(in)    :: base_type          !! Base MPI_Datatype
    integer(int64),               intent(in)    :: base_storage       !! Number of bytes needed to store single element
    TYPE_MPI_DATATYPE,            intent(out)   :: send_dtype         !! Datatype used to send data
    integer(int32),               intent(out)   :: send_displ         !! Send displacement in bytes
    TYPE_MPI_DATATYPE,            intent(out)   :: recv_dtype         !! Datatype used to recv data
    integer(int32),               intent(out)   :: recv_displ         !! Recv displacement in bytes
    TYPE_MPI_DATATYPE   :: temp1                !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp2                !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp3                !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp4                !! Temporary datatype
    integer(int32)      :: ierr                 !! Error code

    send_displ = recv_counts(2) * int(base_storage, int32)
    recv_displ = send_counts(3) * int(base_storage, int32)
    if ( datatype_id == 1 ) then
    ! This datatype_id has "contiguous" send and strided recieve datatype
      call MPI_Type_vector(send%counts(2) * send%counts(3), recv_counts(2), send%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(send_displ, MPI_ADDRESS_KIND), send_dtype, ierr)
      call free_datatypes(temp1)

      call MPI_Type_vector(recv%counts(2) * recv%counts(3), 1, recv%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr)
      call MPI_Type_contiguous(send_counts(3), temp2, temp3, ierr)
      call MPI_Type_create_resized(temp3, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr)
      call free_datatypes(temp1, temp2, temp3)
    elseif ( datatype_id == 2 ) then
    ! This datatype_id has strided send and "contiguous" recieve datatypes
      call MPI_Type_vector(send%counts(3), 1, send%counts(1) * send%counts(2), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr)
      call MPI_Type_contiguous(recv_counts(2), temp2, temp3, ierr)
      call MPI_Type_create_hvector(send%counts(2), 1, int(send%counts(1) * base_storage, MPI_ADDRESS_KIND), temp3, temp4, ierr)
      call MPI_Type_create_resized(temp4, LB, int(send_displ, MPI_ADDRESS_KIND), send_dtype, ierr)
      call free_datatypes(temp1, temp2, temp3, temp4)

      call MPI_Type_vector(recv%counts(2) * recv%counts(3), send_counts(3), recv%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr)
      call free_datatypes(temp1)
    endif

    call MPI_Type_commit(send_dtype, ierr)
    call MPI_Type_commit(recv_dtype, ierr)
  end subroutine create_back_permutation

  subroutine create_transpose_XZ(send, send_counts, recv, recv_counts, datatype_id, base_type, base_storage, send_dtype, send_displ, recv_dtype, recv_displ)
  !! Creates three-dimensional X --> Z transposition datatypes
  !! Can only be used with 3D slab decomposition when slabs are distributed in Z direction
    class(pencil),                intent(in)    :: send               !! Information about send buffer
    integer(int32),               intent(in)    :: send_counts(:)     !! Rank i is sending this counts
    class(pencil),                intent(in)    :: recv               !! Information about send buffer
    integer(int32),               intent(in)    :: recv_counts(:)     !! Rank i is recieving this counts
    integer(int8),                intent(in)    :: datatype_id        !! Id of transpose plan to use
    TYPE_MPI_DATATYPE,            intent(in)    :: base_type          !! Base MPI_Datatype
    integer(int64),               intent(in)    :: base_storage       !! Number of bytes needed to store single element
    TYPE_MPI_DATATYPE,            intent(out)   :: send_dtype         !! Datatype used to send data
    integer(int32),               intent(out)   :: send_displ         !! Send displacement in bytes
    TYPE_MPI_DATATYPE,            intent(out)   :: recv_dtype         !! Datatype used to recv data
    integer(int32),               intent(out)   :: recv_displ         !! Recv displacement in bytes
    TYPE_MPI_DATATYPE   :: temp1                !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp2                !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp3                !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp4                !! Temporary datatype
    integer(int32)      :: ierr                 !! Error code

    send_displ = send%counts(1) * recv_counts(3) * int(base_storage, int32)
    recv_displ = send_counts(3) * int(base_storage, int32)
    if ( datatype_id == 1 ) then
      call MPI_Type_vector(send%counts(3), send%counts(1), send%counts(1) * send%counts(2), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(send%counts(1) * base_storage, MPI_ADDRESS_KIND), temp2, ierr)
      call MPI_Type_contiguous(recv_counts(3), temp2, send_dtype, ierr)
      call free_datatypes(temp1, temp2)

      call MPI_Type_vector(recv%counts(2), 1, recv%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr)
      call MPI_Type_contiguous(send_counts(3), temp2, temp3, ierr)
      call MPI_Type_create_hvector(recv%counts(3), 1, int(recv%counts(1) * recv%counts(2) * base_storage, MPI_ADDRESS_KIND), temp3, temp4, ierr)
      call MPI_Type_create_resized(temp4, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr)
      call free_datatypes(temp1, temp2, temp3, temp4)
    else
      call MPI_Type_vector(send%counts(3), 1, send%counts(1) * send%counts(2), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr)
      call MPI_Type_contiguous(send%counts(1), temp2, temp3, ierr)
      call MPI_Type_create_hvector(recv_counts(3), 1, int(send%counts(1) * base_storage, MPI_ADDRESS_KIND), temp3, temp4, ierr)
      call MPI_Type_create_resized(temp4, LB, int(send_displ, MPI_ADDRESS_KIND), send_dtype, ierr)
      call free_datatypes(temp1, temp2, temp3, temp4)

      call MPI_Type_vector(recv%counts(2) * recv%counts(3), send_counts(3), recv%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr)
      call free_datatypes(temp1)
    endif

    call MPI_Type_commit(send_dtype, ierr)
    call MPI_Type_commit(recv_dtype, ierr)
  end subroutine create_transpose_XZ

  subroutine create_transpose_ZX(send, send_counts, recv, recv_counts, datatype_id, base_type, base_storage, send_dtype, send_displ, recv_dtype, recv_displ)
  !! Creates three-dimensional Z --> X transposition datatypes
  !! Can only be used with 3D slab decomposition when slabs are distributed in Z direction
    class(pencil),                intent(in)    :: send               !! Information about send buffer
    integer(int32),               intent(in)    :: send_counts(:)     !! Rank i is sending this counts
    class(pencil),                intent(in)    :: recv               !! Information about send buffer
    integer(int32),               intent(in)    :: recv_counts(:)     !! Rank i is recieving this counts
    integer(int8),                intent(in)    :: datatype_id        !! Id of transpose plan to use
    TYPE_MPI_DATATYPE,            intent(in)    :: base_type          !! Base MPI_Datatype
    integer(int64),               intent(in)    :: base_storage       !! Number of bytes needed to store single element
    TYPE_MPI_DATATYPE,            intent(out)   :: send_dtype         !! Datatype used to send data
    integer(int32),               intent(out)   :: send_displ         !! Send displacement in bytes
    TYPE_MPI_DATATYPE,            intent(out)   :: recv_dtype         !! Datatype used to recv data
    integer(int32),               intent(out)   :: recv_displ         !! Recv displacement in bytes
    TYPE_MPI_DATATYPE   :: temp1                !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp2                !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp3                !! Temporary datatype
    TYPE_MPI_DATATYPE   :: temp4                !! Temporary datatype
    integer(int32)      :: ierr                 !! Error code

    send_displ = recv_counts(3) * int(base_storage, int32)
    recv_displ = recv%counts(1) * send_counts(3) * int(base_storage, int32)
    if ( datatype_id == 1 ) then
      call MPI_Type_vector(send%counts(2) * send%counts(3), recv_counts(3), send%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(send_displ, MPI_ADDRESS_KIND), send_dtype, ierr)
      call free_datatypes(temp1)

      call MPI_Type_vector(recv%counts(3), 1, recv%counts(1) * recv%counts(2), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr)
      call MPI_Type_contiguous(recv%counts(1), temp2, temp3, ierr)
      call MPI_Type_create_hvector(send_counts(3), 1, int(recv%counts(1) * base_storage, MPI_ADDRESS_KIND), temp3, temp4, ierr)
      call MPI_Type_create_resized(temp4, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr)
      call free_datatypes(temp1, temp2, temp3, temp4)
    else
      call MPI_Type_vector(send%counts(2) * send%counts(3), 1, send%counts(1), base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr)
      call MPI_Type_contiguous(recv_counts(3), temp2, send_dtype, ierr)
      call free_datatypes(temp1, temp2)

      call MPI_Type_vector(recv%counts(3), recv%counts(1) * send_counts(3), recv%counts(1) * recv%counts(2),  base_type, temp1, ierr)
      call MPI_Type_create_resized(temp1, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr)
      call free_datatypes(temp1)
    endif

    call MPI_Type_commit(send_dtype, ierr)
    call MPI_Type_commit(recv_dtype, ierr)
  end subroutine create_transpose_ZX

  subroutine free_datatypes(t1, t2, t3, t4)
  !! Frees temporary datatypes
    TYPE_MPI_DATATYPE,  intent(inout), optional :: t1     !! Temporary datatype
    TYPE_MPI_DATATYPE,  intent(inout), optional :: t2     !! Temporary datatype
    TYPE_MPI_DATATYPE,  intent(inout), optional :: t3     !! Temporary datatype
    TYPE_MPI_DATATYPE,  intent(inout), optional :: t4     !! Temporary datatype
    integer(int32)                              :: ierr   !! Error code

    if ( present(t1) ) call MPI_Type_free(t1, ierr)
    if ( present(t2) ) call MPI_Type_free(t2, ierr)
    if ( present(t3) ) call MPI_Type_free(t3, ierr)
    if ( present(t4) ) call MPI_Type_free(t4, ierr)
  end subroutine free_datatypes
end module dtfft_transpose_handle_datatype