Skip to content

How to properly make batch inference? #343

@IsidreMas

Description

@IsidreMas

Hello FTorch developers,

I'm a new user of the library trying to include it in an atmospheric chemistry emulator to interface with ML models. So far I managed to reproduce the inference results running the model 1 sample at a time, but when attempting to do batch inference the results are not numerically correct. Do you know what I could be doing wrong?

Here's the example code I'm using to load a dataset from a netcdf file and write the predicitons of the model to a file:

batchintest.F90
program read_toy
    use, intrinsic :: iso_fortran_env, only: sp => real32, int64, dp => real64
    use netcdf
    ! IMPORT TORCH INTERFACE MODULES
    use ftorch, only: torch_model, torch_tensor, &
         torch_kCPU, torch_kCUDA, torch_kXPU, torch_kMPS, &
         torch_tensor_from_array, torch_model_load, torch_model_forward, &
         torch_delete
    implicit none

    ! SET PRECISION
    integer, parameter :: wp = sp

    integer :: ncid, status, varid, i, num_samples
    real(wp), allocatable :: gas_conc(:,:,:) ! 3D array to store gas concentrations

    ! COMMAND-LINE ARGUMENTS
    integer :: num_args, ix, device_index
    character(len=128), dimension(:), allocatable :: args

    ! USER CONFIGURATION: CHANGE INPUT/OUTPUT DIMENSIONS AS NEEDED.
    integer, parameter :: n_features = 3      ! (E.G., FOR CALIFORNIA HOUSING, THIS IS 8)
    integer, parameter :: n_outputs  = 3        ! (CHANGE IF YOUR MODEL RETURNS MULTIPLE VALUES)

    ! INPUT DATA ARRAY (SAMPLE): MODIFY THESE VALUES FOR YOUR USE-CASE.
    real(wp), dimension(:,:), target, allocatable :: in_data
    ! OUTPUT DATA ARRAY (INITIALLY ZEROED)
    real(wp), dimension(:,:), target, allocatable :: out_data

    ! DEFINE TENSOR LAYOUT: BATCH SIZE = 1 (SINGLE SAMPLE).
    integer, parameter :: tensor_layout(2) = [1,2]

    ! TORCH OBJECTS
    type(torch_model) :: model
    type(torch_tensor), dimension(:), allocatable :: in_tensors
    type(torch_tensor), dimension(:), allocatable :: out_tensors

    ! DEVICE SELECTION VARIABLES
    integer :: device_type, num_devices

    ! FILE I/O UNIT NUMBER FOR SAVING PREDICTIONS
    integer, parameter :: outfile_unit = 10

    ! TIMING VARIABLES
    integer(int64) :: start_time, end_time, count_rate
    real(dp) :: elapsed_time, total_time = 0.0_dp

    ! PARSE COMMAND-LINE ARGUMENTS
    num_args = command_argument_count()
    allocate(args(num_args))
    do ix = 1, num_args
        call get_command_argument(ix, args(ix))
    end do

    ! DEVICE SELECTION BASED ON FIRST ARGUMENT (REQUIRED)
    if (trim(args(1)) == "cuda") then
        device_type = torch_kCUDA
        num_devices = 2
    else if (trim(args(1)) == "xpu") then
        device_type = torch_kXPU
        num_devices = 2
    else if (trim(args(1)) == "mps") then
        device_type = torch_kMPS
        num_devices = 1
    else if (trim(args(1)) == "cpu") then
        device_type = torch_kCPU
        num_devices = 1
    else
        write (*,*) "Error: invalid device type", trim(args(1))
        stop 999
    end if

    ! Open NetCDF file
    status = nf90_open("toy.nc", NF90_NOWRITE, ncid)
    if (status /= nf90_noerr) call handle_err(status)

    ! Read gas concentrations (3D: species × experiment × time)
    status = nf90_inq_varid(ncid, "gas_species_concentrations", varid)
    call handle_err(status)
    allocate(gas_conc(3, 1010688, 2))
    status = nf90_get_var(ncid, varid, gas_conc)
    print *, "First gas concentration:", gas_conc(:,1,1)

    num_samples = size(gas_conc, 2)

    allocate(in_tensors(1))
    allocate(in_data(num_samples, n_features))
    allocate(out_tensors(1))
    allocate(out_data(num_samples, n_outputs))

    ! LOAD THE TORCHSCRIPT MODEL FROM THE FILE (SECOND ARGUMENT)
    call torch_model_load(model, args(2), device_type)

    ! Transpose gas concentrations to match the input tensor shape
    in_data = transpose(gas_conc(:, :, 1))
    
    print *, "First few rows of gas_conc:"
    do i = 1, 5
        print *, gas_conc(:, i, 1)
    end do

    ! Get the system clock rate
    call system_clock(count_rate=count_rate)

    print *, "First few rows of input data:"
    do i = 1, 5
        print *, in_data(i, :)
    end do

    ! CREATE THE TORCH INPUT TENSOR ON THE SELECTED DEVICE
    call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, device_type)

    ! CREATE THE TORCH OUTPUT TENSOR ON THE CPU
    call torch_tensor_from_array(out_tensors(1), out_data, tensor_layout, torch_kCPU)

    print *, "First input tensor:", in_tensors(1)%get_shape()
    print *, "First output tensor:", out_tensors(1)%get_shape()
    print *, "in_data shape:", shape(in_data)
    print *, "out_data shape:", shape(out_data)

    ! TIME THE FORWARD PASS
    call system_clock(start_time)
    call torch_model_forward(model, in_tensors, out_tensors)
    call system_clock(end_time)

    ! Calculate elapsed time in seconds
    elapsed_time = real(end_time - start_time, dp) / real(count_rate, dp)
    total_time = total_time + elapsed_time

    ! Print total prediction time
    print '(a, f12.6, a)', 'Total prediction time: ', total_time, ' seconds'

    ! Open predictions.txt file for writing (overwrite if exists)
    open(unit=outfile_unit, file="predictions.txt", action="write", status="replace")
    ! WRITE THE OUTPUT DATA TO predictions.txt
    write(outfile_unit, '(3(F20.10))') out_data

    ! CLEANUP: DELETE THE MODEL AND TENSORS
    call torch_delete(model)
    call torch_delete(in_tensors)
    call torch_delete(out_tensors)

    ! Close the output file
    close(outfile_unit)

    ! Cleanup
    deallocate(gas_conc)
    status = nf90_close(ncid)

contains
    subroutine handle_err(status)
        integer, intent(in) :: status
        if (status /= nf90_noerr) then
            print *, "Error: ", trim(nf90_strerror(status))
            stop
        end if
    end subroutine
end program read_toy

Thank you for any hint or advice that you could provide on this issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions