test_host_kernels.F90 Source File


This file depends on

sourcefile~~test_host_kernels.f90~~EfferentGraph sourcefile~test_host_kernels.f90 test_host_kernels.F90 sourcefile~dtfft_kernel_host.f90 dtfft_kernel_host.F90 sourcefile~test_host_kernels.f90->sourcefile~dtfft_kernel_host.f90 sourcefile~dtfft_abstract_kernel.f90 dtfft_abstract_kernel.F90 sourcefile~dtfft_kernel_host.f90->sourcefile~dtfft_abstract_kernel.f90 sourcefile~dtfft_config.f90 dtfft_config.F90 sourcefile~dtfft_kernel_host.f90->sourcefile~dtfft_config.f90 sourcefile~dtfft_interface_nvtx.f90 dtfft_interface_nvtx.F90 sourcefile~dtfft_kernel_host.f90->sourcefile~dtfft_interface_nvtx.f90 sourcefile~dtfft_parameters.f90 dtfft_parameters.F90 sourcefile~dtfft_kernel_host.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_utils.f90 dtfft_utils.F90 sourcefile~dtfft_kernel_host.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_config.f90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_interface_nvtx.f90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_abstract_compressor.f90 dtfft_abstract_compressor.F90 sourcefile~dtfft_abstract_kernel.f90->sourcefile~dtfft_abstract_compressor.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_config.f90->sourcefile~dtfft_abstract_compressor.f90 sourcefile~dtfft_errors.f90 dtfft_errors.F90 sourcefile~dtfft_config.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_interface_cuda_runtime.f90 dtfft_interface_cuda_runtime.F90 sourcefile~dtfft_config.f90->sourcefile~dtfft_interface_cuda_runtime.f90 sourcefile~dtfft_interface_nvtx.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_utils.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_utils.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_abstract_compressor.f90->sourcefile~dtfft_interface_nvtx.f90 sourcefile~dtfft_abstract_compressor.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_abstract_compressor.f90->sourcefile~dtfft_utils.f90 sourcefile~dtfft_abstract_compressor.f90->sourcefile~dtfft_errors.f90 sourcefile~dtfft_interface_cuda_runtime.f90->sourcefile~dtfft_parameters.f90 sourcefile~dtfft_interface_cuda_runtime.f90->sourcefile~dtfft_utils.f90

Source Code

program test_host_kernels
use iso_fortran_env
use dtfft_kernel_host
implicit none
    real(real32), allocatable :: in(:), out(:), gold(:), temp(:)
    integer(int32) :: dims3(3), dims2(2), i

    dims3 = [33, 77, 21]
    dims2 = [90, 57]


    call run_permute_forward(dims3)
    call run_permute_forward(dims2)

    call run_permute_backward(dims3)

    call run_permute_backward_start(dims3)

    call run_permute_backward_end(dims3)

    call run_pack_unpack(dims3)
    call run_pack_unpack(dims2)

contains

    subroutine run_permute_forward(dims)
        integer(int32), intent(in) :: dims(:)
        integer(int32) :: locals(5, 1)
        integer(int32) :: unpack_dims(3)

        print*,'Testing permute_forward kernels: ndims = ',size(dims)

        allocate( in(product(dims)), out(product(dims)), gold(product(dims)) )

        do i = 1, size(in)
            in(i) = real(i, real32)
        enddo

        call permute_forward_write_f32(in, gold, dims)

        call permute_forward_read_f32(in, out, dims);                   call compare("permute_forward_read_f32")
        call permute_forward_write_f32_block_4(in, out, dims);          call compare("permute_forward_write_f32_block_4")
        call permute_forward_write_f32_block_8(in, out, dims);          call compare("permute_forward_write_f32_block_8")
        call permute_forward_write_f32_block_16(in, out, dims);         call compare("permute_forward_write_f32_block_16")
        call permute_forward_write_f32_block_32(in, out, dims);         call compare("permute_forward_write_f32_block_32")
        call permute_forward_write_f32_block_64(in, out, dims);         call compare("permute_forward_write_f32_block_64")
        call permute_forward_read_f32_block_4(in, out, dims);           call compare("permute_forward_read_f32_block_4")
        call permute_forward_read_f32_block_8(in, out, dims);           call compare("permute_forward_read_f32_block_8")
        call permute_forward_read_f32_block_16(in, out, dims);          call compare("permute_forward_read_f32_block_16")
        call permute_forward_read_f32_block_32(in, out, dims);          call compare("permute_forward_read_f32_block_32")
        call permute_forward_read_f32_block_64(in, out, dims);          call compare("permute_forward_read_f32_block_64")

        locals(1:size(dims), 1) = dims
        locals(4, 1) = 0
        locals(5, 1) = 0

        call pack_forward_write_f32(in, out, dims, locals(:, 1));             call compare("pack_forward_write_f32")
        call pack_forward_write_f32_block_4(in, out, dims, locals(:, 1));     call compare("pack_forward_write_f32_block_4")
        call pack_forward_write_f32_block_8(in, out, dims, locals(:, 1));     call compare("pack_forward_write_f32_block_8")
        call pack_forward_write_f32_block_16(in, out, dims, locals(:, 1));    call compare("pack_forward_write_f32_block_16")
        call pack_forward_write_f32_block_32(in, out, dims, locals(:, 1));    call compare("pack_forward_write_f32_block_32")
        call pack_forward_write_f32_block_64(in, out, dims, locals(:, 1));    call compare("pack_forward_write_f32_block_64")
        call pack_forward_read_f32(in, out, dims, locals(:, 1));              call compare("pack_forward_read_f32")
        call pack_forward_read_f32_block_4(in, out, dims, locals(:, 1));      call compare("pack_forward_read_f32_block_4")
        call pack_forward_read_f32_block_8(in, out, dims, locals(:, 1));      call compare("pack_forward_read_f32_block_8")
        call pack_forward_read_f32_block_16(in, out, dims, locals(:, 1));     call compare("pack_forward_read_f32_block_16")
        call pack_forward_read_f32_block_32(in, out, dims, locals(:, 1));     call compare("pack_forward_read_f32_block_32")
        call pack_forward_read_f32_block_64(in, out, dims, locals(:, 1));     call compare("pack_forward_read_f32_block_64")

        if ( size(dims) == 2 ) then
            unpack_dims = [dims(2), dims(1), 1]
        else
            unpack_dims = [dims(2), dims(3), dims(1)]
        endif
        locals(1:size(unpack_dims), 1) = unpack_dims

        call unpack_forward_write_f32(in, out, unpack_dims(1:size(dims)), locals);            call compare("unpack_forward_write_f32")
        call unpack_forward_write_f32_block_4(in, out, unpack_dims(1:size(dims)), locals);    call compare("unpack_forward_write_f32_block_4")
        call unpack_forward_write_f32_block_8(in, out, unpack_dims(1:size(dims)), locals);    call compare("unpack_forward_write_f32_block_8")
        call unpack_forward_write_f32_block_16(in, out, unpack_dims(1:size(dims)), locals);   call compare("unpack_forward_write_f32_block_16")
        call unpack_forward_write_f32_block_32(in, out, unpack_dims(1:size(dims)), locals);   call compare("unpack_forward_write_f32_block_32")
        call unpack_forward_write_f32_block_64(in, out, unpack_dims(1:size(dims)), locals);   call compare("unpack_forward_write_f32_block_64")
        call unpack_forward_read_f32(in, out, unpack_dims(1:size(dims)), locals);             call compare("unpack_forward_read_f32")
        call unpack_forward_read_f32_block_4(in, out, unpack_dims(1:size(dims)), locals);     call compare("unpack_forward_read_f32_block_4")
        call unpack_forward_read_f32_block_8(in, out, unpack_dims(1:size(dims)), locals);     call compare("unpack_forward_read_f32_block_8")
        call unpack_forward_read_f32_block_16(in, out, unpack_dims(1:size(dims)), locals);    call compare("unpack_forward_read_f32_block_16")
        call unpack_forward_read_f32_block_32(in, out, unpack_dims(1:size(dims)), locals);    call compare("unpack_forward_read_f32_block_32")
        call unpack_forward_read_f32_block_64(in, out, unpack_dims(1:size(dims)), locals);    call compare("unpack_forward_read_f32_block_64")

    !     print*,'Testing transpose_3d_cache_oblivious kernel'
    !     call cpu_time(start_time)
    !     call transpose_3d_cache_oblivious(in, out, dims)!;call compare
    ! call cpu_time(end_time)
    ! print*,'Time for transpose_3d_cache_oblivious: ', end_time - start_time
        deallocate(in, out, gold)

        print*,'SUCCESS'
    end subroutine run_permute_forward

    subroutine run_permute_backward(dims)
        integer(int32), intent(in) :: dims(:)
        integer(int32) :: locals(5, 1)
        integer(int32) :: unpack_dims(3)

        print*,'Testing permute_backward kernels'

        allocate( in(product(dims)), out(product(dims)), gold(product(dims)) )

        do i = 1, size(in)
            in(i) = real(i, real32)
        enddo

        call permute_backward_write_f32(in, gold, dims)

        call permute_backward_read_f32(in, out, dims);              call compare("permute_backward_read_f32")
        call permute_backward_write_f32_block_4(in, out, dims);     call compare("permute_backward_write_f32_block_4")
        call permute_backward_write_f32_block_8(in, out, dims);     call compare("permute_backward_write_f32_block_8")
        call permute_backward_write_f32_block_16(in, out, dims);    call compare("permute_backward_write_f32_block_16")
        call permute_backward_write_f32_block_32(in, out, dims);    call compare("permute_backward_write_f32_block_32")
        call permute_backward_write_f32_block_64(in, out, dims);    call compare("permute_backward_write_f32_block_64")
        call permute_backward_read_f32_block_4(in, out, dims);      call compare("permute_backward_read_f32_block_4")
        call permute_backward_read_f32_block_8(in, out, dims);      call compare("permute_backward_read_f32_block_8")
        call permute_backward_read_f32_block_16(in, out, dims);     call compare("permute_backward_read_f32_block_16")
        call permute_backward_read_f32_block_32(in, out, dims);     call compare("permute_backward_read_f32_block_32")
        call permute_backward_read_f32_block_64(in, out, dims);     call compare("permute_backward_read_f32_block_64")

        locals(1:size(dims), 1) = dims
        locals(4, 1) = 0
        locals(5, 1) = 0

        call pack_backward_write_f32(in, out, dims, locals(:, 1));                call compare("pack_backward_write_f32")
        call pack_backward_write_f32_block_4(in, out, dims, locals(:, 1));        call compare("pack_backward_write_f32_block_4")
        call pack_backward_write_f32_block_8(in, out, dims, locals(:, 1));        call compare("pack_backward_write_f32_block_8")
        call pack_backward_write_f32_block_16(in, out, dims, locals(:, 1));       call compare("pack_backward_write_f32_block_16")
        call pack_backward_write_f32_block_32(in, out, dims, locals(:, 1));       call compare("pack_backward_write_f32_block_32")
        call pack_backward_write_f32_block_64(in, out, dims, locals(:, 1));       call compare("pack_backward_write_f32_block_64")
        call pack_backward_read_f32(in, out, dims, locals(:, 1));                 call compare("pack_backward_read_f32")
        call pack_backward_read_f32_block_4(in, out, dims, locals(:, 1));         call compare("pack_backward_read_f32_block_4")
        call pack_backward_read_f32_block_8(in, out, dims, locals(:, 1));         call compare("pack_backward_read_f32_block_8")
        call pack_backward_read_f32_block_16(in, out, dims, locals(:, 1));        call compare("pack_backward_read_f32_block_16")
        call pack_backward_read_f32_block_32(in, out, dims, locals(:, 1));        call compare("pack_backward_read_f32_block_32")
        call pack_backward_read_f32_block_64(in, out, dims, locals(:, 1));        call compare("pack_backward_read_f32_block_64")

        unpack_dims = [dims(3), dims(1), dims(2)]
        locals(1:size(unpack_dims), 1) = unpack_dims

        call unpack_backward_write_f32(in, out, unpack_dims, locals);            call compare("unpack_backward_write_f32")
        call unpack_backward_write_f32_block_4(in, out, unpack_dims, locals);    call compare("unpack_backward_write_f32_block_4")
        call unpack_backward_write_f32_block_8(in, out, unpack_dims, locals);    call compare("unpack_backward_write_f32_block_8")
        call unpack_backward_write_f32_block_16(in, out, unpack_dims, locals);   call compare("unpack_backward_write_f32_block_16")
        call unpack_backward_write_f32_block_32(in, out, unpack_dims, locals);   call compare("unpack_backward_write_f32_block_32")
        call unpack_backward_write_f32_block_64(in, out, unpack_dims, locals);   call compare("unpack_backward_write_f32_block_64")
        call unpack_backward_read_f32(in, out, unpack_dims, locals);             call compare("unpack_backward_read_f32")
        call unpack_backward_read_f32_block_4(in, out, unpack_dims, locals);     call compare("unpack_backward_read_f32_block_4")
        call unpack_backward_read_f32_block_8(in, out, unpack_dims, locals);     call compare("unpack_backward_read_f32_block_8")
        call unpack_backward_read_f32_block_16(in, out, unpack_dims, locals);    call compare("unpack_backward_read_f32_block_16")
        call unpack_backward_read_f32_block_32(in, out, unpack_dims, locals);    call compare("unpack_backward_read_f32_block_32")
        call unpack_backward_read_f32_block_64(in, out, unpack_dims, locals);    call compare("unpack_backward_read_f32_block_64")

        deallocate(in, out, gold)

        print*,'SUCCESS'
    end subroutine run_permute_backward

    subroutine run_permute_backward_start(dims)
        integer(int32), intent(in) :: dims(:)

        print*,'Testing permute_backward_start kernels'

        allocate( in(product(dims)), out(product(dims)), gold(product(dims)) )

        do i = 1, size(in)
            in(i) = real(i, real32)
        enddo

        call permute_backward_start_write_f32(in, gold, dims)

        call permute_backward_start_read_f32(in, out, dims);            call compare("permute_backward_start_read_f32")
        call permute_backward_start_write_f32_block_4(in, out, dims);   call compare("permute_backward_start_write_f32_block_4")
        call permute_backward_start_write_f32_block_8(in, out, dims);   call compare("permute_backward_start_write_f32_block_8")
        call permute_backward_start_write_f32_block_16(in, out, dims);  call compare("permute_backward_start_write_f32_block_16")
        call permute_backward_start_write_f32_block_32(in, out, dims);  call compare("permute_backward_start_write_f32_block_32")
        call permute_backward_start_write_f32_block_64(in, out, dims);  call compare("permute_backward_start_write_f32_block_64")
        call permute_backward_start_read_f32_block_4(in, out, dims);    call compare("permute_backward_start_read_f32_block_4")
        call permute_backward_start_read_f32_block_8(in, out, dims);    call compare("permute_backward_start_read_f32_block_8")
        call permute_backward_start_read_f32_block_16(in, out, dims);   call compare("permute_backward_start_read_f32_block_16")
        call permute_backward_start_read_f32_block_32(in, out, dims);   call compare("permute_backward_start_read_f32_block_32")
        call permute_backward_start_read_f32_block_64(in, out, dims);   call compare("permute_backward_start_read_f32_block_64")

        deallocate(in, out, gold)

        print*,'SUCCESS'
    end subroutine run_permute_backward_start

    subroutine run_permute_backward_end(dims)
        integer(int32), intent(in) :: dims(:)
        integer(int32) :: temp_dims(size(dims))
        integer(int32) :: locals(5, 1)

        print*,'Testing permute_backward_end kernels'

        allocate( in(product(dims)), temp(product(dims)), out(product(dims)), gold(product(dims)) )

        do i = 1, size(in)
            in(i) = real(i, real32)
        enddo
        gold(:) = -1._real32
        out(:) = -1._real32

        call permute_backward_write_f32(in, gold, dims)
        call permute_backward_start_write_f32(in, temp, dims)

        temp_dims = [dims(3), dims(1), dims(2)]

        locals(1:size(dims), 1) = temp_dims
        locals(4, 1) = 0
        locals(5, 1) = 0

        call permute_backward_end_write_f32(temp, out, temp_dims, locals);           call compare("permute_backward_end_write_f32")
        call permute_backward_end_write_f32_block_4(temp, out, temp_dims, locals);   call compare("permute_backward_end_write_f32_block_4")
        call permute_backward_end_write_f32_block_8(temp, out, temp_dims, locals);   call compare("permute_backward_end_write_f32_block_8")
        call permute_backward_end_write_f32_block_16(temp, out, temp_dims, locals);  call compare("permute_backward_end_write_f32_block_16")
        call permute_backward_end_write_f32_block_32(temp, out, temp_dims, locals);  call compare("permute_backward_end_write_f32_block_32")
        call permute_backward_end_write_f32_block_64(temp, out, temp_dims, locals);  call compare("permute_backward_end_write_f32_block_64")
        call permute_backward_end_read_f32(temp, out, temp_dims, locals);            call compare("permute_backward_end_read_f32")
        call permute_backward_end_read_f32_block_4(temp, out, temp_dims, locals);    call compare("permute_backward_end_read_f32_block_4")
        call permute_backward_end_read_f32_block_8(temp, out, temp_dims, locals);    call compare("permute_backward_end_read_f32_block_8")
        call permute_backward_end_read_f32_block_16(temp, out, temp_dims, locals);   call compare("permute_backward_end_read_f32_block_16")
        call permute_backward_end_read_f32_block_32(temp, out, temp_dims, locals);   call compare("permute_backward_end_read_f32_block_32")
        call permute_backward_end_read_f32_block_64(temp, out, temp_dims, locals);   call compare("permute_backward_end_read_f32_block_64")
        deallocate(in, out, gold, temp)

        print*,'SUCCESS'
    end subroutine run_permute_backward_end

    subroutine run_pack_unpack(dims)
        integer(int32), intent(in) :: dims(:)
        integer(int32) :: locals(5, 1)

        print*,'Testing pack/unpack kernels, ndims = ',size(dims)

        allocate( in(product(dims)), out(product(dims)), gold(product(dims)) )

        do i = 1, size(in)
            in(i) = real(i, real32)
            gold(i) = in(i)
        enddo
        ! gold(:) = -1._real32
        out(:) = -1._real32

        locals(1:size(dims), 1) = dims
        locals(4, 1) = 0
        locals(5, 1) = 0

        call unpack_f32(in, out, dims, locals); call compare("unpack_f32")
        call unpack_f32_block_4(in, out, dims, locals); call compare("unpack_f32_block_4")
        call unpack_f32_block_8(in, out, dims, locals); call compare("unpack_f32_block_8")
        call unpack_f32_block_16(in, out, dims, locals); call compare("unpack_f32_block_16")
        call unpack_f32_block_32(in, out, dims, locals); call compare("unpack_f32_block_32")
        call unpack_f32_block_64(in, out, dims, locals); call compare("unpack_f32_block_64")

        call pack_f32(in, out, dims, locals); call compare("pack_f32")
        call pack_f32_block_4(in, out, dims, locals); call compare("pack_f32_block_4")
        call pack_f32_block_8(in, out, dims, locals); call compare("pack_f32_block_8")
        call pack_f32_block_16(in, out, dims, locals); call compare("pack_f32_block_16")
        call pack_f32_block_32(in, out, dims, locals); call compare("pack_f32_block_32")
        call pack_f32_block_64(in, out, dims, locals); call compare("pack_f32_block_64")

        deallocate(in, out, gold)

        print*,'SUCCESS'
    end subroutine run_pack_unpack


    subroutine compare(test_name)
        character(len=*), intent(in) :: test_name
        ! print*, 'Comparing results...'
        ! print*, gold(1:10)
        ! print*, out(1:10)
        do i = 1, size(gold)
            if ( abs(out(i) - gold(i)) > 1e-6 ) then
                print*, i, out(i), gold(i)
                error stop "Test failed: "//test_name
            endif
            out(i) = -1._real32
        enddo
    end subroutine compare


