!------------------------------------------------------------------------------------------------ ! Copyright (c) 2021 - 2025, Oleg Shatrov ! All rights reserved. ! This file is part of dtFFT library. ! dtFFT is free software: you can redistribute it and/or modify ! it under the terms of the GNU General Public License as published by ! the Free Software Foundation, either version 3 of the License, or ! (at your option) any later version. ! dtFFT is distributed in the hope that it will be useful, ! but WITHOUT ANY WARRANTY; without even the implied warranty of ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ! GNU General Public License for more details. ! You should have received a copy of the GNU General Public License ! along with this program. If not, see <https://www.gnu.org/licenses/>. !------------------------------------------------------------------------------------------------ #include "dtfft_config.h" module dtfft_transpose_handle_datatype !! This module describes [[transpose_handle_datatype]] class !! This class implements transposition using MPI_Ialltoall(w) !! with custom MPI datatypes !! For the end user this is `DTFFT_BACKEND_MPI_DATATYPE` - backend. !! But since it does not perform sequence: transpose -> exchange -> unpack, it is internally treated as tranpose_handle. use iso_fortran_env use dtfft_abstract_transpose_handle, only: abstract_transpose_handle, create_args, execute_args use dtfft_errors use dtfft_parameters use dtfft_pencil, only: pencil, get_transpose_type use dtfft_utils #include "_dtfft_mpi.h" #include "_dtfft_private.h" #include "_dtfft_profile.h" #include "_dtfft_cuda.h" implicit none private public :: transpose_handle_datatype integer(MPI_ADDRESS_KIND), parameter :: LB = 0 !! Lower bound for all derived datatypes #if defined (ENABLE_PERSISTENT_COMM) # if defined(ENABLE_PERSISTENT_COLLECTIVES) # if defined(OMPI_FIX_REQUIRED) logical, parameter :: IS_P2P_ENABLED = .true. # else logical, parameter :: IS_P2P_ENABLED = .false. # endif # else logical, parameter :: IS_P2P_ENABLED = .true. # endif #else # if defined(OMPI_FIX_REQUIRED) logical, parameter :: IS_P2P_ENABLED = .true. # else logical, parameter :: IS_P2P_ENABLED = .false. # endif #endif !! Is point-to-point communication enabled type :: handle_t !! Transposition handle class TYPE_MPI_DATATYPE, allocatable :: dtypes(:) !! Datatypes buffer integer(int32), allocatable :: counts(:) !! Number of datatypes (always equals 1) integer(int32), allocatable :: displs(:) !! Displacements is bytes contains procedure, pass(self) :: create => create_handle !! Creates transposition handle procedure, pass(self) :: destroy => destroy_handle !! Destroys transposition handle end type handle_t type, extends(abstract_transpose_handle) :: transpose_handle_datatype !! Tranpose backend that uses MPI_Ialltoall(w) with custom MPI datatypes private TYPE_MPI_COMM :: comm !! 1d communicator logical :: is_even = .false. !! Is decomposition even logical :: is_active = .false. !! Is async transposition active type(handle_t) :: send !! Handle to send data type(handle_t) :: recv !! Handle to recieve data TYPE_MPI_REQUEST, allocatable :: requests(:) !! Requests for communication integer(int32) :: n_requests !! Actual number of requests, can be less than size(requests) #if defined(ENABLE_PERSISTENT_COMM) logical :: is_request_created = .false. !! Is request created #endif contains private procedure, pass(self), public :: create_private => create !! Initializes class procedure, pass(self), public :: execute !! Performs MPI_Ialltoall(w) procedure, pass(self), public :: execute_end !! Waits for MPI_Ialltoall(w) to complete procedure, pass(self), public :: destroy !! Destroys class procedure, pass(self), public :: get_async_active !! Returns .true. if async transposition is active end type transpose_handle_datatype contains subroutine create_handle(self, n) !! Creates transposition handle class(handle_t), intent(inout) :: self !! Transposition handle integer(int32), intent(in) :: n !! Number of datatypes to be created call self%destroy() allocate(self%dtypes(n), source = MPI_DATATYPE_NULL) allocate(self%counts(n), source = 1_int32) allocate(self%displs(n), source = 0_int32) end subroutine create_handle subroutine destroy_handle(self) !! Destroys transposition handle class(handle_t), intent(inout) :: self !! Transposition handle integer(int32) :: i !! Counter integer(int32) :: ierr !! Error code if ( allocated(self%dtypes) ) then do i = 1, size(self%dtypes) call MPI_Type_free(self%dtypes(i), ierr) enddo deallocate(self%dtypes) endif if ( allocated(self%displs) ) deallocate(self%displs) if ( allocated(self%counts) ) deallocate(self%counts) end subroutine destroy_handle subroutine create(self, comm, send, recv, transpose_type, base_storage, kwargs) !! Creates `transpose_handle_datatype` class class(transpose_handle_datatype), intent(inout) :: self !! Transpose handle TYPE_MPI_COMM, intent(in) :: comm !! MPI Communicator type(pencil), intent(in) :: send !! Send pencil type(pencil), intent(in) :: recv !! Recv pencil type(dtfft_transpose_t), intent(in) :: transpose_type !! Type of transpose to create integer(int64), intent(in) :: base_storage !! Base storage type(create_args), intent(in) :: kwargs !! Additional arguments integer(int32) :: comm_size !! Size of 1d communicator integer(int32) :: n_neighbors !! Number of datatypes to be created integer(int32), allocatable :: recv_counts(:,:) !! Each processor should know how much data each processor recieves integer(int32), allocatable :: send_counts(:,:) !! Each processor should know how much data each processor sends integer(int32) :: i !! Counter integer(int32) :: ierr !! Error code integer(int32) :: send_displ, recv_displ call self%destroy() self%comm = comm call MPI_Comm_size(comm, comm_size, ierr) self%is_even = send%is_even .and. recv%is_even n_neighbors = comm_size; if ( self%is_even ) n_neighbors = 1 self%is_active = .false. allocate(self%requests(2 * comm_size)) #if defined(ENABLE_PERSISTENT_COMM) self%is_request_created = .false. #endif call self%send%create(n_neighbors) call self%recv%create(n_neighbors) allocate(recv_counts(recv%rank, comm_size), source = 0_int32) allocate(send_counts, source = recv_counts) call MPI_Allgather(recv%counts, int(recv%rank, int32), MPI_INTEGER4, recv_counts, int(recv%rank, int32), MPI_INTEGER4, comm, ierr) call MPI_Allgather(send%counts, int(send%rank, int32), MPI_INTEGER4, send_counts, int(send%rank, int32), MPI_INTEGER4, comm, ierr) do i = 1, n_neighbors if ( send%rank == 2 ) then call create_transpose_2d(send, send_counts(:,i), recv, recv_counts(:,i), kwargs%datatype_id, kwargs%base_type, base_storage, & self%send%dtypes(i), send_displ, self%recv%dtypes(i), recv_displ) else if ( any( transpose_type == [DTFFT_TRANSPOSE_X_TO_Y, DTFFT_TRANSPOSE_Y_TO_Z]) ) then call create_forw_permutation(send, send_counts(:,i), recv, recv_counts(:,i), kwargs%datatype_id, kwargs%base_type, base_storage, & self%send%dtypes(i), send_displ, self%recv%dtypes(i), recv_displ) else if ( any( transpose_type == [DTFFT_TRANSPOSE_Y_TO_X, DTFFT_TRANSPOSE_Z_TO_Y]) ) then call create_back_permutation(send, send_counts(:,i), recv, recv_counts(:,i), kwargs%datatype_id, kwargs%base_type, base_storage, & self%send%dtypes(i), send_displ, self%recv%dtypes(i), recv_displ) else if ( transpose_type == DTFFT_TRANSPOSE_X_TO_Z ) then call create_transpose_XZ(send, send_counts(:,i), recv, recv_counts(:,i), kwargs%datatype_id, kwargs%base_type, base_storage, & self%send%dtypes(i), send_displ, self%recv%dtypes(i), recv_displ) else call create_transpose_ZX(send, send_counts(:,i), recv, recv_counts(:,i), kwargs%datatype_id, kwargs%base_type, base_storage, & self%send%dtypes(i), send_displ, self%recv%dtypes(i), recv_displ) endif if ( self%is_even ) then self%send%displs(i) = send_displ self%recv%displs(i) = recv_displ else if ( i < n_neighbors ) then if ( any(send%counts == 0) ) then self%send%displs(i + 1) = self%send%displs(i) else self%send%displs(i + 1) = self%send%displs(i) + send_displ endif if ( any(recv%counts == 0) ) then self%recv%displs(i + 1) = self%recv%displs(i) else self%recv%displs(i + 1) = self%recv%displs(i) + recv_displ endif endif endif enddo if ( IS_P2P_ENABLED ) then self%send%displs(:) = self%send%displs(:) / int(FLOAT_STORAGE_SIZE, int32) self%recv%displs(:) = self%recv%displs(:) / int(FLOAT_STORAGE_SIZE, int32) if ( .not. self%is_even ) then if ( any(send%counts == 0) ) self%send%counts(:) = 0 if ( any(recv%counts == 0) ) self%recv%counts(:) = 0 do i = 1, n_neighbors if ( any(recv_counts(:,i) == 0) ) then self%send%counts(i) = 0 endif if ( any(send_counts(:,i) == 0) ) then self%recv%counts(i) = 0 endif enddo endif endif deallocate(recv_counts, send_counts) end subroutine create subroutine execute(self, in, out, kwargs, error_code) !! Executes transposition class(transpose_handle_datatype), intent(inout) :: self !! Transpose handle real(real32), intent(inout) :: in(:) !! Send pointer real(real32), intent(inout) :: out(:) !! Recv pointer type(execute_args), intent(inout) :: kwargs !! Additional arguments integer(int32), intent(out) :: error_code !! Result of execution integer(int32) :: i, comm_size, ierr error_code = DTFFT_SUCCESS if ( self%is_active ) then error_code = DTFFT_ERROR_TRANSPOSE_ACTIVE return endif call MPI_Comm_size(self%comm, comm_size, ierr) #if defined (ENABLE_PERSISTENT_COMM) if ( .not. self%is_request_created ) then # if defined(ENABLE_PERSISTENT_COLLECTIVES) if ( self%is_even ) then self%n_requests = 1 call MPI_Alltoall_init(in, 1, self%send%dtypes(1), out, 1, self%recv%dtypes(1), & self%comm, MPI_INFO_NULL, self%requests(1), ierr) else # if defined(OMPI_FIX_REQUIRED) self%n_requests = 0 do i = 1, comm_size if ( self%recv%counts(i) > 0 ) then self%n_requests = self%n_requests + 1 call MPI_Recv_init(out(self%recv%displs(i)), 1, self%recv%dtypes(i), i - 1, 0_int32, & self%comm, self%requests(self%n_requests), ierr) endif enddo do i = 1, comm_size if ( self%send%counts(i) > 0 ) then self%n_requests = self%n_requests + 1 call MPI_Send_init(in(self%send%displs(i)), 1, self%send%dtypes(i), i - 1, 0_int32, & self%comm, self%requests(self%n_requests), ierr) endif enddo # else self%n_requests = 1 call MPI_Alltoallw_init(in, self%send%counts, self%send%displs, self%send%dtypes, & out, self%recv%counts, self%recv%displs, self%recv%dtypes, & self%comm, MPI_INFO_NULL, self%requests(1), ierr) endif # endif # else self%n_requests = 0 if ( self%is_even ) then do i = 1, comm_size self%n_requests = self%n_requests + 1 call MPI_Recv_init(out((i - 1) * self%recv%displs(1) + 1), 1, self%recv%dtypes(1), i - 1, 0_int32, & self%comm, self%requests(self%n_requests), ierr) enddo do i = 1, comm_size self%n_requests = self%n_requests + 1 call MPI_Send_init(in((i - 1) * self%send%displs(1) + 1), 1, self%send%dtypes(1), i - 1, 0_int32, & self%comm, self%requests(self%n_requests), ierr) enddo else do i = 1, comm_size if ( self%recv%counts(i) > 0 ) then self%n_requests = self%n_requests + 1 call MPI_Recv_init(out(self%recv%displs(i) + 1), 1, self%recv%dtypes(i), i - 1, 0_int32, & self%comm, self%requests(self%n_requests), ierr) endif enddo do i = 1, comm_size if ( self%send%counts(i) > 0 ) then self%n_requests = self%n_requests + 1 call MPI_Send_init(in(self%send%displs(i) + 1), 1, self%send%dtypes(i), i - 1, 0_int32, & self%comm, self%requests(self%n_requests), ierr) endif enddo endif # endif self%is_request_created = .true. endif call MPI_Startall(self%n_requests, self%requests, ierr) #else if ( self%is_even ) then self%n_requests = 1 call MPI_Ialltoall(in, 1, self%send%dtypes(1), out, 1, self%recv%dtypes(1), & self%comm, self%requests(1), ierr) else # if defined(OMPI_FIX_REQUIRED) block integer(int32) :: i, comm_size call MPI_Comm_size(self%comm, comm_size, ierr) self%n_requests = 0 do i = 1, comm_size if ( self%recv%counts(i) > 0 ) then self%n_requests = self%n_requests + 1 call MPI_Irecv(out(self%recv%displs(i) + 1), 1, self%recv%dtypes(i), i - 1, 0_int32, & self%comm, self%requests(self%n_requests), ierr) endif enddo do i = 1, comm_size if ( self%send%counts(i) > 0 ) then self%n_requests = self%n_requests + 1 call MPI_Isend(in(self%send%displs(i) + 1), 1, self%send%dtypes(i), i - 1, 0_int32, & self%comm, self%requests(self%n_requests), ierr) endif enddo endblock # else self%n_requests = 1 call MPI_Ialltoallw(in, self%send%counts, self%send%displs, self%send%dtypes, & out, self%recv%counts, self%recv%displs, self%recv%dtypes, & self%comm, self%requests(1), ierr) # endif endif #endif self%is_active = .true. if ( kwargs%exec_type == EXEC_BLOCKING ) call self%execute_end(kwargs, error_code) end subroutine execute subroutine execute_end(self, kwargs, error_code) !! Ends execution of transposition class(transpose_handle_datatype), intent(inout) :: self !! Transpose handle type(execute_args), intent(inout) :: kwargs !! Additional arguments integer(int32), intent(out) :: error_code !! Error code integer(int32) :: ierr !! Error code error_code = DTFFT_SUCCESS if ( .not. self%is_active ) then error_code = DTFFT_ERROR_TRANSPOSE_NOT_ACTIVE return endif call MPI_Waitall(self%n_requests, self%requests, MPI_STATUSES_IGNORE, ierr) self%is_active = .false. end subroutine execute_end elemental logical function get_async_active(self) !! Returns if async transpose is active class(transpose_handle_datatype), intent(in) :: self !! Transpose handle get_async_active = self%is_active end function get_async_active subroutine destroy(self) !! Destroys `transpose_handle_datatype` class class(transpose_handle_datatype), intent(inout) :: self !! Transpose handle call self%send%destroy() call self%recv%destroy() #if defined(ENABLE_PERSISTENT_COLLECTIVES) block integer(int32) :: i, ierr if( self%is_request_created ) then do i = 1, self%n_requests call MPI_Request_free(self%requests(i), ierr) enddo self%is_request_created = .false. endif endblock #endif if( allocated(self%requests) ) deallocate( self%requests ) self%is_active = .false. self%is_even = .false. end subroutine destroy subroutine create_transpose_2d(send, send_counts, recv, recv_counts, datatype_id, base_type, base_storage, send_dtype, send_displ, recv_dtype, recv_displ) !! Creates two-dimensional transposition datatypes class(pencil), intent(in) :: send !! Information about send buffer integer(int32), intent(in) :: send_counts(:) !! Rank i is sending this counts class(pencil), intent(in) :: recv !! Information about send buffer integer(int32), intent(in) :: recv_counts(:) !! Rank i is recieving this counts integer(int8), intent(in) :: datatype_id !! Id of transpose plan to use TYPE_MPI_DATATYPE, intent(in) :: base_type !! Base MPI_Datatype integer(int64), intent(in) :: base_storage !! Number of bytes needed to store single element TYPE_MPI_DATATYPE, intent(out) :: send_dtype !! Datatype used to send data integer(int32), intent(out) :: send_displ !! Send displacement in bytes TYPE_MPI_DATATYPE, intent(out) :: recv_dtype !! Datatype used to recv data integer(int32), intent(out) :: recv_displ !! Recv displacement in bytes TYPE_MPI_DATATYPE :: temp1 !! Temporary datatype TYPE_MPI_DATATYPE :: temp2 !! Temporary datatype integer(int32) :: ierr !! Error code send_displ = recv_counts(2) * int(base_storage, int32) recv_displ = send_counts(2) * int(base_storage, int32) if ( datatype_id == 1 ) then call MPI_Type_vector(send%counts(2), recv_counts(2), send%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(send_displ, MPI_ADDRESS_KIND), send_dtype, ierr) call free_datatypes(temp1) call MPI_Type_vector(recv%counts(2), 1, recv%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr) call MPI_Type_contiguous(send_counts(2), temp2, recv_dtype, ierr) call free_datatypes(temp1, temp2) else call MPI_Type_vector(send%counts(2), 1, send%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr) call MPI_Type_contiguous(recv_counts(2), temp2, send_dtype, ierr) call free_datatypes(temp1, temp2) call MPI_Type_vector(recv%counts(2), send_counts(2), recv%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr) call free_datatypes(temp1) endif call MPI_Type_commit(send_dtype, ierr) call MPI_Type_commit(recv_dtype, ierr) end subroutine create_transpose_2d subroutine create_forw_permutation(send, send_counts, recv, recv_counts, datatype_id, base_type, base_storage, send_dtype, send_displ, recv_dtype, recv_displ) !! Creates three-dimensional X --> Y and Y -> Z transposition datatypes class(pencil), intent(in) :: send !! Information about send buffer integer(int32), intent(in) :: send_counts(:) !! Rank i is sending this counts class(pencil), intent(in) :: recv !! Information about send buffer integer(int32), intent(in) :: recv_counts(:) !! Rank i is recieving this counts integer(int8), intent(in) :: datatype_id !! Id of transpose plan to use TYPE_MPI_DATATYPE, intent(in) :: base_type !! Base MPI_Datatype integer(int64), intent(in) :: base_storage !! Number of bytes needed to store single element TYPE_MPI_DATATYPE, intent(out) :: send_dtype !! Datatype used to send data integer(int32), intent(out) :: send_displ !! Send displacement in bytes TYPE_MPI_DATATYPE, intent(out) :: recv_dtype !! Datatype used to recv data integer(int32), intent(out) :: recv_displ !! Recv displacement in bytes TYPE_MPI_DATATYPE :: temp1 !! Temporary datatype TYPE_MPI_DATATYPE :: temp2 !! Temporary datatype TYPE_MPI_DATATYPE :: temp3 !! Temporary datatype TYPE_MPI_DATATYPE :: temp4 !! Temporary datatype integer(int32) :: ierr !! Error code send_displ = recv_counts(3) * int(base_storage, int32) recv_displ = send_counts(2) * int(base_storage, int32) if ( datatype_id == 1 ) then ! This datatype_id has "contiguous" send and strided recieve datatype call MPI_Type_vector(send%counts(2) * send%counts(3), recv_counts(3), send%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(send_displ, MPI_ADDRESS_KIND), send_dtype, ierr) call free_datatypes(temp1) call MPI_Type_vector(recv%counts(3), 1, recv%counts(1) * recv%counts(2), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr) call MPI_Type_contiguous(send_counts(2), temp2, temp3, ierr) call MPI_Type_create_hvector(recv%counts(2), 1, int(recv%counts(1) * base_storage, MPI_ADDRESS_KIND), temp3, temp4, ierr) call MPI_Type_create_resized(temp4, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr) call free_datatypes(temp1, temp2, temp3, temp4) elseif ( datatype_id == 2 ) then ! This datatype_id has strided send and "contiguous" recieve datatypes call MPI_Type_vector(send%counts(2) * send%counts(3), 1, send%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr) call MPI_Type_contiguous(recv_counts(3), temp2, send_dtype, ierr) call free_datatypes(temp1, temp2) call MPI_Type_vector(recv%counts(2), send_counts(2), recv%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(recv_displ, MPI_ADDRESS_KIND), temp2, ierr) call MPI_Type_create_hvector(recv%counts(3), 1, int(recv%counts(1) * recv%counts(2) * base_storage, MPI_ADDRESS_KIND), temp2, temp3, ierr) call MPI_Type_create_resized(temp3, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr) call free_datatypes(temp1, temp2, temp3) endif call MPI_Type_commit(send_dtype, ierr) call MPI_Type_commit(recv_dtype, ierr) end subroutine create_forw_permutation subroutine create_back_permutation(send, send_counts, recv, recv_counts, datatype_id, base_type, base_storage, send_dtype, send_displ, recv_dtype, recv_displ) !! Creates three-dimensional Y --> X and Z --> Y transposition datatypes class(pencil), intent(in) :: send !! Information about send buffer integer(int32), intent(in) :: send_counts(:) !! Rank i is sending this counts class(pencil), intent(in) :: recv !! Information about send buffer integer(int32), intent(in) :: recv_counts(:) !! Rank i is recieving this counts integer(int8), intent(in) :: datatype_id !! Id of transpose plan to use TYPE_MPI_DATATYPE, intent(in) :: base_type !! Base MPI_Datatype integer(int64), intent(in) :: base_storage !! Number of bytes needed to store single element TYPE_MPI_DATATYPE, intent(out) :: send_dtype !! Datatype used to send data integer(int32), intent(out) :: send_displ !! Send displacement in bytes TYPE_MPI_DATATYPE, intent(out) :: recv_dtype !! Datatype used to recv data integer(int32), intent(out) :: recv_displ !! Recv displacement in bytes TYPE_MPI_DATATYPE :: temp1 !! Temporary datatype TYPE_MPI_DATATYPE :: temp2 !! Temporary datatype TYPE_MPI_DATATYPE :: temp3 !! Temporary datatype TYPE_MPI_DATATYPE :: temp4 !! Temporary datatype integer(int32) :: ierr !! Error code send_displ = recv_counts(2) * int(base_storage, int32) recv_displ = send_counts(3) * int(base_storage, int32) if ( datatype_id == 1 ) then ! This datatype_id has "contiguous" send and strided recieve datatype call MPI_Type_vector(send%counts(2) * send%counts(3), recv_counts(2), send%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(send_displ, MPI_ADDRESS_KIND), send_dtype, ierr) call free_datatypes(temp1) call MPI_Type_vector(recv%counts(2) * recv%counts(3), 1, recv%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr) call MPI_Type_contiguous(send_counts(3), temp2, temp3, ierr) call MPI_Type_create_resized(temp3, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr) call free_datatypes(temp1, temp2, temp3) elseif ( datatype_id == 2 ) then ! This datatype_id has strided send and "contiguous" recieve datatypes call MPI_Type_vector(send%counts(3), 1, send%counts(1) * send%counts(2), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr) call MPI_Type_contiguous(recv_counts(2), temp2, temp3, ierr) call MPI_Type_create_hvector(send%counts(2), 1, int(send%counts(1) * base_storage, MPI_ADDRESS_KIND), temp3, temp4, ierr) call MPI_Type_create_resized(temp4, LB, int(send_displ, MPI_ADDRESS_KIND), send_dtype, ierr) call free_datatypes(temp1, temp2, temp3, temp4) call MPI_Type_vector(recv%counts(2) * recv%counts(3), send_counts(3), recv%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr) call free_datatypes(temp1) endif call MPI_Type_commit(send_dtype, ierr) call MPI_Type_commit(recv_dtype, ierr) end subroutine create_back_permutation subroutine create_transpose_XZ(send, send_counts, recv, recv_counts, datatype_id, base_type, base_storage, send_dtype, send_displ, recv_dtype, recv_displ) !! Creates three-dimensional X --> Z transposition datatypes !! Can only be used with 3D slab decomposition when slabs are distributed in Z direction class(pencil), intent(in) :: send !! Information about send buffer integer(int32), intent(in) :: send_counts(:) !! Rank i is sending this counts class(pencil), intent(in) :: recv !! Information about send buffer integer(int32), intent(in) :: recv_counts(:) !! Rank i is recieving this counts integer(int8), intent(in) :: datatype_id !! Id of transpose plan to use TYPE_MPI_DATATYPE, intent(in) :: base_type !! Base MPI_Datatype integer(int64), intent(in) :: base_storage !! Number of bytes needed to store single element TYPE_MPI_DATATYPE, intent(out) :: send_dtype !! Datatype used to send data integer(int32), intent(out) :: send_displ !! Send displacement in bytes TYPE_MPI_DATATYPE, intent(out) :: recv_dtype !! Datatype used to recv data integer(int32), intent(out) :: recv_displ !! Recv displacement in bytes TYPE_MPI_DATATYPE :: temp1 !! Temporary datatype TYPE_MPI_DATATYPE :: temp2 !! Temporary datatype TYPE_MPI_DATATYPE :: temp3 !! Temporary datatype TYPE_MPI_DATATYPE :: temp4 !! Temporary datatype integer(int32) :: ierr !! Error code send_displ = send%counts(1) * recv_counts(3) * int(base_storage, int32) recv_displ = send_counts(3) * int(base_storage, int32) if ( datatype_id == 1 ) then call MPI_Type_vector(send%counts(3), send%counts(1), send%counts(1) * send%counts(2), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(send%counts(1) * base_storage, MPI_ADDRESS_KIND), temp2, ierr) call MPI_Type_contiguous(recv_counts(3), temp2, send_dtype, ierr) call free_datatypes(temp1, temp2) call MPI_Type_vector(recv%counts(2), 1, recv%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr) call MPI_Type_contiguous(send_counts(3), temp2, temp3, ierr) call MPI_Type_create_hvector(recv%counts(3), 1, int(recv%counts(1) * recv%counts(2) * base_storage, MPI_ADDRESS_KIND), temp3, temp4, ierr) call MPI_Type_create_resized(temp4, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr) call free_datatypes(temp1, temp2, temp3, temp4) else call MPI_Type_vector(send%counts(3), 1, send%counts(1) * send%counts(2), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr) call MPI_Type_contiguous(send%counts(1), temp2, temp3, ierr) call MPI_Type_create_hvector(recv_counts(3), 1, int(send%counts(1) * base_storage, MPI_ADDRESS_KIND), temp3, temp4, ierr) call MPI_Type_create_resized(temp4, LB, int(send_displ, MPI_ADDRESS_KIND), send_dtype, ierr) call free_datatypes(temp1, temp2, temp3, temp4) call MPI_Type_vector(recv%counts(2) * recv%counts(3), send_counts(3), recv%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr) call free_datatypes(temp1) endif call MPI_Type_commit(send_dtype, ierr) call MPI_Type_commit(recv_dtype, ierr) end subroutine create_transpose_XZ subroutine create_transpose_ZX(send, send_counts, recv, recv_counts, datatype_id, base_type, base_storage, send_dtype, send_displ, recv_dtype, recv_displ) !! Creates three-dimensional Z --> X transposition datatypes !! Can only be used with 3D slab decomposition when slabs are distributed in Z direction class(pencil), intent(in) :: send !! Information about send buffer integer(int32), intent(in) :: send_counts(:) !! Rank i is sending this counts class(pencil), intent(in) :: recv !! Information about send buffer integer(int32), intent(in) :: recv_counts(:) !! Rank i is recieving this counts integer(int8), intent(in) :: datatype_id !! Id of transpose plan to use TYPE_MPI_DATATYPE, intent(in) :: base_type !! Base MPI_Datatype integer(int64), intent(in) :: base_storage !! Number of bytes needed to store single element TYPE_MPI_DATATYPE, intent(out) :: send_dtype !! Datatype used to send data integer(int32), intent(out) :: send_displ !! Send displacement in bytes TYPE_MPI_DATATYPE, intent(out) :: recv_dtype !! Datatype used to recv data integer(int32), intent(out) :: recv_displ !! Recv displacement in bytes TYPE_MPI_DATATYPE :: temp1 !! Temporary datatype TYPE_MPI_DATATYPE :: temp2 !! Temporary datatype TYPE_MPI_DATATYPE :: temp3 !! Temporary datatype TYPE_MPI_DATATYPE :: temp4 !! Temporary datatype integer(int32) :: ierr !! Error code send_displ = recv_counts(3) * int(base_storage, int32) recv_displ = recv%counts(1) * send_counts(3) * int(base_storage, int32) if ( datatype_id == 1 ) then call MPI_Type_vector(send%counts(2) * send%counts(3), recv_counts(3), send%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(send_displ, MPI_ADDRESS_KIND), send_dtype, ierr) call free_datatypes(temp1) call MPI_Type_vector(recv%counts(3), 1, recv%counts(1) * recv%counts(2), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr) call MPI_Type_contiguous(recv%counts(1), temp2, temp3, ierr) call MPI_Type_create_hvector(send_counts(3), 1, int(recv%counts(1) * base_storage, MPI_ADDRESS_KIND), temp3, temp4, ierr) call MPI_Type_create_resized(temp4, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr) call free_datatypes(temp1, temp2, temp3, temp4) else call MPI_Type_vector(send%counts(2) * send%counts(3), 1, send%counts(1), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(base_storage, MPI_ADDRESS_KIND), temp2, ierr) call MPI_Type_contiguous(recv_counts(3), temp2, send_dtype, ierr) call free_datatypes(temp1, temp2) call MPI_Type_vector(recv%counts(3), recv%counts(1) * send_counts(3), recv%counts(1) * recv%counts(2), base_type, temp1, ierr) call MPI_Type_create_resized(temp1, LB, int(recv_displ, MPI_ADDRESS_KIND), recv_dtype, ierr) call free_datatypes(temp1) endif call MPI_Type_commit(send_dtype, ierr) call MPI_Type_commit(recv_dtype, ierr) end subroutine create_transpose_ZX subroutine free_datatypes(t1, t2, t3, t4) !! Frees temporary datatypes TYPE_MPI_DATATYPE, intent(inout), optional :: t1 !! Temporary datatype TYPE_MPI_DATATYPE, intent(inout), optional :: t2 !! Temporary datatype TYPE_MPI_DATATYPE, intent(inout), optional :: t3 !! Temporary datatype TYPE_MPI_DATATYPE, intent(inout), optional :: t4 !! Temporary datatype integer(int32) :: ierr !! Error code if ( present(t1) ) call MPI_Type_free(t1, ierr) if ( present(t2) ) call MPI_Type_free(t2, ierr) if ( present(t3) ) call MPI_Type_free(t3, ierr) if ( present(t4) ) call MPI_Type_free(t4, ierr) end subroutine free_datatypes end module dtfft_transpose_handle_datatype