!------------------------------------------------------------------------------------------------ ! Copyright (c) 2021 - 2025, Oleg Shatrov ! All rights reserved. ! This file is part of dtFFT library. ! dtFFT is free software: you can redistribute it and/or modify ! it under the terms of the GNU General Public License as published by ! the Free Software Foundation, either version 3 of the License, or ! (at your option) any later version. ! dtFFT is distributed in the hope that it will be useful, ! but WITHOUT ANY WARRANTY; without even the implied warranty of ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ! GNU General Public License for more details. ! You should have received a copy of the GNU General Public License ! along with this program. If not, see <https://www.gnu.org/licenses/>. !------------------------------------------------------------------------------------------------ module dtfft_interface_nvrtc !! nvRTC Interfaces. !! !! nvRTC is loaded at runtime via dynamic loading due to explicit cuda_driver linking by cmake. use iso_c_binding use iso_fortran_env, only: int32 use dtfft_errors, only: DTFFT_SUCCESS use dtfft_utils, only: dynamic_load, string, destroy_strings, string_c2f implicit none private #include "_dtfft_private.h" public :: nvrtcGetErrorString public :: load_nvrtc public :: nvrtcProgram type, bind(C) :: nvrtcProgram !! nvrtcProgram is the unit of compilation, and an opaque handle for a program. type(c_ptr) :: cptr !! Actual pointer end type nvrtcProgram abstract interface function nvrtcGetErrorString_interface(error_code) & result(string) !! Helper function that returns a string describing the given nvrtcResult code !! For unrecognized enumeration values, it returns "NVRTC_ERROR unknown" import integer(c_int), value :: error_code !! CUDA Runtime Compilation API result code. type(c_ptr) :: string !! Pointer to C string end function nvrtcGetErrorString_interface end interface abstract interface function nvrtcCreateProgram_interface(prog, src, name, numHeaders, headers, includeNames) & result(nvrtcResult) !! Creates an instance of nvrtcProgram with the given input parameters, !! and sets the output parameter prog with it. import type(nvrtcProgram) :: prog !! CUDA Runtime Compilation program. character(c_char) :: src(*) !! CUDA program source. character(c_char) :: name(*) !! CUDA program name. integer(c_int), value :: numHeaders !! Number of headers used. Must be greater than or equal to 0. type(c_ptr), value :: headers !! Sources of the headers type(c_ptr), value :: includeNames !! Name of each header by which they can be included in the CUDA program source integer(c_int) :: nvrtcResult !! The enumerated type nvrtcResult defines API call result codes. end function nvrtcCreateProgram_interface end interface abstract interface function nvrtcDestroyProgram_interface(prog) & result(nvrtcResult) !! Destroys the given program. import type(nvrtcProgram) :: prog !! CUDA Runtime Compilation program. integer(c_int) :: nvrtcResult !! The enumerated type nvrtcResult defines API call result codes. end function nvrtcDestroyProgram_interface end interface abstract interface function nvrtcCompileProgram_interface(prog, numOptions, options) & result(nvrtcResult) !! Compiles the given program. import type(nvrtcProgram), value :: prog !! CUDA Runtime Compilation program. integer(c_int), value :: numOptions !! Number of compiler options passed. type(c_ptr) :: options(*) !! Compiler options in the form of C string array integer(c_int) :: nvrtcResult !! The enumerated type nvrtcResult defines API call result codes. end function nvrtcCompileProgram_interface end interface abstract interface function nvrtcGetProgramLogSize_interface(prog, logSizeRet) & result(nvrtcResult) !! Sets the value of ``logSizeRet`` with the size of the log generated by the previous compilation of ``prog``. !! The log is a null-terminated string. import type(nvrtcProgram), value :: prog !! CUDA Runtime Compilation program. integer(c_size_t) :: logSizeRet !! Size of the compilation log. integer(c_int) :: nvrtcResult !! The enumerated type nvrtcResult defines API call result codes. end function nvrtcGetProgramLogSize_interface end interface abstract interface function nvrtcGetProgramLog_interface(prog, log) & result(nvrtcResult) !! Stores the log generated by the previous compilation of prog in the memory pointed by log import type(nvrtcProgram), value :: prog !! CUDA Runtime Compilation program. type(c_ptr), value :: log !! Compilation log. integer(c_int) :: nvrtcResult !! The enumerated type nvrtcResult defines API call result codes. end function nvrtcGetProgramLog_interface end interface abstract interface function nvrtcGetCUBINSize_interface(prog, cubinSizeRet) & result(nvrtcResult) !! Sets the value of ``cubinSizeRet`` with the size of the cubin generated by the previous compilation of ``prog``. import type(nvrtcProgram), value :: prog !! CUDA Runtime Compilation program. integer(c_size_t) :: cubinSizeRet !! Size of the generated cubin. integer(c_int) :: nvrtcResult !! The enumerated type nvrtcResult defines API call result codes. end function nvrtcGetCUBINSize_interface end interface abstract interface function nvrtcGetCUBIN_interface(prog, cubin) & result(nvrtcResult) !! Stores the cubin generated by the previous compilation of ``prog`` in the memory pointed by ``cubin``. import type(nvrtcProgram), value :: prog !! CUDA Runtime Compilation program. type(c_ptr), value :: cubin !! Compiled and assembled result. integer(c_int) :: nvrtcResult !! The enumerated type nvrtcResult defines API call result codes. end function nvrtcGetCUBIN_interface end interface abstract interface function nvrtcGetLoweredName_interface(prog, name_expression, lowered_name) & result(nvrtcResult) !! Extracts the lowered (mangled) name for a global function or device/__constant__ variable, !! and updates *lowered_name to point to it. import type(nvrtcProgram), value :: prog !! CUDA Runtime Compilation program. character(c_char) :: name_expression(*) !! Name expression. type(c_ptr) :: lowered_name !! Mangled name. integer(c_int) :: nvrtcResult !! The enumerated type nvrtcResult defines API call result codes. end function nvrtcGetLoweredName_interface end interface abstract interface function nvrtcAddNameExpression_interface(prog, name_expression) & result(nvrtcResult) !! Notes the given name expression denoting the address of a global function or device/__constant__ variable. import type(nvrtcProgram), value :: prog !! CUDA Runtime Compilation program. character(c_char) :: name_expression(*) !! Name expression. integer(c_int) :: nvrtcResult !! The enumerated type nvrtcResult defines API call result codes. end function nvrtcAddNameExpression_interface end interface integer(int32), parameter :: N_FUNCTIONS_TO_LOAD = 10 !! Number of functions to load from nvrtc library logical, save :: is_loaded = .false. !! Flag indicating whether the library is loaded type(c_ptr), save :: libnvrtc !! Handle to the loaded library type(c_funptr), save :: nvrtcFunctions(N_FUNCTIONS_TO_LOAD) !! Array of pointers to the nvRTC functions procedure(nvrtcGetErrorString_interface), pointer :: nvrtcGetErrorString_c !! Fortran pointer to the nvrtcGetErrorString function procedure(nvrtcCreateProgram_interface), pointer, public :: nvrtcCreateProgram !! Fortran pointer to the nvrtcCreateProgram function procedure(nvrtcDestroyProgram_interface), pointer, public :: nvrtcDestroyProgram !! Fortran pointer to the nvrtcDestroyProgram function procedure(nvrtcCompileProgram_interface), pointer, public :: nvrtcCompileProgram !! Fortran pointer to the nvrtcCompileProgram function procedure(nvrtcGetProgramLogSize_interface),pointer, public :: nvrtcGetProgramLogSize !! Fortran pointer to the nvrtcGetProgramLogSize function procedure(nvrtcGetProgramLog_interface), pointer, public :: nvrtcGetProgramLog !! Fortran pointer to the nvrtcGetProgramLog function procedure(nvrtcGetCUBINSize_interface), pointer, public :: nvrtcGetCUBINSize !! Fortran pointer to the nvrtcGetCUBINSize function procedure(nvrtcGetCUBIN_interface), pointer, public :: nvrtcGetCUBIN !! Fortran pointer to the nvrtcGetCUBIN function procedure(nvrtcGetLoweredName_interface), pointer, public :: nvrtcGetLoweredName !! Fortran pointer to the nvrtcGetLoweredName function procedure(nvrtcAddNameExpression_interface),pointer, public :: nvrtcAddNameExpression !! Fortran pointer to the nvrtcAddNameExpression function contains function nvrtcGetErrorString(error_code) result(string) !! Helper function that returns a string describing the given nvrtcResult code !! For unrecognized enumeration values, it returns "NVRTC_ERROR unknown" integer(c_int), intent(in) :: error_code !! CUDA Runtime Compilation API result code. character(len=:), allocatable :: string !! Result string type(c_ptr) :: c_string !! Pointer to C string c_string = nvrtcGetErrorString_c(error_code) call string_c2f(c_string, string) end function nvrtcGetErrorString function load_nvrtc() result(error_code) !! Dynamically loads nvRTC library and its functions integer(int32) :: error_code !! Error code type(string), allocatable :: func_names(:) !! Array of function names to load error_code = DTFFT_SUCCESS if ( is_loaded ) return allocate(func_names(N_FUNCTIONS_TO_LOAD)) func_names(1) = string("nvrtcGetErrorString") func_names(2) = string("nvrtcCreateProgram") func_names(3) = string("nvrtcDestroyProgram") func_names(4) = string("nvrtcCompileProgram") func_names(5) = string("nvrtcGetProgramLog") func_names(6) = string("nvrtcGetCUBINSize") func_names(7) = string("nvrtcGetCUBIN") func_names(8) = string("nvrtcGetProgramLogSize") func_names(9) = string("nvrtcGetLoweredName") func_names(10) = string("nvrtcAddNameExpression") error_code = dynamic_load("libnvrtc.so", func_names, libnvrtc, nvrtcFunctions) call destroy_strings(func_names) if ( error_code /= DTFFT_SUCCESS ) return call c_f_procpointer(nvrtcFunctions(1), nvrtcGetErrorString_c) call c_f_procpointer(nvrtcFunctions(2), nvrtcCreateProgram) call c_f_procpointer(nvrtcFunctions(3), nvrtcDestroyProgram) call c_f_procpointer(nvrtcFunctions(4), nvrtcCompileProgram) call c_f_procpointer(nvrtcFunctions(5), nvrtcGetProgramLog) call c_f_procpointer(nvrtcFunctions(6), nvrtcGetCUBINSize) call c_f_procpointer(nvrtcFunctions(7), nvrtcGetCUBIN) call c_f_procpointer(nvrtcFunctions(8), nvrtcGetProgramLogSize) call c_f_procpointer(nvrtcFunctions(9), nvrtcGetLoweredName) call c_f_procpointer(nvrtcFunctions(10), nvrtcAddNameExpression) is_loaded = .true. end function load_nvrtc end module dtfft_interface_nvrtc