!    recursive subroutine transpose_3d_recursive(in, out, &
!                                               in_start, in_dims, in_strides, &
!                                               out_start, out_dims, out_strides, &
!                                               base_case_size)
!         real(real32), intent(in) :: in(*)
!         real(real32), intent(out) :: out(*)
!         integer(int32), intent(in) :: in_start(3), in_dims(3), in_strides(3)
!         integer(int32), intent(in) :: out_start(3), out_dims(3), out_strides(3)
!         integer(int32), intent(in) :: base_case_size

!         integer(int32) :: new_start(3), new_dims(3)
!         integer(int32) :: split_dim


!         if (all(in_dims < base_case_size)) then
!             call transpose_3d_base_case(in, out, &
!                                       in_start, in_dims, in_strides, &
!                                       out_start, out_dims, out_strides)
!             return
!         end if

!         split_dim = maxloc(in_dims, dim=1)


!         new_dims = in_dims
!         new_dims(split_dim) = in_dims(split_dim) / 2

!         new_start = in_start
!         call transpose_3d_recursive(in, out, &
!                                   new_start, new_dims, in_strides, &
!                                   out_start, out_dims, out_strides, &
!                                   base_case_size)

!         new_start(split_dim) = in_start(split_dim) + new_dims(split_dim)
!         new_dims(split_dim) = in_dims(split_dim) - new_dims(split_dim)
!         call transpose_3d_recursive(in, out, &
!                                   new_start, new_dims, in_strides, &
!                                   out_start, out_dims, out_strides, &
!                                   base_case_size)
!     end subroutine transpose_3d_recursive

!     subroutine transpose_3d_base_case(in, out, &
!                                     in_start, in_dims, in_strides, &
!                                     out_start, out_dims, out_strides)
!         real(real32), intent(in) :: in(*)
!         real(real32), intent(out) :: out(*)
!         integer(int32), intent(in) :: in_start(3), in_dims(3), in_strides(3)
!         integer(int32), intent(in) :: out_start(3), out_dims(3), out_strides(3)

!         integer(int32) :: x, y, z
!         integer(int32) :: in_idx, out_idx
!         integer(int32) :: nx, ny, nz

!         nx = in_dims(1)
!         ny = in_dims(2)
!         nz = in_dims(3)

!         do x = in_start(1), in_start(1) + nx - 1
!             do z = in_start(3), in_start(3) + nz - 1
!                 do y = in_start(2), in_start(2) + ny - 1
!                     in_idx = x * in_strides(1) + &
!                              y * in_strides(2) + &
!                              z * in_strides(3) + 1

!                     out_idx = y * out_strides(1) + &
!                               z * out_strides(2) + &
!                               x * out_strides(3) + 1

!                     out(out_idx) = in(in_idx)
!                 end do
!             end do
!         end do
!     end subroutine transpose_3d_base_case

    ! subroutine transpose_3d_cache_oblivious(in, out, dims)
    !     real(real32), intent(in) :: in(*)
    !     real(real32), intent(out) :: out(*)
    !     integer(int32), intent(in) :: dims(3)

    !     integer(int32) :: in_start(3), in_dims(3), in_strides(3)
    !     integer(int32) :: out_start(3), out_dims(3), out_strides(3)
    !     integer(int32) :: base_case_size

    !     in_start = [0, 0, 0]
    !     in_dims = dims
    !     in_strides = [1, dims(1), dims(1) * dims(2)]

    !     out_start = [0, 0, 0]
    !     out_dims = [dims(2), dims(3), dims(1)]  ! (y, z, x)
    !     out_strides = [1, dims(2), dims(2) * dims(3)]

    !     base_case_size = 32

    !     call transpose_3d_recursive(in, out, &
    !                               in_start, in_dims, in_strides, &
    !                               out_start, out_dims, out_strides, &
    !                               base_case_size)
    ! end subroutine transpose_3d_cache_oblivious

end program test_host_kernels