!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!
!! $Id: states.F90 10978 2013-07-11 15:28:46Z micael $

#include "global.h"

module states_m
  use blacs_proc_grid_m
  use calc_mode_m
#ifdef HAVE_OPENCL
  use cl
#endif
  use comm_m
  use batch_m
  use batch_ops_m
  use blas_m
  use datasets_m
  use derivatives_m
  use distributed_m
  use geometry_m
  use global_m
  use grid_m
  use hardware_m
  use io_m
  use kpoints_m
  use lalg_adv_m
  use lalg_basic_m
  use loct_pointer_m
  use math_m
  use mesh_m
  use mesh_function_m
  use messages_m
  use modelmb_particles_m
  use mpi_m ! if not before parser_m, ifort 11.072 can`t compile with MPI2
  use mpi_lib_m
  use multicomm_m
#ifdef HAVE_OPENMP
  use omp_lib
#endif
  use opencl_m
  use parser_m
  use profiling_m
  use simul_box_m
  use smear_m
  use states_dim_m
  use symmetrizer_m
  use types_m
  use unit_m
  use unit_system_m
  use utils_m
  use varinfo_m

  implicit none

  private

  public ::                           &
    states_t,                         &
    states_priv_t,                    &
    states_lead_t,                    &
    states_init,                      &
    states_look,                      &
    states_densities_init,            &
    states_exec_init,                 &
    states_allocate_wfns,             &
    states_allocate_intf_wfns,        &
    states_deallocate_wfns,           &
    states_null,                      &
    states_end,                       &
    states_copy,                      &
    states_generate_random,           &
    states_fermi,                     &
    states_eigenvalues_sum,           &
    states_lead_densities_init,       &
    states_lead_densities_end,        &
    states_spin_channel,              &
    states_calc_quantities,           &
    state_is_local,                   &
    states_distribute_nodes,          &
    states_wfns_memory,               &
    states_are_complex,               &
    states_are_real,                  &
    states_set_complex,               &
    states_blacs_blocksize,           &
    states_get_state,                 &
    states_set_state,                 &
    states_pack,                      &
    states_unpack,                    &
    states_sync,                      &
    states_are_packed,                &
    states_write_info,                &
    states_set_zero,                  &
    states_block_min,                 &
    states_block_max,                 &
    states_block_size,                &
    states_resize_unocc,              &
    zstates_eigenvalues_sum,          &
    cmplx_array2_t,                   &
    states_wfs_t

  !> cmplxscl: Left and Right eigenstates
  type states_wfs_t    
    CMPLX, pointer     :: zL(:, :, :, :) !< (np, st%d%dim, st%nst, st%d%nik)
    CMPLX, pointer     :: zR(:, :, :, :) !< (np, st%d%dim, st%nst, st%d%nik)
    FLOAT, pointer     :: dL(:, :, :, :) !< (np, st%d%dim, st%nst, st%d%nik)
    FLOAT, pointer     :: dR(:, :, :, :) !< (np, st%d%dim, st%nst, st%d%nik)
  end type states_wfs_t
  
  !>cmplxscl: complex 2D matrices 
  type cmplx_array2_t    
    FLOAT, pointer     :: Re(:, :) !< Real components 
    FLOAT, pointer     :: Im(:, :) !< Imaginary components
  end type cmplx_array2_t

  type states_lead_t
    CMPLX, pointer     :: intf_psi(:, :, :, :) !< (np, st%d%dim, st%nst, st%d%nik)
    FLOAT, pointer     :: rho(:, :)   !< Density of the lead unit cells.
    CMPLX, pointer     :: self_energy(:, :, :, :, :) !< (np, np, nspin, ncs, nik) self-energy of the leads.
  end type states_lead_t

  type states_priv_t
    private
    type(type_t) :: wfs_type              !< real (TYPE_FLOAT) or complex (TYPE_CMPLX) wavefunctions
  end type states_priv_t

  type states_t
    type(states_dim_t)       :: d
    type(modelmb_particle_t) :: modelmbparticles
    type(states_priv_t)      :: priv                  !< the private components
    integer                  :: nst                   !< Number of states in each irreducible subspace

    logical                  :: only_userdef_istates  !< only use user-defined states as initial states in propagation
    !> pointers to the wavefunctions
    FLOAT, pointer           :: dpsi(:,:,:,:)         !< dpsi(sys%gr%mesh%np_part, st%d%dim, st%nst, st%d%nik)
    CMPLX, pointer           :: zpsi(:,:,:,:)         !< zpsi(sys%gr%mesh%np_part, st%d%dim, st%nst, st%d%nik)
   
    !> Pointers to complexified quantities. 
    !! When we use complex scaling the Hamiltonian is no longer hermitian.
    !! In this case we have to distinguish between left and right eigenstates of H and
    !! both density and eigenvalues become complex.
    !! In order to modify the code to include this changes we allocate the general structures and 
    !! make the restricted quantities point to a part of the structure.
    ! For instance for the orbitals we allocate psi and make zpsi to point only to Right states as follows:
    !! zpsi => psi%zR
    !! Similarly for density and eigenvalues we make the old quantities to point to the real part:
    !! rho => zrho%Re
    !! eigenval => zeigenval%Re  
    type(states_wfs_t)       :: psi          !< cmplxscl: Left psi%zL(:,:,:,:) and Right psi%zR(:,:,:,:) orbitals    
    type(cmplx_array2_t)     :: zrho         !< cmplxscl: the complexified density <psi%zL(:,:,:,:)|psi%zR(:,:,:,:)>
    type(cmplx_array2_t)     :: zeigenval    !< cmplxscl: the complexified eigenvalues 
    FLOAT,           pointer :: Imrho_core(:)  
    FLOAT,           pointer :: Imfrozen_rho(:, :)   


    type(batch_t), pointer   :: psib(:, :)            !< A set of wave-functions blocks
    integer                  :: nblocks               !< The number of blocks
    integer                  :: block_start           !< The lowest index of local blocks
    integer                  :: block_end             !< The highest index of local blocks
    integer, pointer         :: iblock(:, :)          !< A map, that for each state index, returns the index of block containing it
    integer, pointer         :: block_range(:, :)     !< Each block contains states from block_range(:, 1) to block_range(:, 2)
    integer, pointer         :: block_size(:)         !< The number of states in each block.
    logical, pointer         :: block_is_local(:, :)  !< It is true if the block is in this node.
    logical                  :: block_initialized     !< For keeping track of the blocks to avoid memory leaks

    logical             :: open_boundaries
    CMPLX, pointer      :: zphi(:, :, :, :)  !< Free states for open-boundary calculations.
    FLOAT, pointer      :: ob_eigenval(:, :) !< Eigenvalues of free states.
    type(states_dim_t)  :: ob_d              !< Dims. of the unscattered systems.
    integer             :: ob_nst            !< nst of the unscattered systems.
    FLOAT, pointer      :: ob_occ(:, :)      !< occupations
    type(states_lead_t) :: ob_lead(2*MAX_DIM)

    !> used for the user-defined wavefunctions (they are stored as formula strings)
    !! (st%d%dim, st%nst, st%d%nik)
    character(len=1024), pointer :: user_def_states(:,:,:)

    !> the densities and currents (after all we are doing DFT :)
    FLOAT, pointer :: rho(:,:)         !< rho(gr%mesh%np_part, st%d%nspin)
    FLOAT, pointer :: current(:, :, :) !<   current(gr%mesh%np_part, gr%sb%dim, st%d%nspin)


    FLOAT, pointer :: rho_core(:)      !< core charge for nl core corrections
    logical        :: current_in_tau   !< are we using in tau the term which depends on the paramagnetic current?

    !> It may be required to "freeze" the deepest orbitals during the evolution; the density
    !! of these orbitals is kept in frozen_rho. It is different from rho_core.
    FLOAT, pointer :: frozen_rho(:, :)

    FLOAT, pointer :: eigenval(:,:) !< obviously the eigenvalues
    logical        :: fixed_occ     !< should the occupation numbers be fixed?
    logical        :: restart_fixed_occ !< should the occupation numbers be fixed by restart?
    logical        :: restart_reorder_occs !< used for restart with altered occupation numbers
    FLOAT, pointer :: occ(:,:)      !< the occupation numbers
    logical        :: fixed_spins   !< In spinors mode, the spin direction is set
                                    !< for the initial (random) orbitals.
    FLOAT, pointer :: spin(:, :, :)

    FLOAT          :: qtot          !< (-) The total charge in the system (used in Fermi)
    FLOAT          :: val_charge    !< valence charge

    logical        :: fromScratch
    type(smear_t)  :: smear         ! smearing of the electronic occupations

    !> This is stuff needed for the parallelization in states.
    logical                     :: parallel_in_states !< Am I parallel in states?
    type(mpi_grp_t)             :: mpi_grp            !< The MPI group related to the parallelization in states.
    type(mpi_grp_t)             :: dom_st_mpi_grp     !< The MPI group related to the domains-states "plane".
    type(mpi_grp_t)             :: st_kpt_mpi_grp     !< The MPI group related to the states-kpoints "plane".
    type(mpi_grp_t)             :: dom_st_kpt_mpi_grp !< The MPI group related to the domains-states-kpoints "cube".
#ifdef HAVE_SCALAPACK
    type(blacs_proc_grid_t)     :: dom_st_proc_grid   !< The BLACS process grid for the domains-states plane
#endif
    integer                     :: lnst               !< Number of states on local node.
    integer                     :: st_start, st_end   !< Range of states processed by local node.
    integer, pointer            :: node(:)            !< To which node belongs each state.
    !> Node r manages states st_range(1, r) to
    !! st_range(2, r) for r = 0, ..., mpi_grp%size-1,
    !! i. e. st_start = st_range(1, r) and
    !! st_end = st_range(2, r) on node r.
    integer, pointer            :: st_range(:, :)  
    !> Number of states on node r, i. e.
    !! st_num(r) = st_num(2, r)-st_num(1, r).
    integer, pointer            :: st_num(:)         
    type(multicomm_all_pairs_t) :: ap                 !< All-pairs schedule.

    logical                     :: symmetrize_density
    logical                     :: packed
  end type states_t

  interface states_get_state
    module procedure dstates_get_state1, zstates_get_state1, dstates_get_state2, zstates_get_state2
  end interface states_get_state

  interface states_set_state
    module procedure dstates_set_state1, zstates_set_state1, dstates_set_state2, zstates_set_state2
  end interface states_set_state

contains

  ! ---------------------------------------------------------
  subroutine states_null(st)
    type(states_t), intent(inout) :: st

    integer :: il

    PUSH_SUB(states_null)

    call states_dim_null(st%d)
    st%d%orth_method = 0
    call modelmb_particles_nullify(st%modelmbparticles)
    st%priv%wfs_type = TYPE_FLOAT ! By default, calculations use real wavefunctions

    st%d%cmplxscl = .false.
    !cmplxscl
    nullify(st%psi%dL, st%psi%dR)
    nullify(st%psi%zL, st%psi%zR)     
    nullify(st%zeigenval%Re, st%zeigenval%Im) 
    nullify(st%zrho%Re, st%zrho%Im)
    nullify(st%Imrho_core, st%Imfrozen_rho)


    nullify(st%dpsi, st%zpsi)
    nullify(st%psib, st%iblock, st%block_is_local)
    nullify(st%block_range)
    st%block_initialized = .false.

    
    nullify(st%zphi, st%ob_eigenval, st%ob_occ)
    st%open_boundaries = .false.
    call states_dim_null(st%ob_d)
    do il = 1, 2*MAX_DIM
      nullify(st%ob_lead(il)%intf_psi, st%ob_lead(il)%rho, st%ob_lead(il)%self_energy)
    end do

    nullify(st%user_def_states)
    nullify(st%rho, st%current)
    nullify(st%rho_core, st%frozen_rho)
    nullify(st%eigenval, st%occ, st%spin)

    st%parallel_in_states = .false.
#ifdef HAVE_SCALAPACK
    call blacs_proc_grid_nullify(st%dom_st_proc_grid)
#endif
    nullify(st%node,st%st_range, st%st_num)
    nullify(st%ap%schedule)

    st%packed = .false.

    POP_SUB(states_null)
  end subroutine states_null


  ! ---------------------------------------------------------
  subroutine states_init(st, gr, geo)
    type(states_t), target, intent(inout) :: st
    type(grid_t),           intent(in)    :: gr
    type(geometry_t),       intent(in)    :: geo

    FLOAT :: excess_charge
    integer :: nempty, ierr, il, ntot, default, nthreads
    integer, allocatable :: ob_k(:), ob_st(:), ob_d(:)
    character(len=256)   :: restart_dir

    PUSH_SUB(states_init)

    st%fromScratch = .true. ! this will be reset if restart_read is called
    call states_null(st)

    !%Variable StatesBlockSize
    !%Type integer
    !%Section Execution::Optimization
    !%Description
    !% Some routines work over blocks of eigenfunctions, which
    !% generally improves performance at the expense of increased
    !% memory consumption. This variable selects the size of the
    !% blocks to be used. If OpenCl is enabled, the default is 32;
    !% otherwise it is max(4, 2*nthreads).
    !%End

    nthreads = 1
#ifdef HAVE_OPENMP
    !$omp parallel
    !$omp master
    nthreads = omp_get_num_threads()
    !$omp end master
    !$omp end parallel
#endif    

    if(opencl_is_enabled()) then
      default = 32
    else
      default = max(4, 2*nthreads)
    end if

    call parse_integer(datasets_check('StatesBlockSize'), default, st%d%block_size)
    if(st%d%block_size < 1) then
      message(1) = "The variable 'StatesBlockSize' must be greater than 0."
      call messages_fatal(1)
    end if

    ASSERT(st%d%block_size > 0)

    !%Variable SpinComponents
    !%Type integer
    !%Default unpolarized
    !%Section States
    !%Description
    !% The calculations may be done in three different ways: spin-restricted (TD)DFT (<i>i.e.</i>, doubly
    !% occupied "closed shells"), spin-unrestricted or "spin-polarized" (TD)DFT (<i>i.e.</i> we have two
    !% electronic systems, one with spin up and one with spin down), or making use of two-component
    !% spinors.
    !%Option unpolarized 1
    !% Spin-restricted calculations.
    !%Option polarized 2
    !%Option spin_polarized 2
    !% Spin-unrestricted, also known as spin-DFT, SDFT. This mode will double the number of
    !% wavefunctions necessary for a spin-unpolarized calculation.
    !%Option non_collinear 3
    !%Option spinors 3
    !% (Synonym: <tt>non_collinear</tt>.) The spin-orbitals are two-component spinors. This effectively allows the spin-density to
    !% be oriented non-collinearly: <i>i.e.</i> the magnetization vector is allowed to take different
    !% directions at different points. This vector is always in 3D regardless of <tt>Dimensions</tt>.
    !%End
    call parse_integer(datasets_check('SpinComponents'), UNPOLARIZED, st%d%ispin)
    if(.not.varinfo_valid_option('SpinComponents', st%d%ispin)) call input_error('SpinComponents')
    call messages_print_var_option(stdout, 'SpinComponents', st%d%ispin)
    ! Use of spinors requires complex wavefunctions.
    if (st%d%ispin == SPINORS) st%priv%wfs_type = TYPE_CMPLX


    !%Variable ExcessCharge
    !%Type float
    !%Default 0.0
    !%Section States
    !%Description
    !% The net charge of the system. A negative value means that we are adding
    !% electrons, while a positive value means we are taking electrons
    !% from the system.
    !%End
    call parse_float(datasets_check('ExcessCharge'), M_ZERO, excess_charge)


    !%Variable TotalStates
    !%Type integer
    !%Default 0
    !%Section States
    !%Description
    !% This variable sets the total number of states that Octopus will
    !% use. This is normally not necessary since by default Octopus
    !% sets the number of states to the minimum necessary to hold the
    !% electrons present in the system. (This default behavior is
    !% obtained by setting <tt>TotalStates</tt> to 0).
    !%
    !% If you want to add some unoccupied states, probably it is more convenient to use the variable
    !% <tt>ExtraStates</tt>.
    !%
    !% Note that this number is unrelated to <tt>CalculationMode == unocc</tt>.
    !%End
    call parse_integer(datasets_check('TotalStates'), 0, ntot)
    if (ntot < 0) then
      write(message(1), '(a,i5,a)') "Input: '", ntot, "' is not a valid value for TotalStates."
      call messages_fatal(1)
    end if

    !%Variable ExtraStates
    !%Type integer
    !%Default 0
    !%Section States
    !%Description
    !% The number of states is in principle calculated considering the minimum
    !% numbers of states necessary to hold the electrons present in the system.
    !% The number of electrons is
    !% in turn calculated considering the nature of the species supplied in the
    !% <tt>Species</tt> block, and the value of the <tt>ExcessCharge</tt> variable.
    !% However, one may command <tt>Octopus</tt> to use more states, which is necessary if one wants to
    !% use fractional occupational numbers, either fixed from the beginning through
    !% the <tt>Occupations</tt> block or by prescribing
    !% an electronic temperature with <tt>Smearing</tt>.
    !%
    !% Note that this number is unrelated to <tt>CalculationMode == unocc</tt>.
    !% <tt>ExtraStates</tt> is used for a self-consistent calculation and
    !% the usual convergence criteria on the density do not take into account the
    !% eigenvalues, whereas <tt>unocc</tt> is a non-self-consistent calculation,
    !% and explicitly considers the eigenvalues of the unoccupied states as the
    !% convergence criteria.
    !%End
    call parse_integer(datasets_check('ExtraStates'), 0, nempty)
    if (nempty < 0) then
      write(message(1), '(a,i5,a)') "Input: '", nempty, "' is not a valid value for ExtraStates."
      message(2) = '(0 <= ExtraStates)'
      call messages_fatal(2)
    end if

    if(ntot > 0 .and. nempty > 0) then
      message(1) = 'You cannot set TotalStates and ExtraStates at the same time.'
      call messages_fatal(1)
    end if

    ! For non-periodic systems this should just return the Gamma point
    call states_choose_kpoints(st%d, gr%sb)

    call geometry_val_charge(geo, st%val_charge)

    if(gr%ob_grid%open_boundaries) then
      ! renormalize charge of central region to match leads (open system, not finite)
      st%val_charge = st%val_charge * (gr%ob_grid%lead(LEFT)%sb%lsize(TRANS_DIR) / gr%sb%lsize(TRANS_DIR))
    end if

    st%qtot = -(st%val_charge + excess_charge)

    do il = 1, NLEADS
      nullify(st%ob_lead(il)%intf_psi)
    end do
    ! When doing open-boundary calculations the number of free states is
    ! determined by the previous periodic calculation.
    st%open_boundaries = gr%ob_grid%open_boundaries
    if(gr%ob_grid%open_boundaries) then
      SAFE_ALLOCATE( ob_k(1:NLEADS))
      SAFE_ALLOCATE(ob_st(1:NLEADS))
      SAFE_ALLOCATE( ob_d(1:NLEADS))
      do il = 1, NLEADS
        restart_dir = trim(trim(gr%ob_grid%lead(il)%info%restart_dir)//'/'// GS_DIR)
        ! first get nst and kpoints of all states
        call states_look(restart_dir, mpi_world, ob_k(il), ob_d(il), ob_st(il), ierr)
        if(ierr.ne.0) then
          message(1) = 'Could not read the number of states of the periodic calculation'
          message(2) = 'from '//restart_dir//'.'
          call messages_fatal(2)
        end if
      end do
      if(NLEADS.gt.1) then
        if(ob_k(LEFT).ne.ob_k(RIGHT).or. &
          ob_st(LEFT).ne.ob_st(LEFT).or. &
          ob_d(LEFT).ne.ob_d(RIGHT)) then
          message(1) = 'The number of states for the left and right leads are not equal.'
          call messages_fatal(1)
        end if
      end if
      st%ob_d%dim = ob_d(LEFT)
      st%ob_nst   = ob_st(LEFT)
      st%ob_d%nik = ob_k(LEFT)
      st%d%nik = st%ob_d%nik
      SAFE_DEALLOCATE_A(ob_d)
      SAFE_DEALLOCATE_A(ob_st)
      SAFE_DEALLOCATE_A(ob_k)
      call distributed_nullify(st%ob_d%kpt, 0)
      if((st%d%ispin.eq.UNPOLARIZED.and.st%ob_d%dim.ne.1) .or.   &
        (st%d%ispin.eq.SPIN_POLARIZED.and.st%ob_d%dim.ne.1) .or. &
        (st%d%ispin.eq.SPINORS.and.st%ob_d%dim.ne.2)) then
        message(1) = 'The spin type of the leads calculation from '&
                     //gr%ob_grid%lead(LEFT)%info%restart_dir
        message(2) = 'and SpinComponents of the current run do not match.'
        call messages_fatal(2)
      end if
      SAFE_DEALLOCATE_P(st%d%kweights)
      SAFE_ALLOCATE(st%d%kweights(1:st%d%nik))
      st%d%kweights = M_ZERO
      st%d%kweights(1) = M_ONE
      SAFE_ALLOCATE(st%ob_d%kweights(1:st%ob_d%nik))
      SAFE_ALLOCATE(st%ob_eigenval(1:st%ob_nst, 1:st%ob_d%nik))
      SAFE_ALLOCATE(st%ob_occ(1:st%ob_nst, 1:st%ob_d%nik))
      st%ob_d%kweights = M_ZERO
      st%ob_eigenval   = huge(st%ob_eigenval)
      st%ob_occ        = M_ZERO
      call read_ob_eigenval_and_occ()
    else
      st%ob_nst   = 0
      st%ob_d%nik = 0
      st%ob_d%dim = 0
    end if

    select case(st%d%ispin)
    case(UNPOLARIZED)
      st%d%dim = 1
      st%nst = int(st%qtot/2)
      if(st%nst*2 < st%qtot) st%nst = st%nst + 1
      st%d%nspin = 1
      st%d%spin_channels = 1
    case(SPIN_POLARIZED)
      st%d%dim = 1
      st%nst = int(st%qtot/2)
      if(st%nst*2 < st%qtot) st%nst = st%nst + 1
      st%d%nspin = 2
      st%d%spin_channels = 2
    case(SPINORS)
      st%d%dim = 2
      st%nst = int(st%qtot)
      if(st%nst < st%qtot) st%nst = st%nst + 1
      st%d%nspin = 4
      st%d%spin_channels = 2
    end select
     
    if(ntot > 0) then
      if(ntot < st%nst) then
        message(1) = 'TotalStates is smaller than the number of states required by the system.'
        call messages_fatal(1)
      end if

      st%nst = ntot
    end if

    st%nst = st%nst + nempty


    ! FIXME: For now, open-boundary calculations are only possible for
    ! continuum states, i.e. for those states treated by the Lippmann-
    ! Schwinger approach during SCF.
    ! Bound states should be done with extra states, without k-points.
    if(gr%ob_grid%open_boundaries) then
      if(st%nst.ne.st%ob_nst .or. st%d%nik.ne.st%ob_d%nik) then
        message(1) = 'Open-boundary calculations for possibly bound states'
        message(2) = 'are not possible yet. You have to match your number'
        message(3) = 'of states to the number of free states of your previous'
        message(4) = 'periodic run.'
        write(message(5), '(a,i5,a)') 'Your central region contributes ', st%nst, ' states,'
        write(message(6), '(a,i5,a)') 'while your lead calculation had ', st%ob_nst, ' states.'
        write(message(7), '(a,i5,a)') 'Your central region contributes ', st%d%nik, ' k-points,'
        write(message(8), '(a,i5,a)') 'while your lead calculation had ', st%ob_d%nik, ' k-points.'
        call messages_fatal(8)
      end if
    end if

    !%Variable CurrentDFT
    !%Type logical
    !%Default false
    !%Section Hamiltonian
    !%Description
    !% (experimental) If set to yes, Current-DFT will be used. This is the
    !% extension to DFT that should be used when external magnetic fields are
    !% present. The current-dependent part of the XC functional is set using the
    !% <tt>JFunctional</tt> variable. The default is no.
    !%End
    call parse_logical(datasets_check('CurrentDFT'), .false., st%d%cdft)
    if (st%d%cdft) then
      call messages_experimental('Current DFT')

      ! Use of CDFT requires complex wavefunctions
      st%priv%wfs_type = TYPE_CMPLX

      if(st%d%ispin == SPINORS) then
        message(1) = "Sorry, current DFT not working yet for spinors."
        call messages_fatal(1)
      end if
      message(1) = "Info: Using current DFT"
      call messages_info(1)
    end if

    !%Variable ComplexScaling
    !%Type logical
    !%Default false
    !%Section Hamiltonian
    !%Description
    !% (experimental) If set to yes, a complex scaled Hamiltonian will be used. 
    !% When <tt>TheoryLevel=DFT</tt> Density functional resonance theory DFRT is employed.  
    !% In order to reveal resonances <tt>ComplexScalingAngle</tt> bigger than zero should be set.
    !% D. L. Whitenack and A. Wasserman, Phys. Rev. Lett. 107, 163002 (2011).
    !%End
    call parse_logical(datasets_check('ComplexScaling'), .false., st%d%cmplxscl)


    if (st%d%cmplxscl) then
      call messages_experimental('Complex Scaling')
      call messages_print_var_value(stdout, "ComplexScaling", st%d%cmplxscl)

      !Even for gs calculations it requires complex wavefunctions
      st%priv%wfs_type = TYPE_CMPLX
      !Allocate imaginary parts of the eigenvalues
      SAFE_ALLOCATE(st%zeigenval%Im(1:st%nst, 1:st%d%nik))
      st%zeigenval%Im = M_ZERO
    end if
    SAFE_ALLOCATE(st%zeigenval%Re(1:st%nst, 1:st%d%nik))
    st%zeigenval%Re = huge(st%zeigenval%Re)
    st%eigenval => st%zeigenval%Re(1:st%nst, 1:st%d%nik) 


    ! Periodic systems require complex wavefunctions
    ! but not if it is Gamma-point only
    if(simul_box_is_periodic(gr%sb)) then
      if(.not. (kpoints_number(gr%sb%kpoints) == 1 .and. kpoints_point_is_gamma(gr%sb%kpoints, 1))) then
        st%priv%wfs_type = TYPE_CMPLX
      endif
    endif

    ! Calculations with open boundaries require complex wavefunctions.
    if(gr%ob_grid%open_boundaries) st%priv%wfs_type = TYPE_CMPLX

    !%Variable OnlyUserDefinedInitialStates
    !%Type logical
    !%Default no
    !%Section States
    !%Description
    !% If true, then only user-defined states from the block <tt>UserDefinedStates</tt>
    !% will be used as initial states for a time-propagation. No attempt is made
    !% to load ground-state orbitals from a previous ground-state run.
    !%End
    call parse_logical(datasets_check('OnlyUserDefinedInitialStates'), .false., st%only_userdef_istates)

    !%Variable CurrentInTau
    !%Type logical
    !%Default yes
    !%Section States
    !%Description
    !% If true, a term including the (paramagnetic or total) current is included in the calculation of the kinetic-energy density.
    !%End
    call parse_logical(datasets_check('CurrentInTau'), .true., st%current_in_tau)


    ! we now allocate some arrays
    SAFE_ALLOCATE(st%occ     (1:st%nst, 1:st%d%nik))
    st%occ      = M_ZERO
    ! allocate space for formula strings that define user-defined states
    SAFE_ALLOCATE(st%user_def_states(1:st%d%dim, 1:st%nst, 1:st%d%nik))
    if(st%d%ispin == SPINORS) then
      SAFE_ALLOCATE(st%spin(1:3, 1:st%nst, 1:st%d%nik))
    else
      nullify(st%spin)
    end if

    ! initially we mark all 'formulas' as undefined
    st%user_def_states(1:st%d%dim, 1:st%nst, 1:st%d%nik) = 'undefined'

    call states_read_initial_occs(st, excess_charge)
    call states_read_initial_spins(st)

    nullify(st%zphi)

    st%st_start = 1
    st%st_end = st%nst
    st%lnst = st%nst
    SAFE_ALLOCATE(st%node(1:st%nst))
    st%node(1:st%nst) = 0

    call mpi_grp_init(st%mpi_grp, -1)
    st%parallel_in_states = .false.

    nullify(st%dpsi, st%zpsi)

    call distributed_nullify(st%d%kpt, st%d%nik)

    call modelmb_particles_init (st%modelmbparticles,gr)

    !%Variable SymmetrizeDensity
    !%Type logical
    !%Default no
    !%Section States
    !%Description
    !% When enabled the density is symmetrized. Currently, this can
    !% only be done for periodic systems.
    !%
    !% It is enabled by default when symmetries are used to reduce the
    !% k-point grid (KPointsUseSymmetries = yes), otherwise it is
    !% disabled by default.
    !%End
    call parse_logical(datasets_check('SymmetrizeDensity'), gr%sb%kpoints%use_symmetries, st%symmetrize_density)
    call messages_print_var_value(stdout, 'SymmetrizeDensity', st%symmetrize_density)

#ifdef HAVE_SCALAPACK
    call blacs_proc_grid_nullify(st%dom_st_proc_grid)
#endif

    st%packed = .false.

    POP_SUB(states_init)

  contains

    subroutine read_ob_eigenval_and_occ()
      integer            :: occs, ist, ik, idim, idir, err
      FLOAT              :: flt, eigenval, occ, kweights
      character          :: char
      character(len=256) :: restart_dir, line, chars

      PUSH_SUB(states_init.read_ob_eigenval_and_occ)

      restart_dir = trim(gr%ob_grid%lead(LEFT)%info%restart_dir)//'/'//GS_DIR

      occs = io_open(trim(restart_dir)//'/occs', action='read', is_tmp=.true., grp=mpi_world)
      if(occs .lt. 0) then
        message(1) = 'Could not read '//trim(restart_dir)//'/occs.'
        call messages_fatal(1)
      end if

      ! Skip two lines.
      call iopar_read(mpi_world, occs, line, err)
      call iopar_read(mpi_world, occs, line, err)

      do
        ! Check for end of file.
        call iopar_read(mpi_world, occs, line, err)

        read(line, '(a)') char
        if(char .eq. '%') exit
        call iopar_backspace(mpi_world, occs)

        ! Extract eigenvalue.
        call iopar_read(mpi_world, occs, line, err)
        ! # occupations | eigenvalue[a.u.] | k-points | k-weights | filename | ik | ist | idim
        read(line, *) occ, char, eigenval, char, (flt, char, idir = 1, gr%sb%dim), kweights, &
           char, chars, char, ik, char, ist, char, idim

        if(st%d%ispin .eq. SPIN_POLARIZED) then
          call messages_not_implemented('Spin-Transport')

          if(is_spin_up(ik)) then
            !FIXME
            !              st%ob_eigenval(jst, SPIN_UP) = eigenval
            !              st%ob_occ(jst, SPIN_UP)      = occ
          else
            !              st%ob_eigenval(jst, SPIN_DOWN) = eigenval
            !              st%ob_occ(jst, SPIN_DOWN)      = occ
          end if
        else
          st%ob_eigenval(ist, ik) = eigenval
          st%ob_occ(ist, ik)      = occ
          st%ob_d%kweights(ik)    = kweights
        end if
      end do

      call io_close(occs)

      POP_SUB(states_init.read_ob_eigenval_and_occ)
    end subroutine read_ob_eigenval_and_occ
  end subroutine states_init

  ! ---------------------------------------------------------
  !> Reads the state stored in directory "dir", and finds out
  !! the kpoints, dim, and nst contained in it.
  ! ---------------------------------------------------------
  subroutine states_look(dir, mpi_grp, kpoints, dim, nst, ierr)
    character(len=*),  intent(in)    :: dir
    type(mpi_grp_t),   intent(in)    :: mpi_grp
    integer,           intent(out)   :: dim, ierr
    integer,           intent(inout) :: nst, kpoints

    character(len=256) :: line
    character(len=12)  :: filename
    character(len=1)   :: char
    integer :: iunit, iunit2, err, i, ist, idim, ik
    FLOAT :: occ, eigenval

    PUSH_SUB(states_look)

    ierr = 0
    iunit  = io_open(trim(dir)//'/wfns', action='read', status='old', die=.false., is_tmp=.true., grp=mpi_grp)

    if(iunit < 0) then
      ierr = -1
      POP_SUB(states_look)
      return
    end if

    iunit2 = io_open(trim(dir)//'/occs', action='read', status='old', die=.false., is_tmp=.true., grp=mpi_grp)

    if(iunit2 < 0) then
      call io_close(iunit, grp = mpi_grp)
      ierr = -1
      POP_SUB(states_look)
      return
    end if

    ! Skip two lines.
    call iopar_read(mpi_grp, iunit, line, err)
    call iopar_read(mpi_grp, iunit, line, err)
    call iopar_read(mpi_grp, iunit2, line, err)
    call iopar_read(mpi_grp, iunit2, line, err)

    kpoints = 1
    dim = 1
    nst = 1

    do
      call iopar_read(mpi_grp, iunit, line, i)
      read(line, '(a)') char
      if(i.ne.0.or.char=='%') exit
      read(line, *) ik, char, ist, char, idim, char, filename
      if(idim == 2)    dim     = 2
      call iopar_read(mpi_grp, iunit2, line, err)
      read(line, *) occ, char, eigenval
      if(ik > kpoints) kpoints = ik
      if(ist>nst)      nst     = ist
    end do

    call io_close(iunit, grp = mpi_grp)
    call io_close(iunit2, grp = mpi_grp)

    POP_SUB(states_look)
  end subroutine states_look

  ! ---------------------------------------------------------
  !> Allocate the lead densities.
  subroutine states_lead_densities_init(st, gr)
    type(states_t), intent(inout) :: st
    type(grid_t),   intent(in)    :: gr

    integer :: il

    PUSH_SUB(states_lead_densities_init)

    if(gr%ob_grid%open_boundaries) then
      do il = 1, NLEADS
        SAFE_ALLOCATE(st%ob_lead(il)%rho(1:gr%ob_grid%lead(il)%mesh%np, 1:st%d%nspin))
        st%ob_lead(il)%rho(:, :) = M_ZERO
      end do
    end if

    POP_SUB(states_lead_densities_init)
  end subroutine states_lead_densities_init


  ! ---------------------------------------------------------
  !> Deallocate the lead density.
  subroutine states_lead_densities_end(st, gr)
    type(states_t), intent(inout) :: st
    type(grid_t),   intent(in)    :: gr

    integer :: il

    PUSH_SUB(states_lead_densities_end)

    if(gr%ob_grid%open_boundaries) then
      do il = 1, NLEADS
        SAFE_DEALLOCATE_P(st%ob_lead(il)%rho)
      end do
    end if

    POP_SUB(states_lead_densities_end)
  end subroutine states_lead_densities_end


  ! ---------------------------------------------------------
  !> Reads from the input file the initial occupations, if the
  !! block "Occupations" is present. Otherwise, it makes an initial
  !! guess for the occupations, maybe using the "Smearing"
  !! variable.
  !!
  !! The resulting occupations are placed on the st\%occ variable. The
  !! boolean st\%fixed_occ is also set to .true., if the occupations are
  !! set by the user through the "Occupations" block; false otherwise.
  subroutine states_read_initial_occs(st, excess_charge)
    type(states_t), intent(inout) :: st
    FLOAT,          intent(in)    :: excess_charge

    integer :: ik, ist, ispin, nspin, ncols, el_per_state, icol, start_pos
    type(block_t) :: blk
    FLOAT :: rr, charge
    logical :: integral_occs
    FLOAT, allocatable :: read_occs(:, :)
    FLOAT :: charge_in_block

    PUSH_SUB(states_read_initial_occs)

    !%Variable RestartFixedOccupations
    !%Type logical
    !%Default no
    !%Section States
    !%Description
    !% Setting this variable will make the restart proceed as
    !% if the occupations from the previous calculation had been set via the <tt>Occupations</tt> block,
    !% <i>i.e.</i> fixed. Otherwise, occupations will be determined by smearing.
    !%End
    call parse_logical(datasets_check('RestartFixedOccupations'), .false., st%restart_fixed_occ)
    ! we will turn on st%fixed_occ if restart_read is ever called

    !%Variable Occupations
    !%Type block
    !%Section States
    !%Description
    !% The occupation numbers of the orbitals can be fixed through the use of this
    !% variable. For example:
    !%
    !% <tt>%Occupations
    !% <br>&nbsp;&nbsp;2.0 | 2.0 | 2.0 | 2.0 | 2.0
    !% <br>%</tt>
    !%
    !% would fix the occupations of the five states to <i>2.0</i>. There can be
    !% at most as many columns as states in the calculation. If there are fewer columns
    !% than states, then the code will assume that the user is indicating the occupations
    !% of the uppermost states, assigning maximum occupation (i.e. 2 for spin-unpolarized
    !% calculations, 1 otherwise) to the lower states. If <tt>SpinComponents == polarized</tt>
    !% this block should contain two lines, one for each spin channel.
    !% This variable is very useful when dealing with highly symmetric small systems
    !% (like an open-shell atom), for it allows us to fix the occupation numbers
    !% of degenerate states in order to help <tt>octopus</tt> to converge. This is to
    !% be used in conjuction with <tt>ExtraStates</tt>. For example, to calculate the
    !% carbon atom, one would do:
    !%
    !% <tt>ExtraStates = 2
    !% <br>%Occupations
    !% <br>&nbsp;&nbsp;2 | 2/3 | 2/3 | 2/3
    !% <br>%</tt>
    !%
    !% If you want the calculation to be spin-polarized (which makes more sense), you could do:
    !%
    !% <tt>ExtraStates = 2
    !% <br>%Occupations
    !% <br>&nbsp;&nbsp; 2/3 | 2/3 | 2/3
    !% <br>&nbsp;&nbsp; 0   |   0 |   0
    !% <br>%</tt>
    !%
    !% Note that in this case the first state is absent, the code will calculate four states
    !% (two because there are four electrons, plus two because <tt>ExtraStates</tt> = 2), and since
    !% it finds only three columns, it will occupy the first state with one electron for each
    !% of the spin options.
    !%
    !% If the sum of occupations is not equal to the total charge set by <tt>ExcessCharge</tt>,
    !% an error message is printed.
    !% If <tt>FromScratch = no</tt> and <tt>RestartFixedOccupations = yes</tt>,
    !% this block will be ignored.
    !%End

    integral_occs = .true.

    if(st%open_boundaries) then
      st%fixed_occ = .true.
      st%occ  = st%ob_occ
      st%d%kweights = st%ob_d%kweights
      st%qtot = M_ZERO
      do ist = 1, st%nst
        st%qtot = st%qtot + sum(st%occ(ist, 1:st%d%nik) * st%d%kweights(1:st%d%nik))
      end do

    else
      occ_fix: if(parse_block(datasets_check('Occupations'), blk)==0) then
        ! read in occupations
        st%fixed_occ = .true.

        ! Reads the number of columns in the first row. This assumes that all rows
        ! have the same column number; otherwise the code will stop with an error.
        ncols = parse_block_cols(blk, 0)
        if(ncols > st%nst) then
          message(1) = "Too many columns in block Occupations."
          call messages_warning(1)
          call input_error("Occupations")
        end if
        ! Now we fill all the "missing" states with the maximum occupation.
        if(st%d%ispin == UNPOLARIZED) then
          el_per_state = M_TWO
        else
          el_per_state = M_ONE
        endif
     
        SAFE_ALLOCATE(read_occs(1:ncols, 1:st%d%nik))
 
        do ik = 1, st%d%nik
          do icol = 1, ncols
            call parse_block_float(blk, ik - 1, icol - 1, read_occs(icol, ik))
          end do
        end do

        charge_in_block = sum(read_occs)

        start_pos = int((st%qtot - charge_in_block)/(el_per_state*st%d%nik))

        do ik = 1, st%d%nik
          do ist = 1, start_pos
            st%occ(ist, ik) = el_per_state
          end do
        end do

        do ik = 1, st%d%nik
          do ist = start_pos + 1, start_pos + ncols
            st%occ(ist, ik) = read_occs(ist - start_pos, ik)
            integral_occs = integral_occs .and. &
              abs((st%occ(ist, ik) - el_per_state) * st%occ(ist, ik)) .le. M_EPSILON
          end do
        end do

        do ik = 1, st%d%nik
          do ist = start_pos + ncols + 1, st%nst
             st%occ(ist, ik) = M_ZERO
          end do
        end do
        
        call parse_block_end(blk)

        SAFE_DEALLOCATE_A(read_occs)

      else
        st%fixed_occ = .false.
        integral_occs = .false.

        ! first guess for occupation...paramagnetic configuration
        rr = M_ONE
        if(st%d%ispin == UNPOLARIZED) rr = M_TWO

        st%occ  = M_ZERO
        st%qtot = -(st%val_charge + excess_charge)

        nspin = 1
        if(st%d%nspin == 2) nspin = 2

        do ik = 1, st%d%nik, nspin
          charge = M_ZERO
          do ispin = ik, ik + nspin - 1
            do ist = 1, st%nst
              st%occ(ist, ispin) = min(rr, -(st%val_charge + excess_charge) - charge)
              charge = charge + st%occ(ist, ispin)
            end do
          end do
        end do

      end if occ_fix
    end if

    !%Variable RestartReorderOccs
    !%Type logical
    !%Default no
    !%Section States
    !%Description
    !% Consider doing a ground-state calculation, and then restarting with new occupations set
    !% with the <tt>Occupations</tt> block, in an attempt to populate the orbitals of the original
    !% calculation. However, the eigenvalues may reorder as the density changes, in which case the
    !% occupations will now be referring to different orbitals. Setting this variable to yes will
    !% try to solve this issue when the restart data is being read, by reordering the occupations
    !% according to the order of the expectation values of the restart wavefunctions.
    !%End
    if(st%fixed_occ) then
      call parse_logical(datasets_check('RestartReorderOccs'), .false., st%restart_reorder_occs)
    else
      st%restart_reorder_occs = .false.
    endif

    call smear_init(st%smear, st%d%ispin, st%fixed_occ, integral_occs)

    if(.not. smear_is_semiconducting(st%smear) .and. .not. st%smear%method == SMEAR_FIXED_OCC) then
      if((st%d%ispin /= SPINORS .and. st%nst * 2 .le. st%qtot) .or. &
         (st%d%ispin == SPINORS .and. st%nst .le. st%qtot)) then
        call messages_write('Smearing needs unoccupied states (via ExtraStates) to be useful.')
        call messages_warning()
      endif
    endif

    ! sanity check
    charge = M_ZERO
    do ist = 1, st%nst
      charge = charge + sum(st%occ(ist, 1:st%d%nik) * st%d%kweights(1:st%d%nik))
    end do
    if(abs(charge - st%qtot) > CNST(1e-6)) then
      message(1) = "Initial occupations do not integrate to total charge."
      write(message(2), '(6x,f12.6,a,f12.6)') charge, ' != ', st%qtot
      call messages_fatal(2, only_root_writes = .true.)
    end if

    POP_SUB(states_read_initial_occs)
  end subroutine states_read_initial_occs


  ! ---------------------------------------------------------
  !> Reads, if present, the "InitialSpins" block. This is only
  !! done in spinors mode; otherwise the routine does nothing. The
  !! resulting spins are placed onto the st\%spin pointer. The boolean
  !! st\%fixed_spins is set to true if (and only if) the InitialSpins
  !! block is present.
  subroutine states_read_initial_spins(st)
    type(states_t), intent(inout) :: st

    integer :: i, j
    type(block_t) :: blk

    PUSH_SUB(states_read_initial_spins)

    st%fixed_spins = .false.
    if(st%d%ispin .ne. SPINORS) then
      POP_SUB(states_read_initial_spins)
      return
    end if

    !%Variable InitialSpins
    !%Type block
    !%Section States
    !%Description
    !% The spin character of the initial random guesses for the spinors can
    !% be fixed by making use of this block. Note that this will not "fix" the
    !% the spins during the calculation (this cannot be done in spinors mode, in
    !% being able to change the spins is why the spinors mode exists in the first
    !% place).
    !%
    !% This block is meaningless and ignored if the run is not in spinors mode
    !% (<tt>SpinComponents = spinors</tt>).
    !%
    !% The structure of the block is very simple: each column contains the desired
    !% &lt;<i>S_x</i>&gt;, &lt;<i>S_y</i>&gt;, &lt;<i>S_z</i>&gt; for each spinor.
    !% If the calculation is for a periodic system
    !% and there is more than one <i>k</i>-point, the spins of all the <i>k</i>-points are
    !% the same.
    !%
    !% For example, if we have two spinors, and we want one in the <i>Sx</i> "down" state,
    !% and another one in the <i>Sx</i> "up" state:
    !%
    !% <tt>%InitialSpins
    !% <br>&nbsp;&nbsp;  0.5 | 0.0 | 0.0
    !% <br>&nbsp;&nbsp; -0.5 | 0.0 | 0.0
    !% <br>%</tt>
    !%
    !% WARNING: if the calculation is for a system described by pseudopotentials (as
    !% opposed to user-defined potentials or model systems), this option is
    !% meaningless since the random spinors are overwritten by the atomic orbitals.
    !%
    !% There are a couple of physical constraints that have to be fulfilled:
    !%
    !% (A) | &lt;<i>S_i</i>&gt; | &lt;= 1/2
    !%
    !% (B) &lt;<i>S_x</i>&gt;^2 + &lt;<i>S_y</i>&gt;^2 + &lt;<i>S_z</i>&gt;^2 = 1/4
    !%
    !%End
    spin_fix: if(parse_block(datasets_check('InitialSpins'), blk)==0) then
      do i = 1, st%nst
        do j = 1, 3
          call parse_block_float(blk, i-1, j-1, st%spin(j, i, 1))
        end do
        ! This checks (B).
        if( abs(sum(st%spin(1:3, i, 1)**2) - M_FOURTH) > CNST(1.0e-6)) call input_error('InitialSpins')
      end do
      call parse_block_end(blk)
      ! This checks (A). In fact (A) follows from (B), so maybe this is not necessary...
      if(any(abs(st%spin(:, :, :)) > M_HALF)) then
        call input_error('InitialSpins')
      end if
      st%fixed_spins = .true.
      do i = 2, st%d%nik
        st%spin(:, :, i) = st%spin(:, :, 1)
      end do
    end if spin_fix

    POP_SUB(states_read_initial_spins)
  end subroutine states_read_initial_spins


  ! ---------------------------------------------------------
  !> Allocates the KS wavefunctions defined within a states_t structure.
  subroutine states_allocate_wfns(st, mesh, wfs_type, alloc_zphi)
    type(states_t),         intent(inout)   :: st
    type(mesh_t),           intent(in)      :: mesh
    type(type_t), optional, intent(in)      :: wfs_type
    logical,      optional, intent(in)      :: alloc_zphi ! only needed for gs transport

    integer :: ip, ik, ist, idim, st1, st2, k1, k2, np_part
    logical :: force

    PUSH_SUB(states_allocate_wfns)

    if(associated(st%dpsi).or.associated(st%zpsi)) then
      call messages_write('Trying to allocate wavefunctions that are already allocated.')
      call messages_fatal()
    end if

    if (present(wfs_type)) then
      ASSERT(wfs_type == TYPE_FLOAT .or. wfs_type == TYPE_CMPLX)
      st%priv%wfs_type = wfs_type
    end if

    !%Variable ForceComplex
    !%Type logical
    !%Default no
    !%Section Execution::Debug
    !%Description
    !% Normally <tt>Octopus</tt> determines automatically the type necessary
    !% for the wavefunctions. When set to yes this variable will
    !% force the use of complex wavefunctions.
    !%
    !% Warning: This variable is designed for testing and
    !% benchmarking and normal users need not use it.
    !%
    !%End
    call parse_logical(datasets_check('ForceComplex'), .false., force)

    if(force) call states_set_complex(st)

    st1 = st%st_start
    st2 = st%st_end
    k1 = st%d%kpt%start
    k2 = st%d%kpt%end
    np_part = mesh%np_part

    if(.not. st%d%pack_states) then

      if (states_are_real(st)) then
        SAFE_ALLOCATE(st%dpsi(1:np_part, 1:st%d%dim, st1:st2, k1:k2))
      else        
        SAFE_ALLOCATE(st%psi%zR(1:np_part, 1:st%d%dim, st1:st2, k1:k2))  
        st%zpsi => st%psi%zR
        if(st%d%cmplxscl) then
          SAFE_ALLOCATE(st%psi%zL(1:np_part, 1:st%d%dim, st1:st2, k1:k2))  
        else
          st%psi%zL => st%psi%zR  
        end if          
      end if
      
      if(optional_default(alloc_zphi, .false.)) then
        SAFE_ALLOCATE(st%zphi(1:np_part, 1:st%ob_d%dim, st1:st2, k1:k2))
        forall(ik=k1:k2, ist=st1:st2, idim=1:st%d%dim, ip=1:np_part)
          st%zphi(ip, idim, ist, ik) = M_Z0
        end forall
      else
        nullify(st%zphi)
      end if

    end if

    call states_init_block(st, mesh)
    call states_set_zero(st)

    POP_SUB(states_allocate_wfns)
  end subroutine states_allocate_wfns

  ! ---------------------------------------------------------
  !> Allocates the interface wavefunctions defined within a states_t structure.
  subroutine states_allocate_intf_wfns(st, ob_mesh)
    type(states_t),         intent(inout)   :: st
    type(mesh_t),           intent(in)      :: ob_mesh(:)

    integer :: st1, st2, k1, k2, il

    PUSH_SUB(states_allocate_intf_wfns)

    ASSERT(st%open_boundaries)

    st1 = st%st_start
    st2 = st%st_end
    k1 = st%d%kpt%start
    k2 = st%d%kpt%end

    do il = 1, NLEADS
      ASSERT(.not.associated(st%ob_lead(il)%intf_psi))
      SAFE_ALLOCATE(st%ob_lead(il)%intf_psi(1:ob_mesh(il)%np, 1:st%d%dim, st1:st2, k1:k2))
      st%ob_lead(il)%intf_psi = M_z0
    end do

    ! TODO: write states_init_block for intf_psi
!    call states_init_block(st)

    POP_SUB(states_allocate_intf_wfns)
  end subroutine states_allocate_intf_wfns
  ! -----------------------------------------------------


  !---------------------------------------------------------------------
  !> Initializes the data components in st that describe how the states
  !! are distributed in blocks:
  !!
  !! st\%nblocks: this is the number of blocks in which the states are divided. Note that
  !!   this number is the total number of blocks, regardless of how many are actually stored
  !!   in each node.
  !! block_start: in each node, the index of the first block.
  !! block_end: in each node, the index of the last block.
  !!   If the states are not parallelized, then block_start is 1 and block_end is st\%nblocks.
  !! st\%iblock(1:st\%nst, 1:st\%d\%nik): it points, for each state, to the block that contains it.
  !! st\%block_is_local(): st\%block_is_local(ib) is .true. if block ib is stored in the running node.
  !! st\%block_range(1:st\%nblocks, 1:2): Block ib contains states fromn st\%block_range(ib, 1) to st\%block_range(ib, 2)
  !! st\%block_size(1:st\%nblocks): Block ib contains a number st\%block_size(ib) of states.
  !! st\%block_initialized: it should be .false. on entry, and .true. after exiting this routine.
  !!
  !! The set of batches st\%psib(1:st\%nblocks) contains the blocks themselves.
  subroutine states_init_block(st, mesh)
    type(states_t),           intent(inout) :: st
    type(mesh_t),   optional, intent(in)    :: mesh

    integer :: ib, iqn, ist
    logical :: same_node
    integer, allocatable :: bstart(:), bend(:)

    PUSH_SUB(states_init_block)

    SAFE_ALLOCATE(bstart(1:st%nst))
    SAFE_ALLOCATE(bend(1:st%nst))
    SAFE_ALLOCATE(st%iblock(1:st%nst, 1:st%d%nik))
    st%iblock = 0

    ! count and assign blocks
    ib = 0
    st%nblocks = 0
    bstart(1) = 1
    do ist = 1, st%nst
      INCR(ib, 1)

      st%iblock(ist, st%d%kpt%start:st%d%kpt%end) = st%nblocks + 1

      same_node = .true.
      if(st%parallel_in_states .and. ist /= st%nst) then
        ! We have to avoid that states that are in different nodes end
        ! up in the same block
        same_node = (st%node(ist + 1) == st%node(ist))
      end if

      if(ib == st%d%block_size .or. ist == st%nst .or. .not. same_node) then
        ib = 0
        INCR(st%nblocks, 1)
        bend(st%nblocks) = ist
        if(ist /= st%nst) bstart(st%nblocks + 1) = ist + 1
      end if
    end do

    SAFE_ALLOCATE(st%psib(1:st%nblocks, 1:st%d%nik))
    SAFE_ALLOCATE(st%block_is_local(1:st%nblocks, 1:st%d%nik))
    st%block_is_local = .false.
    st%block_start  = -1
    st%block_end    = -2  ! this will make that loops block_start:block_end do not run if not initialized

    do ib = 1, st%nblocks
      if(bstart(ib) >= st%st_start .and. bend(ib) <= st%st_end) then
        if(st%block_start == -1) st%block_start = ib
        st%block_end = ib
        do iqn = st%d%kpt%start, st%d%kpt%end
          st%block_is_local(ib, iqn) = .true.

          if (states_are_real(st)) then
            if(associated(st%dpsi)) then
              call batch_init(st%psib(ib, iqn), st%d%dim, bstart(ib), bend(ib), st%dpsi(:, :, bstart(ib):bend(ib), iqn))
            else
              ASSERT(present(mesh))
              call batch_init(st%psib(ib, iqn), st%d%dim, bend(ib) - bstart(ib) + 1)
              call dbatch_new(st%psib(ib, iqn), bstart(ib), bend(ib), mesh%np_part)
            end if
          else
            if(associated(st%zpsi)) then
              call batch_init(st%psib(ib, iqn), st%d%dim, bstart(ib), bend(ib), st%zpsi(:, :, bstart(ib):bend(ib), iqn))
            else
              ASSERT(present(mesh))
              call batch_init(st%psib(ib, iqn), st%d%dim, bend(ib) - bstart(ib) + 1)
              call zbatch_new(st%psib(ib, iqn), bstart(ib), bend(ib), mesh%np_part)
            end if
          end if
        end do
      end if
    end do

    SAFE_ALLOCATE(st%block_range(1:st%nblocks, 1:2))
    SAFE_ALLOCATE(st%block_size(1:st%nblocks))

    st%block_range(1:st%nblocks, 1) = bstart(1:st%nblocks)
    st%block_range(1:st%nblocks, 2) = bend(1:st%nblocks)
    st%block_size(1:st%nblocks) = bend(1:st%nblocks) - bstart(1:st%nblocks) + 1

    st%block_initialized = .true.

!!$!!!!DEBUG
!!$    ! some debug output that I will keep here for the moment
!!$    if(mpi_grp_is_root(mpi_world)) then
!!$      print*, "NST       ", st%nst
!!$      print*, "BLOCKSIZE ", st%d%block_size
!!$      print*, "NBLOCKS   ", st%nblocks
!!$
!!$      print*, "==============="
!!$      do ist = 1, st%nst
!!$        print*, st%node(ist), ist, st%iblock(ist, 1)
!!$      end do
!!$      print*, "==============="
!!$
!!$      do ib = 1, st%nblocks
!!$        print*, ib, bstart(ib), bend(ib)
!!$      end do
!!$
!!$    end if
!!$!!!!ENDOFDEBUG

    SAFE_DEALLOCATE_A(bstart)
    SAFE_DEALLOCATE_A(bend)
    POP_SUB(states_init_block)
  end subroutine states_init_block


  ! ---------------------------------------------------------
  !> Deallocates the KS wavefunctions defined within a states_t structure.
  subroutine states_deallocate_wfns(st)
    type(states_t), intent(inout) :: st

    integer :: il, ib, iq

    PUSH_SUB(states_deallocate_wfns)

    if (st%block_initialized) then
       do ib = 1, st%nblocks
          do iq = st%d%kpt%start, st%d%kpt%end
            if(st%block_is_local(ib, iq)) call batch_end(st%psib(ib, iq))
          end do
       end do

       SAFE_DEALLOCATE_P(st%psib)
       SAFE_DEALLOCATE_P(st%iblock)
       SAFE_DEALLOCATE_P(st%block_range)
       SAFE_DEALLOCATE_P(st%block_size)
       SAFE_DEALLOCATE_P(st%block_is_local)
       st%block_initialized = .false.
    end if

    if (states_are_real(st)) then
      SAFE_DEALLOCATE_P(st%dpsi)
    else
      nullify(st%zpsi)
      if(associated(st%psi%zL,target=st%psi%zR )) then
        nullify(st%psi%zL)
      else          
        SAFE_DEALLOCATE_P(st%psi%zL) ! cmplxscl
      end if
      SAFE_DEALLOCATE_P(st%psi%zR) ! cmplxscl      
    end if

    if(st%open_boundaries) then
      do il = 1, NLEADS
        SAFE_DEALLOCATE_P(st%ob_lead(il)%intf_psi)
      end do
    end if

    POP_SUB(states_deallocate_wfns)
  end subroutine states_deallocate_wfns


  ! ---------------------------------------------------------
  subroutine states_densities_init(st, gr, geo)
    type(states_t), target, intent(inout) :: st
    type(grid_t),           intent(in)    :: gr
    type(geometry_t),       intent(in)    :: geo

    FLOAT :: size

    PUSH_SUB(states_densities_init)


    SAFE_ALLOCATE(st%zrho%Re(1:gr%fine%mesh%np_part, 1:st%d%nspin))
    st%zrho%Re = M_ZERO    
    st%rho => st%zrho%Re 
    if( st%d%cmplxscl) then
      SAFE_ALLOCATE(st%zrho%Im(1:gr%fine%mesh%np_part, 1:st%d%nspin))
      st%zrho%Im = M_ZERO
    end if

    if(st%d%cdft) then
      SAFE_ALLOCATE(st%current(1:gr%mesh%np_part, 1:gr%mesh%sb%dim, 1:st%d%nspin))
      st%current = M_ZERO
    end if
    if(geo%nlcc) then
      SAFE_ALLOCATE(st%rho_core(1:gr%fine%mesh%np))
      st%rho_core(:) = M_ZERO
      if(st%d%cmplxscl) then
        SAFE_ALLOCATE(st%Imrho_core(1:gr%fine%mesh%np))
        st%Imrho_core(:) = M_ZERO
      end if
    end if

    size = gr%mesh%np_part*CNST(8.0)*st%d%block_size

    call messages_write('Info: states-block size = ')
    call messages_write(size, fmt = '(f10.1)', align_left = .true., units = unit_megabytes, print_units = .true.)
    call messages_info()

    POP_SUB(states_densities_init)
  end subroutine states_densities_init



  !---------------------------------------------------------------------
  !> This subroutine: (i) Fills in the block size (st\%d\%block_size);
  !! (ii) Finds out whether or not to pack the states (st\%d\%pack_states);
  !! (iii) Finds out the orthogonalization method (st\%d\%orth_method).
  subroutine states_exec_init(st, mc)
    type(states_t),    intent(inout) :: st
    type(multicomm_t), intent(in)    :: mc

    integer :: default

    PUSH_SUB(states_exec_init)

    !%Variable StatesPack
    !%Type logical
    !%Default no
    !%Section Execution::Optimization
    !%Description
    !% (Experimental) When set to yes, states are stored in packed
    !% mode, which improves performance considerably. However this
    !% is not fully implemented and it might give wrong results. The
    !% default is no.
    !%
    !% If OpenCL is used and this variable is set to yes, Octopus
    !% will store the wave-functions in device (GPU) memory. If
    !% there is not enough memory to store all the wave-functions,
    !% execution will stop with an error.
    !%End

    call parse_logical(datasets_check('StatesPack'), .false., st%d%pack_states)
    if(st%d%pack_states) call messages_experimental('StatesPack')

    !%Variable StatesOrthogonalization
    !%Type integer
    !%Section Execution::Optimization
    !%Description
    !% The full orthogonalization method used by some
    !% eigensolvers. The default is gram_schmidt. With state
    !% parallelization the default is par_gram_schmidt.
    !%Option gram_schmidt 1
    !% Cholesky decomposition (despite the name) implemented using
    !% BLAS/LAPACK. Can be used with domain parallelization but not
    !% state parallelization.
    !%Option par_gram_schmidt 2
    !% Cholesky decomposition (despite the name) implemented using
    !% ScaLAPACK. Compatible with states parallelization.
    !%Option mgs 3
    !% Modified Gram-Schmidt orthogonalization.
    !% Can be used with domain parallelization but not state parallelization.
    !%Option qr 4
    !% (Experimental) Orthogonalization is performed based on a QR
    !% decomposition with LAPACK or ScaLAPACK.
    !% Compatible with states parallelization.
    !%End

    if(multicomm_strategy_is_parallel(mc, P_STRATEGY_STATES)) then
      default = ORTH_PAR_GS
    else
      default = ORTH_GS
    end if

    call parse_integer(datasets_check('StatesOrthogonalization'), default, st%d%orth_method)

    if(.not.varinfo_valid_option('StatesOrthogonalization', st%d%orth_method)) call input_error('StatesOrthogonalization')
    call messages_print_var_option(stdout, 'StatesOrthogonalization', st%d%orth_method)

    if(st%d%orth_method == ORTH_QR) call messages_experimental("QR Orthogonalization")


    !%Variable StatesCLDeviceMemory
    !%Type float
    !%Section Execution::Optimization
    !%Default -512
    !%Description
    !% This variable selects the amount of OpenCL device memory that
    !% will be used by Octopus to store the states. 
    !%
    !% A positive number smaller than 1 indicates a fraction of the total
    !% device memory. A number larger than one indicates an absolute
    !% amount of memory in megabytes. A negative number indicates an
    !% amount of memory in megabytes that would be substracted from
    !% the total device memory.
    !%End
    call parse_float(datasets_check('StatesCLDeviceMemory'), CNST(-512.0), st%d%cl_states_mem)

    POP_SUB(states_exec_init)
  end subroutine states_exec_init



  !---------------------------------------------------------------------
  subroutine states_resize_unocc(st, nus)
    type(states_t), intent(inout) :: st
    integer,        intent(in)    :: nus

    FLOAT, pointer :: new_occ(:,:)

    PUSH_SUB(states_resize_unocc)

    ! Resize st%occ, retaining current values
    SAFE_ALLOCATE(new_occ(1:st%nst + nus, 1:st%d%nik))
    new_occ(1:st%nst,:) = st%occ(1:st%nst,:)
    new_occ(st%nst+1:,:) = M_ZERO
    SAFE_DEALLOCATE_P(st%occ)
    st%occ => new_occ

    ! fix states: THIS IS NOT OK
    st%nst    = st%nst + nus
    st%st_end = st%nst

    !cmplxscl
    SAFE_DEALLOCATE_P(st%zeigenval%Re)
    SAFE_DEALLOCATE_P(st%zeigenval%Im)
    nullify(st%eigenval)
    if (st%d%cmplxscl) then      
      SAFE_ALLOCATE(st%zeigenval%Im(1:st%nst, 1:st%d%nik))
      st%zeigenval%Im = huge(st%zeigenval%Im)      
    end if
    SAFE_ALLOCATE(st%zeigenval%Re(1:st%nst, 1:st%d%nik))
    st%zeigenval%Re = huge(st%zeigenval%Re)
    st%eigenval => st%zeigenval%Re 
  

    if(st%d%ispin == SPINORS) then
      SAFE_ALLOCATE(st%spin(1:3, 1:st%nst, 1:st%d%nik))
      st%spin = M_ZERO
    end if
    
    POP_SUB(states_resize_unocc)
    
  end subroutine states_resize_unocc


  ! ---------------------------------------------------------
  subroutine states_copy(stout, stin)
    type(states_t), target, intent(inout) :: stout
    type(states_t),         intent(in)    :: stin

    PUSH_SUB(states_copy)

    call states_null(stout)

    call states_dim_copy(stout%d, stin%d)
    call modelmb_particles_copy(stout%modelmbparticles, stin%modelmbparticles)
    stout%priv%wfs_type = stin%priv%wfs_type
    stout%nst           = stin%nst

    stout%only_userdef_istates = stin%only_userdef_istates
    call loct_pointer_copy(stout%dpsi, stin%dpsi)

    !cmplxscl
    call loct_pointer_copy(stout%psi%zR, stin%psi%zR)         
    stout%zpsi => stout%psi%zR
    call loct_pointer_copy(stout%zrho%Re, stin%zrho%Re)           
    stout%rho => stout%zrho%Re
    call loct_pointer_copy(stout%zeigenval%Re, stin%zeigenval%Re) 
    stout%eigenval => stout%zeigenval%Re
    if(stin%d%cmplxscl) then
      call loct_pointer_copy(stout%psi%zL, stin%psi%zL)         
      call loct_pointer_copy(stout%zrho%Im, stin%zrho%Im)           
      call loct_pointer_copy(stout%zeigenval%Im, stin%zeigenval%Im) 
      call loct_pointer_copy(stout%Imrho_core, stin%Imrho_core)
      call loct_pointer_copy(stout%Imfrozen_rho, stin%Imfrozen_rho)
    end if

    
    ! the call to init_block is done at the end of this subroutine
    ! it allocates iblock, psib, block_is_local
    stout%nblocks = stin%nblocks

    stout%open_boundaries = stin%open_boundaries
    ! Warning: some of the "open boundaries" variables are not copied.

    call loct_pointer_copy(stout%user_def_states, stin%user_def_states)

    call loct_pointer_copy(stout%current, stin%current)

    call loct_pointer_copy(stout%rho_core, stin%rho_core)
    stout%current_in_tau = stin%current_in_tau

    call loct_pointer_copy(stout%frozen_rho, stin%frozen_rho)

    stout%fixed_occ = stin%fixed_occ
    stout%restart_fixed_occ = stin%restart_fixed_occ

    call loct_pointer_copy(stout%occ, stin%occ)
    stout%fixed_spins = stin%fixed_spins

    call loct_pointer_copy(stout%spin, stin%spin)

    stout%qtot       = stin%qtot
    stout%val_charge = stin%val_charge

    call smear_copy(stout%smear, stin%smear)

    stout%parallel_in_states = stin%parallel_in_states
    call mpi_grp_copy(stout%mpi_grp, stin%mpi_grp)
    stout%dom_st_kpt_mpi_grp = stin%dom_st_kpt_mpi_grp
    stout%st_kpt_mpi_grp     = stin%st_kpt_mpi_grp

#ifdef HAVE_SCALAPACK
    call blacs_proc_grid_copy(stin%dom_st_proc_grid, stout%dom_st_proc_grid)
#endif

    stout%lnst       = stin%lnst
    stout%st_start   = stin%st_start
    stout%st_end     = stin%st_end
    call loct_pointer_copy(stout%node, stin%node)
    call loct_pointer_copy(stout%st_range, stin%st_range)
    call loct_pointer_copy(stout%st_num, stin%st_num)

    if(stin%parallel_in_states) call multicomm_all_pairs_copy(stout%ap, stin%ap)

    stout%symmetrize_density = stin%symmetrize_density

    stout%block_initialized = .false.
    if(stin%block_initialized) then
      call states_init_block(stout)
    end if

    stout%packed = stin%packed

    POP_SUB(states_copy)
  end subroutine states_copy


  ! ---------------------------------------------------------
  subroutine states_end(st)
    type(states_t), intent(inout) :: st

    integer :: il

    PUSH_SUB(states_end)

    call states_dim_end(st%d)
    call modelmb_particles_end(st%modelmbparticles)

    ! this deallocates dpsi, zpsi, psib, iblock, iblock, st%ob_lead(:)%intf_psi
    call states_deallocate_wfns(st)

    SAFE_DEALLOCATE_P(st%zphi)
    SAFE_DEALLOCATE_P(st%ob_eigenval)
    call states_dim_end(st%ob_d)
    SAFE_DEALLOCATE_P(st%ob_occ)
    do il = 1, 2*MAX_DIM
      SAFE_DEALLOCATE_P(st%ob_lead(il)%self_energy)
    end do

    SAFE_DEALLOCATE_P(st%user_def_states)

    !cmplxscl
    !NOTE: sometimes these objects are allocated outside this module
    ! and therefore the correspondence with val => val%Re is broken.
    ! In this case we check if the pointer val is associated to zval%Re.
    if(associated(st%rho, target=st%zrho%Re)) then 
      nullify(st%rho)
      SAFE_DEALLOCATE_P(st%zrho%Re)       
    else
      SAFE_DEALLOCATE_P(st%rho)
    end if
    if(associated(st%eigenval, target=st%zeigenval%Re)) then 
      nullify(st%eigenval)
      SAFE_DEALLOCATE_P(st%zeigenval%Re)
    else
      SAFE_DEALLOCATE_P(st%eigenval)
    end if
    if(st%d%cmplxscl) then
      SAFE_DEALLOCATE_P(st%zrho%Im)
      SAFE_DEALLOCATE_P(st%zeigenval%Im)
      SAFE_DEALLOCATE_P(st%Imrho_core)
      SAFE_DEALLOCATE_P(st%Imfrozen_rho)
    end if
    

    SAFE_DEALLOCATE_P(st%current)
    SAFE_DEALLOCATE_P(st%rho_core)
    SAFE_DEALLOCATE_P(st%frozen_rho)

    SAFE_DEALLOCATE_P(st%occ)
    SAFE_DEALLOCATE_P(st%spin)

#ifdef HAVE_SCALAPACK
    call blacs_proc_grid_end(st%dom_st_proc_grid)
#endif
    SAFE_DEALLOCATE_P(st%node)
    SAFE_DEALLOCATE_P(st%st_range)
    SAFE_DEALLOCATE_P(st%st_num)

    if(st%parallel_in_states) then
      SAFE_DEALLOCATE_P(st%ap%schedule)
    end if

    POP_SUB(states_end)
  end subroutine states_end

  ! ---------------------------------------------------------
  !> generate a hydrogen s-wavefunction around a random point
  subroutine states_generate_random(st, mesh, ist_start_, ist_end_)
    type(states_t),    intent(inout) :: st
    type(mesh_t),      intent(in)    :: mesh
    integer, optional, intent(in)    :: ist_start_, ist_end_

    integer :: ist, ik, id, ist_start, ist_end, jst, seed
    CMPLX   :: alpha, beta
    FLOAT, allocatable :: dpsi(:,  :)
    CMPLX, allocatable :: zpsi(:,  :), zpsi2(:)

    PUSH_SUB(states_generate_random)

    ist_start = st%st_start
    if(present(ist_start_)) ist_start = max(ist_start, ist_start_)
    ist_end = st%st_end
    if(present(ist_end_)) ist_end = min(ist_end, ist_end_)

    if(st%parallel_in_states) then
      seed = st%mpi_grp%rank
    else
      seed = 0
    end if
    if(st%d%kpt%parallel) then
      seed = st%d%kpt%mpi_grp%rank
    else
      seed = 0
    end if

    if (states_are_real(st)) then
      SAFE_ALLOCATE(dpsi(1:mesh%np, 1:st%d%dim))
    else
      SAFE_ALLOCATE(zpsi(1:mesh%np, 1:st%d%dim))
    end if

    select case(st%d%ispin)
    case(UNPOLARIZED, SPIN_POLARIZED)

      do ik = st%d%kpt%start, st%d%kpt%end
        do ist = ist_start, ist_end
          if (states_are_real(st)) then
            call dmf_random(mesh, dpsi(:, 1), seed)
            call states_set_state(st, mesh, ist,  ik, dpsi)
          else
            call zmf_random(mesh, zpsi(:, 1), seed)
            call states_set_state(st, mesh, ist,  ik, zpsi)
          end if
          st%eigenval(ist, ik) = M_ZERO
        end do
      end do

    case(SPINORS)

      ASSERT(states_are_complex(st))

      if(st%fixed_spins) then

        do ik = st%d%kpt%start, st%d%kpt%end
          do ist = ist_start, ist_end
            call zmf_random(mesh, zpsi(:, 1))
            ! In this case, the spinors are made of a spatial part times a vector [alpha beta]^T in
            ! spin space (i.e., same spatial part for each spin component). So (alpha, beta)
            ! determines the spin values. The values of (alpha, beta) can be be obtained
            ! with simple formulae from <Sx>, <Sy>, <Sz>.
            !
            ! Note that here we orthonormalize the orbital part. This ensures that the spinors
            ! are untouched later in the general orthonormalization, and therefore the spin values
            ! of each spinor remain the same.
            SAFE_ALLOCATE(zpsi2(1:mesh%np))
            do jst = ist_start, ist - 1
              call states_get_state(st, mesh, 1, jst, ik, zpsi2)
              zpsi(1:mesh%np, 1) = zpsi(1:mesh%np, 1) - zmf_dotp(mesh, zpsi(:, 1), zpsi2)*zpsi2(1:mesh%np)
            end do
            SAFE_DEALLOCATE_A(zpsi2)

            zpsi(1:mesh%np, 1) = zpsi(1:mesh%np, 1)/zmf_nrm2(mesh, zpsi(:, 1))
            zpsi(1:mesh%np, 2) = zpsi(1:mesh%np, 1)

            alpha = TOCMPLX(sqrt(M_HALF + st%spin(3, ist, ik)), M_ZERO)
            beta  = TOCMPLX(sqrt(M_ONE - abs(alpha)**2), M_ZERO)
            if(abs(alpha) > M_ZERO) then
              beta = TOCMPLX(st%spin(1, ist, ik) / abs(alpha), st%spin(2, ist, ik) / abs(alpha))
            end if
            zpsi(1:mesh%np, 1) = alpha*zpsi(1:mesh%np, 1)
            zpsi(1:mesh%np, 2) = beta*zpsi(1:mesh%np, 2)
            st%eigenval(ist, ik) = M_ZERO

            call states_set_state(st, mesh, ist,  ik, zpsi)
          end do
        end do
      else
        do ik = st%d%kpt%start, st%d%kpt%end
          do ist = ist_start, ist_end
            do id = 1, st%d%dim
              call zmf_random(mesh, zpsi(:, id))
            end do
            call states_set_state(st, mesh, ist,  ik, zpsi)
            st%eigenval(ist, ik) = M_HUGE
          end do
        end do
      end if

    end select

    SAFE_DEALLOCATE_A(dpsi)
    SAFE_DEALLOCATE_A(zpsi)

    POP_SUB(states_generate_random)
  end subroutine states_generate_random

  ! ---------------------------------------------------------
  subroutine states_fermi(st, mesh)
    type(states_t), intent(inout) :: st
    type(mesh_t),   intent(in)    :: mesh

    !> Local variables.
    integer            :: ist, ik
    FLOAT              :: charge
    CMPLX, allocatable :: zpsi(:, :)
#if defined(HAVE_MPI)
    integer            :: idir, tmp
    FLOAT, allocatable :: lspin(:), lspin2(:) !< To exchange spin.
#endif

    PUSH_SUB(states_fermi)

    if(st%d%cmplxscl) then
      call smear_find_fermi_energy(st%smear, st%zeigenval%Re, st%occ, st%qtot, &
        st%d%nik, st%nst, st%d%kweights, st%zeigenval%Im)

      call smear_fill_occupations(st%smear, st%eigenval, st%occ, &
        st%d%nik, st%nst, st%zeigenval%Im)
    else
      
      call smear_find_fermi_energy(st%smear, st%eigenval, st%occ, st%qtot, &
        st%d%nik, st%nst, st%d%kweights)

      call smear_fill_occupations(st%smear, st%eigenval, st%occ, &
        st%d%nik, st%nst)
        
    end if
    
    ! check if everything is OK
    charge = M_ZERO
    do ist = 1, st%nst
      charge = charge + sum(st%occ(ist, 1:st%d%nik) * st%d%kweights(1:st%d%nik))
    end do
    if(abs(charge-st%qtot) > CNST(1e-6)) then
      message(1) = 'Occupations do not integrate to total charge.'
      write(message(2), '(6x,f12.8,a,f12.8)') charge, ' != ', st%qtot
      call messages_warning(2)
      if(charge < M_EPSILON) then
        message(1) = "There don't seem to be any electrons at all!"
        call messages_fatal(1)
      endif
    end if

    if(st%d%ispin == SPINORS) then
      ASSERT(states_are_complex(st))
      
      SAFE_ALLOCATE(zpsi(1:mesh%np, st%d%dim))
      do ik = st%d%kpt%start, st%d%kpt%end
        do ist = st%st_start, st%st_end
          call states_get_state(st, mesh, ist, ik, zpsi)
          st%spin(1:3, ist, ik) = state_spin(mesh, zpsi)
        end do
#if defined(HAVE_MPI)
        if(st%parallel_in_states) then
          SAFE_ALLOCATE(lspin (1:st%lnst))
          SAFE_ALLOCATE(lspin2(1:st%nst))
          do idir = 1, 3
            lspin = st%spin(idir, st%st_start:st%st_end, ik)
            call lmpi_gen_allgatherv(st%lnst, lspin, tmp, lspin2, st%mpi_grp)
            do ist = 1, st%nst
              st%spin(idir, ist, ik) = lspin2(ist)
            enddo
          end do
          SAFE_DEALLOCATE_A(lspin)
          SAFE_DEALLOCATE_A(lspin2)
        end if
#endif
      end do
      SAFE_DEALLOCATE_A(zpsi)
    end if

    POP_SUB(states_fermi)
  end subroutine states_fermi


  ! ---------------------------------------------------------
  !> function to calculate the eigenvalues sum using occupations as weights
  function states_eigenvalues_sum(st, alt_eig) result(tot)
    type(states_t), intent(in)  :: st
    FLOAT, optional, intent(in) :: alt_eig(st%st_start:, :) !< (st%st_start:st%st_end, 1:st%d%nik)
    FLOAT                       :: tot

    integer :: ik

    PUSH_SUB(states_eigenvalues_sum)

    tot = M_ZERO
    do ik = st%d%kpt%start, st%d%kpt%end
      if(present(alt_eig)) then
        tot = tot + st%d%kweights(ik) * sum(st%occ(st%st_start:st%st_end, ik) * &
          alt_eig(st%st_start:st%st_end, ik))
      else
        tot = tot + st%d%kweights(ik) * sum(st%occ(st%st_start:st%st_end, ik) * &
          st%eigenval(st%st_start:st%st_end, ik))
      end if
    end do

    if(st%parallel_in_states .or. st%d%kpt%parallel) call comm_allreduce(st%st_kpt_mpi_grp%comm, tot)

    POP_SUB(states_eigenvalues_sum)
  end function states_eigenvalues_sum

  ! ---------------------------------------------------------
  !> Same as states_eigenvalues_sum but suitable for cmplxscl
  function zstates_eigenvalues_sum(st, alt_eig) result(tot)
    type(states_t), intent(in)  :: st
    CMPLX, optional, intent(in) :: alt_eig(st%st_start:, :) !< (st%st_start:st%st_end, 1:st%d%nik)
    CMPLX                       :: tot

    integer :: ik

    PUSH_SUB(zstates_eigenvalues_sum)

    tot = M_ZERO
    do ik = st%d%kpt%start, st%d%kpt%end
      if(present(alt_eig)) then
        tot = tot + st%d%kweights(ik) * sum(st%occ(st%st_start:st%st_end, ik) * &
          (alt_eig(st%st_start:st%st_end, ik)))
      else
        tot = tot + st%d%kweights(ik) * sum(st%occ(st%st_start:st%st_end, ik) * &
          (st%zeigenval%Re(st%st_start:st%st_end, ik) + M_zI * st%zeigenval%Im(st%st_start:st%st_end, ik)))
      end if
    end do

    if(st%parallel_in_states .or. st%d%kpt%parallel) call comm_allreduce(st%st_kpt_mpi_grp%comm, tot)

    POP_SUB(zstates_eigenvalues_sum)
  end function zstates_eigenvalues_sum

  ! -------------------------------------------------------
  integer pure function states_spin_channel(ispin, ik, dim)
    integer, intent(in) :: ispin, ik, dim

    select case(ispin)
    case(1); states_spin_channel = 1
    case(2); states_spin_channel = mod(ik+1, 2)+1
    case(3); states_spin_channel = dim
    case default; states_spin_channel = -1
    end select

  end function states_spin_channel


  ! ---------------------------------------------------------
  subroutine states_distribute_nodes(st, mc)
    type(states_t),    intent(inout) :: st
    type(multicomm_t), intent(in)    :: mc

#ifdef HAVE_MPI
    integer :: inode, ist
#endif

    PUSH_SUB(states_distribute_nodes)

    ! Defaults.
    st%node(:)            = 0
    st%st_start           = 1
    st%st_end             = st%nst
    st%lnst               = st%nst
    st%parallel_in_states = .false.
    call mpi_grp_init(st%mpi_grp, mc%group_comm(P_STRATEGY_STATES))
    call mpi_grp_init(st%dom_st_kpt_mpi_grp, mc%dom_st_kpt_comm)
    call mpi_grp_init(st%dom_st_mpi_grp, mc%dom_st_comm)
    call mpi_grp_init(st%st_kpt_mpi_grp, mc%st_kpt_comm)

#ifdef HAVE_SCALAPACK
    if(calc_mode_scalapack_compat() .and. .not. st%d%kpt%parallel) then
      call blacs_proc_grid_init(st%dom_st_proc_grid, st%dom_st_mpi_grp)
    end if
#endif

#if defined(HAVE_MPI)
    if(multicomm_strategy_is_parallel(mc, P_STRATEGY_STATES)) then
      st%parallel_in_states = .true.

      call multicomm_create_all_pairs(st%mpi_grp, st%ap)

     if(st%nst < st%mpi_grp%size) then
       message(1) = "Have more processors than necessary"
       write(message(2),'(i4,a,i4,a)') st%mpi_grp%size, " processors and ", st%nst, " states."
       call messages_fatal(2)
     end if

     SAFE_ALLOCATE(st%st_range(1:2, 0:st%mpi_grp%size-1))
     SAFE_ALLOCATE(st%st_num(0:st%mpi_grp%size-1))

     call multicomm_divide_range(st%nst, st%mpi_grp%size, st%st_range(1, :), st%st_range(2, :), &
       lsize = st%st_num, scalapack_compat = calc_mode_scalapack_compat())

     message(1) = "Info: Parallelization in states"
     call messages_info(1)

     do inode = 0, st%mpi_grp%size - 1
       write(message(1),'(a,i4,a,i5,a)') &
            'Info: Nodes in states-group ', inode, ' will manage ', st%st_num(inode), ' states'
       if(st%st_num(inode) > 0) then
         write(message(1),'(a,a,i6,a,i6)') trim(message(1)), ':', &
           st%st_range(1, inode), " - ", st%st_range(2, inode)
       endif
       call messages_info(1)

       do ist = st%st_range(1, inode), st%st_range(2, inode)
         st%node(ist) = inode
       end do
     end do

     if(any(st%st_num(:) == 0)) then
       message(1) = "Cannot run with empty states-groups. Select a smaller number of processors so none are idle."
       call messages_fatal(1, only_root_writes = .true.)
     endif

     st%st_start = st%st_range(1, st%mpi_grp%rank)
     st%st_end   = st%st_range(2, st%mpi_grp%rank)
     st%lnst     = st%st_num(st%mpi_grp%rank)

   end if
#endif

    POP_SUB(states_distribute_nodes)
  end subroutine states_distribute_nodes


  ! ---------------------------------------------------------
  subroutine states_set_complex(st)
    type(states_t),    intent(inout) :: st

    PUSH_SUB(states_set_complex)
    st%priv%wfs_type = TYPE_CMPLX

    POP_SUB(states_set_complex)
  end subroutine states_set_complex

  ! ---------------------------------------------------------
  pure logical function states_are_complex(st) result (wac)
    type(states_t),    intent(in) :: st

    wac = (st%priv%wfs_type == TYPE_CMPLX)

  end function states_are_complex


  ! ---------------------------------------------------------
  pure logical function states_are_real(st) result (war)
    type(states_t),    intent(in) :: st

    war = (st%priv%wfs_type == TYPE_FLOAT)

  end function states_are_real

  ! ---------------------------------------------------------
  !
  !> This function can calculate several quantities that depend on
  !! derivatives of the orbitals from the states and the density.
  !! The quantities to be calculated depend on the arguments passed.
  subroutine states_calc_quantities(der, st, &
    kinetic_energy_density, paramagnetic_current, density_gradient, density_laplacian, gi_kinetic_energy_density)
    type(derivatives_t),     intent(in)    :: der
    type(states_t),          intent(in)    :: st
    FLOAT, optional, target, intent(out)   :: kinetic_energy_density(:,:)       !< The kinetic energy density.
    FLOAT, optional, target, intent(out)   :: paramagnetic_current(:,:,:)       !< The paramagnetic current.
    FLOAT, optional,         intent(out)   :: density_gradient(:,:,:)           !< The gradient of the density.
    FLOAT, optional,         intent(out)   :: density_laplacian(:,:)            !< The Laplacian of the density.
    FLOAT, optional,         intent(out)   :: gi_kinetic_energy_density(:,:)    !< The gauge-invariant kinetic energy density.

    FLOAT, pointer :: jp(:, :, :)
    FLOAT, pointer :: tau(:, :)
    CMPLX, allocatable :: wf_psi(:,:), gwf_psi(:,:,:), lwf_psi(:,:)
    CMPLX   :: c_tmp
    integer :: is, ik, ik_tmp, ist, i_dim, st_dim, ii
    FLOAT   :: ww, kpoint(1:MAX_DIM)
    logical :: something_to_do

    PUSH_SUB(states_calc_quantities)

    something_to_do = present(kinetic_energy_density) .or. present(gi_kinetic_energy_density) .or. &
      present(paramagnetic_current) .or. present(density_gradient) .or. present(density_laplacian)
    ASSERT(something_to_do)

    SAFE_ALLOCATE( wf_psi(1:der%mesh%np_part, 1:st%d%dim))
    SAFE_ALLOCATE(gwf_psi(1:der%mesh%np, 1:der%mesh%sb%dim, 1:st%d%dim))
    if(present(density_laplacian)) then
      SAFE_ALLOCATE(lwf_psi(1:der%mesh%np, 1:st%d%dim))
    endif

    nullify(tau)
    if(present(kinetic_energy_density)) tau => kinetic_energy_density

    nullify(jp)
    if(present(paramagnetic_current)) jp => paramagnetic_current

    ! for the gauge-invariant kinetic energy density we need the
    ! current and the kinetic energy density
    if(present(gi_kinetic_energy_density)) then
      if(.not. present(paramagnetic_current) .and. states_are_complex(st)) then
        SAFE_ALLOCATE(jp(1:der%mesh%np, 1:der%mesh%sb%dim, 1:st%d%nspin))
      end if
      if(.not. present(kinetic_energy_density)) then
        SAFE_ALLOCATE(tau(1:der%mesh%np, 1:st%d%nspin))
      end if
    end if

    if(associated(tau)) tau = M_ZERO
    if(associated(jp)) jp = M_ZERO
    if(present(density_gradient)) density_gradient(:,:,:) = M_ZERO
    if(present(density_laplacian)) density_laplacian(:,:) = M_ZERO
    if(present(gi_kinetic_energy_density)) gi_kinetic_energy_density = M_ZERO

    do ik = st%d%kpt%start, st%d%kpt%end

      kpoint(1:der%mesh%sb%dim) = kpoints_get_point(der%mesh%sb%kpoints, states_dim_get_kpoint_index(st%d, ik))
      is = states_dim_get_spin_index(st%d, ik)

      do ist = st%st_start, st%st_end

        ! all calculations will be done with complex wavefunctions
        call states_get_state(st, der%mesh, ist, ik, wf_psi)

        ! calculate gradient of the wavefunction
        do st_dim = 1, st%d%dim
          call zderivatives_grad(der, wf_psi(:,st_dim), gwf_psi(:,:,st_dim))
        end do

        ! calculate the Laplacian of the wavefunction
        if (present(density_laplacian)) then
          do st_dim = 1, st%d%dim
            call zderivatives_lapl(der, wf_psi(:,st_dim), lwf_psi(:,st_dim))
          end do
        end if

        ww = st%d%kweights(ik)*st%occ(ist, ik)

        if(present(density_laplacian)) then
          density_laplacian(1:der%mesh%np, is) = density_laplacian(1:der%mesh%np, is) + &
               ww*M_TWO*real(conjg(wf_psi(1:der%mesh%np, 1))*lwf_psi(1:der%mesh%np, 1))
          if(st%d%ispin == SPINORS) then
            density_laplacian(1:der%mesh%np, 2) = density_laplacian(1:der%mesh%np, 2) + &
                 ww*M_TWO*real(conjg(wf_psi(1:der%mesh%np, 2))*lwf_psi(1:der%mesh%np, 2))
            density_laplacian(1:der%mesh%np, 3) = density_laplacian(1:der%mesh%np, 3) + &
                 ww*real (lwf_psi(1:der%mesh%np, 1)*conjg(wf_psi(1:der%mesh%np, 2)) + &
                 wf_psi(1:der%mesh%np, 1)*conjg(lwf_psi(1:der%mesh%np, 2)))
            density_laplacian(1:der%mesh%np, 4) = density_laplacian(1:der%mesh%np, 4) + &
                 ww*aimag(lwf_psi(1:der%mesh%np, 1)*conjg(wf_psi(1:der%mesh%np, 2)) + &
                 wf_psi(1:der%mesh%np, 1)*conjg(lwf_psi(1:der%mesh%np, 2)))
          end if
        end if
        
        do i_dim = 1, der%mesh%sb%dim
          if(present(density_gradient)) &
               density_gradient(1:der%mesh%np, i_dim, is) = density_gradient(1:der%mesh%np, i_dim, is) + &
               ww*M_TWO*real(conjg(wf_psi(1:der%mesh%np, 1))*gwf_psi(1:der%mesh%np, i_dim, 1))
          if(present(density_laplacian)) &
               density_laplacian(1:der%mesh%np, is) = density_laplacian(1:der%mesh%np, is)         + &
               ww*M_TWO*real(conjg(gwf_psi(1:der%mesh%np, i_dim, 1))*gwf_psi(1:der%mesh%np, i_dim, 1))

          if(associated(jp)) then
            if (.not.(states_are_real(st))) then
              jp(1:der%mesh%np, i_dim, is) = jp(1:der%mesh%np, i_dim, is) + &
                   ww*aimag(conjg(wf_psi(1:der%mesh%np, 1))*gwf_psi(1:der%mesh%np, i_dim, 1) - &
                   M_zI*(wf_psi(1:der%mesh%np, 1))**2*kpoint(i_dim ) )
            else
              jp(1:der%mesh%np, i_dim, is) = M_ZERO
            end if
          end if

          if (associated(tau)) then
            tau (1:der%mesh%np, is)   = tau (1:der%mesh%np, is)        + &
                 ww*abs(gwf_psi(1:der%mesh%np, i_dim, 1))**2  &
                 + ww*abs(kpoint(i_dim))**2*abs(wf_psi(1:der%mesh%np, 1))**2  &
                 - ww*M_TWO*aimag(conjg(wf_psi(1:der%mesh%np, 1))*kpoint(i_dim)*gwf_psi(1:der%mesh%np, i_dim, 1) )
          end if

          if(present(gi_kinetic_energy_density)) then
            ASSERT(associated(tau))
            if(states_are_complex(st) .and. st%current_in_tau) then
              ASSERT(associated(jp))
              gi_kinetic_energy_density(1:der%mesh%np, is) = tau(1:der%mesh%np, is) - &
                   jp(1:der%mesh%np, i_dim, 1)**2/st%rho(1:der%mesh%np, 1)
            else
              gi_kinetic_energy_density(1:der%mesh%np, is) = tau(1:der%mesh%np, is)
            end if
          end if

          if(st%d%ispin == SPINORS) then
            if(present(density_gradient)) then
              density_gradient(1:der%mesh%np, i_dim, 2) = density_gradient(1:der%mesh%np, i_dim, 2) + &
                   ww*M_TWO*real(conjg(wf_psi(1:der%mesh%np, 2))*gwf_psi(1:der%mesh%np, i_dim, 2))
              density_gradient(1:der%mesh%np, i_dim, 3) = density_gradient(1:der%mesh%np, i_dim, 3) + ww* &
                   real (gwf_psi(1:der%mesh%np, i_dim, 1)*conjg(wf_psi(1:der%mesh%np, 2)) + &
                   wf_psi(1:der%mesh%np, 1)*conjg(gwf_psi(1:der%mesh%np, i_dim, 2)))
              density_gradient(1:der%mesh%np, i_dim, 4) = density_gradient(1:der%mesh%np, i_dim, 4) + ww* &
                   aimag(gwf_psi(1:der%mesh%np, i_dim, 1)*conjg(wf_psi(1:der%mesh%np, 2)) + &
                   wf_psi(1:der%mesh%np, 1)*conjg(gwf_psi(1:der%mesh%np, i_dim, 2)))
            end if

            if(present(density_laplacian)) then
              density_laplacian(1:der%mesh%np, 2) = density_laplacian(1:der%mesh%np, 2)         + &
                   ww*M_TWO*real(conjg(gwf_psi(1:der%mesh%np, i_dim, 2))*gwf_psi(1:der%mesh%np, i_dim, 2))
              density_laplacian(1:der%mesh%np, 3) = density_laplacian(1:der%mesh%np, 3)         + &
                   ww*M_TWO*real (gwf_psi(1:der%mesh%np, i_dim, 1)*conjg(gwf_psi(1:der%mesh%np, i_dim, 2)))
              density_laplacian(1:der%mesh%np, 4) = density_laplacian(1:der%mesh%np, 4)         + &
                   ww*M_TWO*aimag(gwf_psi(1:der%mesh%np, i_dim, 1)*conjg(gwf_psi(1:der%mesh%np, i_dim, 2)))
            end if

            ! the expression for the paramagnetic current with spinors is
            !     j = ( jp(1)             jp(3) + i jp(4) )
            !         (-jp(3) + i jp(4)   jp(2)           )
            if(associated(jp)) then
              jp(1:der%mesh%np, i_dim, 2) = jp(1:der%mesh%np, i_dim, 2) + &
                   ww*aimag(conjg(wf_psi(1:der%mesh%np, 2))*gwf_psi(1:der%mesh%np, i_dim, 2))
              do ii = 1, der%mesh%np
                c_tmp = conjg(wf_psi(ii, 1))*gwf_psi(ii, i_dim, 2) - wf_psi(ii, 2)*conjg(gwf_psi(ii, i_dim, 1))
                jp(ii, i_dim, 3) = jp(ii, i_dim, 3) + ww* real(c_tmp)
                jp(ii, i_dim, 4) = jp(ii, i_dim, 4) + ww*aimag(c_tmp)
              end do
            end if

            ! the expression for the paramagnetic current with spinors is
            !     t = ( tau(1)              tau(3) + i tau(4) )
            !         ( tau(3) - i tau(4)   tau(2)            )
            if(associated(tau)) then
              tau (1:der%mesh%np, 2) = tau (1:der%mesh%np, 2) + ww*abs(gwf_psi(1:der%mesh%np, i_dim, 2))**2
              do ii = 1, der%mesh%np
                c_tmp = conjg(gwf_psi(ii, i_dim, 1))*gwf_psi(ii, i_dim, 2)
                tau(ii, 3) = tau(ii, 3) + ww* real(c_tmp)
                tau(ii, 4) = tau(ii, 4) + ww*aimag(c_tmp)
              end do
            end if

            ASSERT(.not. present(gi_kinetic_energy_density))

          end if !SPINORS

        end do

      end do
    end do

    SAFE_DEALLOCATE_A(wf_psi)
    SAFE_DEALLOCATE_A(gwf_psi)
    SAFE_DEALLOCATE_A(lwf_psi)

    if(.not. present(paramagnetic_current)) then
      SAFE_DEALLOCATE_P(jp)
    end if

    if(.not. present(kinetic_energy_density)) then
      SAFE_DEALLOCATE_P(tau)
    end if

    if(st%parallel_in_states .or. st%d%kpt%parallel) call reduce_all(st%st_kpt_mpi_grp)

    POP_SUB(states_calc_quantities)

  contains

    subroutine reduce_all(grp)
      type(mpi_grp_t), intent(in)  :: grp

      PUSH_SUB(states_calc_quantities.reduce_all)

      if(associated(tau)) call comm_allreduce(grp%comm, tau, dim = (/der%mesh%np, st%d%nspin/))

      if(present(gi_kinetic_energy_density)) &
        call comm_allreduce(grp%comm, gi_kinetic_energy_density, dim = (/der%mesh%np, st%d%nspin/))

      if (present(density_laplacian)) call comm_allreduce(grp%comm, density_laplacian, dim = (/der%mesh%np, st%d%nspin/))

      do is = 1, st%d%nspin
        if(associated(jp)) call comm_allreduce(grp%comm, jp(:, :, is), dim = (/der%mesh%np, der%mesh%sb%dim/))

        if(present(density_gradient)) &
          call comm_allreduce(grp%comm, density_gradient(:, :, is), dim = (/der%mesh%np, der%mesh%sb%dim/))
      end do

      POP_SUB(states_calc_quantities.reduce_all)
    end subroutine reduce_all

  end subroutine states_calc_quantities


  ! ---------------------------------------------------------
  function state_spin(mesh, f1) result(spin)
    type(mesh_t), intent(in) :: mesh
    CMPLX,        intent(in) :: f1(:, :)
    CMPLX                    :: spin(1:3)

    CMPLX :: z

    PUSH_SUB(state_spin)

    z = zmf_dotp(mesh, f1(:, 1) , f1(:, 2))

    spin(1) = M_TWO*z
    spin(2) = M_TWO*aimag(z)
    spin(3) = zmf_nrm2(mesh, f1(:, 1))**2 - zmf_nrm2(mesh, f1(:, 2))**2
    spin = M_HALF*spin ! spin is half the sigma matrix.

    POP_SUB(state_spin)
  end function state_spin

  ! ---------------------------------------------------------
  logical function state_is_local(st, ist)
    type(states_t), intent(in) :: st
    integer,        intent(in) :: ist

    PUSH_SUB(state_is_local)

    state_is_local = ist.ge.st%st_start.and.ist.le.st%st_end

    POP_SUB(state_is_local)
  end function state_is_local


  ! ---------------------------------------------------------

  real(8) function states_wfns_memory(st, mesh) result(memory)
    type(states_t), intent(in) :: st
    type(mesh_t),   intent(in) :: mesh

    PUSH_SUB(states_wfns_memory)
    memory = 0.0_8

    ! orbitals
    memory = memory + REAL_PRECISION*dble(mesh%np_part_global)*st%d%dim*dble(st%nst)*st%d%kpt%nglobal

    POP_SUB(states_wfns_memory)
  end function states_wfns_memory

  ! ---------------------------------------------------------

  subroutine states_blacs_blocksize(st, mesh, blocksize, total_np)
    type(states_t),  intent(in)    :: st
    type(mesh_t),    intent(in)    :: mesh
    integer,         intent(out)   :: blocksize(2)
    integer,         intent(out)   :: total_np

    PUSH_SUB(states_blacs_blocksize)

#ifdef HAVE_SCALAPACK
    ! We need to select the block size of the decomposition. This is
    ! tricky, since not all processors have the same number of
    ! points.
    !
    ! What we do for now is to use the maximum of the number of
    ! points and we set to zero the remaining points.

    if (mesh%parallel_in_domains) then
      blocksize(1) = maxval(mesh%vp%np_local) + &
        (st%d%dim - 1)*maxval(mesh%vp%np_local + mesh%vp%np_bndry + mesh%vp%np_ghost)
    else
      blocksize(1) = mesh%np + (st%d%dim - 1)*mesh%np_part
    end if

    if (st%parallel_in_states) then
      blocksize(2) = maxval(st%st_num)
    else
      blocksize(2) = st%nst
    end if

    total_np = blocksize(1)*st%dom_st_proc_grid%nprow


    ASSERT(st%d%dim*mesh%np_part >= blocksize(1))
#else
    blocksize(1) = 0
    blocksize(2) = 0
    total_np = 0
#endif

    POP_SUB(states_blacs_blocksize)
  end subroutine states_blacs_blocksize

  ! ------------------------------------------------------------

  subroutine states_pack(st, copy)
    type(states_t),    intent(inout) :: st
    logical, optional, intent(in)    :: copy

    integer :: iqn, ib
    integer(8) :: max_mem, mem
#ifdef HAVE_OPENCL
    FLOAT, parameter :: mem_frac = 0.75
#endif

    PUSH_SUB(states_pack)

    ASSERT(.not. st%packed)

    st%packed = .true.

    if(opencl_is_enabled()) then
#ifdef HAVE_OPENCL
      call clGetDeviceInfo(opencl%device, CL_DEVICE_GLOBAL_MEM_SIZE, max_mem, cl_status)
#endif
      if(st%d%cl_states_mem > CNST(1.0)) then
        max_mem = int(st%d%cl_states_mem, 8)*(1024_8)**2
      else if(st%d%cl_states_mem < CNST(0.0)) then
        max_mem = max_mem + int(st%d%cl_states_mem, 8)*(1024_8)**2
      else
        max_mem = int(st%d%cl_states_mem*real(max_mem, REAL_PRECISION), 8)
      end if
    else
      max_mem = HUGE(max_mem)
    end if

    mem = 0
    qnloop: do iqn = st%d%kpt%start, st%d%kpt%end
      do ib = st%block_start, st%block_end

        mem = mem + batch_pack_size(st%psib(ib, iqn))

        if(mem > max_mem) then
          call messages_write('Not enough CL device memory to store all states simultaneously.', new_line = .true.)
          call messages_write('Only ')
          call messages_write(ib - st%block_start)
          call messages_write(' of ')
          call messages_write(st%block_end - st%block_start + 1)
          call messages_write(' blocks will be stored in device memory.', new_line = .true.)
          call messages_warning()
          exit qnloop
        end if
        
        call batch_pack(st%psib(ib, iqn), copy)
      end do
    end do qnloop

    POP_SUB(states_pack)
  end subroutine states_pack

  ! ------------------------------------------------------------

  subroutine states_unpack(st, copy)
    type(states_t),    intent(inout) :: st
    logical, optional, intent(in)    :: copy

    integer :: iqn, ib

    PUSH_SUB(states_unpack)

    ASSERT(st%packed)

    st%packed = .false.

    do iqn = st%d%kpt%start, st%d%kpt%end
      do ib = st%block_start, st%block_end
        if(batch_is_packed(st%psib(ib, iqn))) call batch_unpack(st%psib(ib, iqn), copy)
      end do
    end do

    POP_SUB(states_unpack)
  end subroutine states_unpack

  ! ------------------------------------------------------------

  subroutine states_sync(st)
    type(states_t),    intent(inout) :: st

    integer :: iqn, ib

    PUSH_SUB(states_sync)

    if(states_are_packed(st)) then

      do iqn = st%d%kpt%start, st%d%kpt%end
        do ib = st%block_start, st%block_end
          call batch_sync(st%psib(ib, iqn))
        end do
      end do

    end if

    POP_SUB(states_sync)
  end subroutine states_sync

  ! -----------------------------------------------------------

  subroutine states_write_info(st)
    type(states_t),    intent(in) :: st

    PUSH_SUB(states_write_info)

    call messages_print_stress(stdout, "States")

    write(message(1), '(a,f12.3)') 'Total electronic charge  = ', st%qtot
    write(message(2), '(a,i8)')    'Number of states         = ', st%nst
    write(message(3), '(a,i8)')    'States block-size        = ', st%d%block_size
    call messages_info(3)

    call messages_print_stress(stdout)

    POP_SUB(states_write_info)
  end subroutine states_write_info
 
  ! -----------------------------------------------------------

  logical pure function states_are_packed(st) result(packed)
    type(states_t),    intent(in) :: st

    packed = st%packed
  end function states_are_packed

  ! ------------------------------------------------------------

  subroutine states_set_zero(st)
    type(states_t),    intent(inout) :: st

    integer :: iqn, ib

    PUSH_SUB(states_set_zero)

    do iqn = st%d%kpt%start, st%d%kpt%end
      do ib = st%block_start, st%block_end
        call batch_set_zero(st%psib(ib, iqn))
      end do
    end do
    
    POP_SUB(states_set_zero)
  end subroutine states_set_zero

  ! ------------------------------------------------------------

  integer pure function states_block_min(st, ib) result(range)
    type(states_t),    intent(in) :: st
    integer,           intent(in) :: ib
    
    range = st%block_range(ib, 1)
  end function states_block_min

  ! ------------------------------------------------------------

  integer pure function states_block_max(st, ib) result(range)
    type(states_t),    intent(in) :: st
    integer,           intent(in) :: ib
    
    range = st%block_range(ib, 2)
  end function states_block_max

  ! ------------------------------------------------------------

  integer pure function states_block_size(st, ib) result(size)
    type(states_t),    intent(in) :: st
    integer,           intent(in) :: ib
    
    size = st%block_size(ib)
  end function states_block_size

#include "undef.F90"
#include "real.F90"
#include "states_inc.F90"

#include "undef.F90"
#include "complex.F90"
#include "states_inc.F90"
#include "undef.F90"

end module states_m


!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
