!! Copyright (C) 2011-2012 M. Oliveira
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!
!! $Id: states_batch.F90 784 2013-08-02 15:21:03Z micael $

#include "global.h"

module states_batch_m
  use global_m
  use oct_parser_m
  use messages_m
  use units_m
  use io_m
  use output_m
  use quantum_numbers_m
  use mesh_m
  use ps_io_m
  use potentials_m
  use wave_equations_m
  use eigensolver_m
  use states_m
  implicit none


                    !---Interfaces---!

  interface assignment (=)
    module procedure states_batch_copy
  end interface


                    !---Derived Data Types---!

  type states_batch_t
    private
    integer :: n_states                   ! Number of states in the batch
    type(state_ptr), pointer :: states(:) ! Pointers to the states
  end type states_batch_t

  type state_ptr
    type(state_t), pointer :: ptr ! Pointer to the state
  end type state_ptr


                    !---Global Variables---!

  integer, parameter :: SORT_EV = 1, &
                        SORT_QN = 2

  integer, parameter :: OCC_FIXED           = 0, &
                        OCC_SEMICONDUCTING  = 1, &
                        OCC_AVERILL_PAINTER = 2


                    !---Public/Private Statements---!

  private
  public :: states_batch_t, &
            states_batch_null, &
            states_batch_init, &
            states_batch_end, &
            states_batch_deallocate, &
            assignment(=), &
            states_batch_size, &
            states_batch_add, &
            states_batch_get, &
            states_batch_number_of_folds, &
            states_batch_eigenvalues, &
            states_batch_split_folds, &
            states_batch_density, &
            states_batch_density_grad, &
            states_batch_density_lapl, &
            states_batch_charge_density, &
            states_batch_magnetization_density, &
            states_batch_tau, &
            states_batch_density_moment, &
            states_batch_charge, &
            states_batch_max_charge, &
            states_batch_ip, &
            states_batch_eigensolve, &
            states_batch_sort, &
            states_batch_smearing, &
            states_batch_psp_generation, &
            states_batch_ld_test, &
            states_batch_output_configuration, &
            states_batch_output_eigenvalues, &
            states_batch_output_density, &
            states_batch_ps_io_set, &
            OCC_FIXED, OCC_SEMICONDUCTING, OCC_AVERILL_PAINTER, &
            SORT_EV, SORT_QN


contains

  subroutine states_batch_null(batch)
    !-----------------------------------------------------------------------!
    !  Nullifies and sets to zero all the components of the batch.          !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(out) :: batch

    call push_sub("states_batch_null")

    batch%n_states = 0
    nullify(batch%states)

    call pop_sub()
  end subroutine states_batch_null

  subroutine states_batch_init(batch, n_states, states)
    !-----------------------------------------------------------------------!
    ! Initializes a batch from an array of states.                          !
    !                                                                       !
    !  batch    - the batch to be initialized                               !
    !  n_states - the number of states                                      !
    !  states   - the array of states                                       !
    !-----------------------------------------------------------------------!
    type(states_batch_t),  intent(inout) :: batch
    integer,               intent(in)    :: n_states
    type(state_t), target, intent(in)    :: states(:)

    integer :: i

    call push_sub("states_batch_init")

    batch%n_states = n_states
    allocate(batch%states(n_states))
    do i = 1, n_states
      batch%states(i)%ptr => states(i)
    end do

    call pop_sub()
  end subroutine states_batch_init

  subroutine states_batch_end(batch)
    !-----------------------------------------------------------------------!
    ! Frees all memory associated to a batch                                !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(inout) :: batch

    call push_sub("states_batch_end")

    if (associated(batch%states)) deallocate(batch%states)
    batch%n_states = 0

    call pop_sub()
  end subroutine states_batch_end

  subroutine states_batch_deallocate(batch)
    !-----------------------------------------------------------------------!
    ! Frees all memory associated to the states that are pointed to by the  !
    ! batch and then frees the memory associated to the batch object itself.!
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(inout) :: batch

    integer :: i

    call push_sub("states_batch_deallocate")

    if (associated(batch%states)) then
      do i = 1, batch%n_states
        call state_end(batch%states(i)%ptr)
      end do
      deallocate(batch%states)
    end if
    batch%n_states = 0
      
    call pop_sub()
  end subroutine states_batch_deallocate

  subroutine states_batch_copy(batch_a, batch_b)
    !-----------------------------------------------------------------------!
    ! Copy batch_b to batch_a                                               !
    !-----------------------------------------------------------------------!
    type(states_batch_t),  intent(inout) :: batch_a
    type(states_batch_t),  intent(in)    :: batch_b

    integer :: i

    call push_sub("states_batch_copy")

    call states_batch_end(batch_a)

    batch_a%n_states = batch_b%n_states
    allocate(batch_a%states(batch_a%n_states))
    do i = 1, batch_a%n_states
      batch_a%states(i)%ptr => batch_b%states(i)%ptr
    end do

    call pop_sub()
  end subroutine states_batch_copy

  function states_batch_size(batch)
    !-----------------------------------------------------------------------!
    ! Return the number of states in the batch.                             !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    integer :: states_batch_size

    call push_sub("states_batch_size")

    states_batch_size = batch%n_states

    call pop_sub()
  end function states_batch_size

  subroutine states_batch_add(batch, state)
    !-----------------------------------------------------------------------!
    ! Add a state to a bacth.                                               !
    !                                                                       !
    !  batch - the batch                                                    !
    !  state - the state to be added to the batch                           !
    !-----------------------------------------------------------------------!
    type(states_batch_t),  intent(inout) :: batch
    type(state_t), target, intent(in)    :: state

    type(state_ptr), allocatable :: tmp(:)

    call push_sub("states_batch_add")

    if (batch%n_states == 0) then
      batch%n_states = 1
      allocate(batch%states(batch%n_states))
    else
      allocate(tmp(batch%n_states))
      tmp = batch%states
      deallocate(batch%states)
      allocate(batch%states(batch%n_states + 1))
      batch%states(1:batch%n_states) = tmp(1:batch%n_states)
      batch%n_states = batch%n_states + 1
    end if
    batch%states(batch%n_states)%ptr => state

    call pop_sub()
  end subroutine states_batch_add

  function states_batch_get(batch, i)
    !-----------------------------------------------------------------------!
    ! Returns a pointer to the ith state of the batch.                      !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: i
    type(state_t), pointer :: states_batch_get

    call push_sub("states_batch_get")

    ASSERT(i > 0 .and. i <= batch%n_states)
    states_batch_get => batch%states(i)%ptr

    call pop_sub()
  end function states_batch_get

  function states_batch_eigenvalues(batch)
    !-----------------------------------------------------------------------!
    ! Returns a pointer to the ith state of the batch.                      !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    real(R8) :: states_batch_eigenvalues(batch%n_states)

    integer :: i

    call push_sub("states_batch_eigenvalues")

    do i = 1, batch%n_states
      states_batch_eigenvalues(i) = state_eigenvalue(batch%states(i)%ptr)
    end do

    call pop_sub()
  end function states_batch_eigenvalues

  function states_batch_number_of_folds(batch, polarized)
    !-----------------------------------------------------------------------!
    ! Given a batch, return the number of quantum number folds that exist   !
    ! in the bacth (see quantum_numbers_m for the definition of a fold).    !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    logical,              intent(in) :: polarized  
    integer :: states_batch_number_of_folds

    integer :: i
    type(qn_t) :: qn
    logical, allocatable :: checked(:)

    call push_sub("states_batch_number_of_folds")

    states_batch_number_of_folds = 0

    if (batch%n_states == 0) return

    allocate(checked(batch%n_states))
    checked = .false.
    do
      if (all(checked)) exit

      do i = 1, batch%n_states
        if (.not. checked(i)) then
          states_batch_number_of_folds = states_batch_number_of_folds + 1
          qn = state_qn(batch%states(i)%ptr)
          exit
        end if
      end do

      do i = 1, batch%n_states
        if (.not. checked(i) .and. qn_equal_fold(qn, state_qn(batch%states(i)%ptr), polarized)) then
          checked(i) = .true.
        end if
      end do
    end do

    ASSERT(all(checked))
    deallocate(checked)

    call pop_sub()
  end function states_batch_number_of_folds

  function states_batch_density(batch, nspin, m)
    !-----------------------------------------------------------------------!
    ! Computes the electronic density associated with a batch of states.    !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    type(mesh_t),         intent(in) :: m
    real(R8) :: states_batch_density(m%np, nspin)

    integer :: i

    call push_sub("states_batch_density")

    states_batch_density = M_ZERO
    do i = 1, batch%n_states
      states_batch_density = states_batch_density + &
                             state_density(nspin, batch%states(i)%ptr)
    end do

    call pop_sub()
  end function states_batch_density

  function states_batch_density_grad(batch, nspin, m)
    !-----------------------------------------------------------------------!
    ! Computes the gradient of the electronic density associated with a     !
    ! batch of states.                                                      !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    type(mesh_t),         intent(in) :: m
    real(R8) :: states_batch_density_grad(m%np, nspin)

    integer :: i

    call push_sub("states_batch_density_grad")

    states_batch_density_grad = M_ZERO
    do i = 1, batch%n_states
      states_batch_density_grad = states_batch_density_grad + &
                               state_density_grad(nspin, batch%states(i)%ptr, m)
    end do

    call pop_sub()
  end function states_batch_density_grad

  function states_batch_density_lapl(batch, nspin, m)
    !-----------------------------------------------------------------------!
    ! Computes the laplacian of the electronic density associated with a    !
    ! batch of states.                                                      !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    type(mesh_t),         intent(in) :: m
    real(R8) :: states_batch_density_lapl(m%np, nspin)

    integer :: i

    call push_sub("states_batch_density_lapl")

    states_batch_density_lapl = M_ZERO
    do i = 1, batch%n_states
      states_batch_density_lapl = states_batch_density_lapl + &
                               state_density_lapl(nspin, batch%states(i)%ptr, m)
    end do

    call pop_sub()
  end function states_batch_density_lapl

  function states_batch_charge_density(batch, m)
    !-----------------------------------------------------------------------!
    ! Computes the charge density associated with a batch of states.        !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    type(mesh_t),         intent(in) :: m
    real(R8) :: states_batch_charge_density(m%np)

    integer :: i

    call push_sub("states_batch_charge_density")

    states_batch_charge_density = M_ZERO
    do i = 1, batch%n_states
      states_batch_charge_density = states_batch_charge_density + &
                             state_charge_density(batch%states(i)%ptr)
    end do

    call pop_sub()
  end function states_batch_charge_density

  function states_batch_magnetization_density(batch, m)
    !-----------------------------------------------------------------------!
    ! Computes the magnetization density associated with a batch of states. !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    type(mesh_t),         intent(in) :: m
    real(R8) :: states_batch_magnetization_density(m%np)

    integer :: i

    call push_sub("states_batch_magnetization_density")

    states_batch_magnetization_density = M_ZERO
    do i = 1, batch%n_states
      states_batch_magnetization_density = states_batch_magnetization_density + &
                             state_magnetization_density(batch%states(i)%ptr)
    end do

    call pop_sub()
  end function states_batch_magnetization_density

  function states_batch_tau(batch, nspin, m)
    !-----------------------------------------------------------------------!
    ! Computes the kinetic energy density associated with a batch of states.!
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    type(mesh_t),         intent(in) :: m
    real(R8) :: states_batch_tau(m%np, nspin)

    integer :: i

    call push_sub("states_batch_tau")

    states_batch_tau = M_ZERO
    do i = 1, batch%n_states
      states_batch_tau = states_batch_tau + &
                               state_tau(nspin, batch%states(i)%ptr, m)
    end do

    call pop_sub()
  end function states_batch_tau

  function states_batch_density_moment(batch, m, order)
    !-----------------------------------------------------------------------!
    ! Computes the density moment sum_i < R_i(r) | r**order | R_i(r)> of a  !
    ! given batch of states.                                                !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    type(mesh_t),         intent(in) :: m
    integer,              intent(in) :: order
    real(R8) :: states_batch_density_moment

    integer :: i

    call push_sub("states_batch_density_moment")

    states_batch_density_moment = M_ZERO
    do i = 1, batch%n_states
      states_batch_density_moment = states_batch_density_moment + &
                             state_density_moment(batch%states(i)%ptr, m, order)
    end do

  end function states_batch_density_moment

  subroutine states_batch_split_folds(batch, folds, polarized)
    !-----------------------------------------------------------------------!
    ! Given a batch, splits the states into folds (see quantum_numbers_m    !
    ! for the definition of a fold).                                        !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in)    :: batch
    type(states_batch_t), intent(inout) :: folds(:)
    logical,              intent(in)    :: polarized

    integer :: i, n
    type(qn_t) :: qn
    logical, allocatable :: checked(:)

    call push_sub("states_batch_split_folds")

    ASSERT(batch%n_states > 0)

    n = 0
    allocate(checked(batch%n_states))
    checked = .false.
    do
      if (all(checked)) exit
      n = n + 1
      call states_batch_end(folds(n))

      do i = 1, batch%n_states
        if (.not. checked(i)) then
          qn = state_qn(batch%states(i)%ptr)
          exit
        end if
      end do

      do i = 1, batch%n_states
        if (.not. checked(i) .and. qn_equal_fold(qn, state_qn(batch%states(i)%ptr), polarized)) then
          checked(i) = .true.
          call states_batch_add(folds(n), batch%states(i)%ptr)
        end if
      end do
    end do

    ASSERT(all(checked))
    deallocate(checked)

    call pop_sub()
  end subroutine states_batch_split_folds

  function states_batch_charge(batch)
    !-----------------------------------------------------------------------!
    ! Returns the number of electons in a batch.                            !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    real(R8) :: states_batch_charge

    integer :: i

    call push_sub("states_batch_charge")

    states_batch_charge = M_ZERO
    do i = 1, batch%n_states
      states_batch_charge = states_batch_charge + state_charge(batch%states(i)%ptr)
    end do

    call pop_sub()
  end function states_batch_charge

  function states_batch_max_charge(batch)
    !-----------------------------------------------------------------------!
    ! Returns the maximum number of electons the batch can hold.            !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    real(R8) :: states_batch_max_charge

    integer :: i

    call push_sub("states_batch_max_charge")

    states_batch_max_charge = M_ZERO
    do i = 1, batch%n_states
      states_batch_max_charge = states_batch_max_charge + qn_max_occ(state_qn(batch%states(i)%ptr))
    end do

    call pop_sub()

  end function states_batch_max_charge

  function states_batch_ip(batch, nspin)
    !-----------------------------------------------------------------------!
    ! Returns the ionization potential of a batch (defined as minus the     !
    ! eigenvalue of the highest occupied orbital) for each spin-channel.    !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    real(R8) :: states_batch_ip(nspin)

    integer :: i, is
    real(R8) :: ev, charge
    type(qn_t) :: qn

    call push_sub("states_batch_ip")

    !Get ionization potential
    states_batch_ip = M_ZERO
    do i = 1, batch%n_states
      qn = state_qn(batch%states(i)%ptr)
      ev = state_eigenvalue(batch%states(i)%ptr)
      charge = state_charge(batch%states(i)%ptr)

      is = 1
      if (qn%s == M_HALF) is = 2
      
      if (charge /= M_ZERO) then
        if (states_batch_ip(is) == M_ZERO) then
          states_batch_ip(is) =  - ev
        elseif (-states_batch_ip(is) < ev) then
          states_batch_ip(is) = - ev
        end if
      end if
    end do

    call pop_sub()
  end function states_batch_ip

  subroutine states_batch_eigensolve(batch, m, wave_eq, potential, integrator_dp, integrator_sp, eigensolver)
    !-----------------------------------------------------------------------!
    ! Get the eigenvalues and eigenfunctions of a batch of states for a     !
    ! given potential.                                                      !
    !                                                                       !
    !  batch         - bacth of states                                      !
    !  m             - mesh                                                 !
    !  wave_eq       - wave-equation to solve                               !
    !  potential     - potential to use in the wave-equation                !
    !  integrator_sp - single-precision integrator object                   !
    !  integrator_dp - double-precision integrator object                   !
    !  eigensolver   - information about the eigensolver                    !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(inout) :: batch
    type(mesh_t),         intent(in)    :: m
    integer,              intent(in)    :: wave_eq
    type(potential_t),    intent(in)    :: potential
    type(integrator_t),   intent(inout) :: integrator_dp
    type(integrator_t),   intent(inout) :: integrator_sp
    type(eigensolver_t),  intent(in)    :: eigensolver

    integer :: i, max_n_bound
    real(R8) :: ev
    type(qn_t) :: qn
    type(states_batch_t) :: bound
    logical, allocatable :: bracketed(:)
    real(R8), allocatable :: brackets(:,:)
    type(qn_t), allocatable :: qns(:)

    call push_sub("states_batch_eigensolve")

    !Batch should be ordered by increasing quantum numbers
    call states_batch_sort(batch, SORT_QN)

    !Check the maximum number of bound states we can have
    qn = state_qn(batch%states(1)%ptr)
    max_n_bound = wavefunctions_n_bound_states(qn, wave_eq, m, &
                                           potential, integrator_sp)
    !Remove states that we are sure are unbound from the list
    call states_batch_null(bound)
    do i = 1, batch%n_states
      qn = state_qn(batch%states(i)%ptr)

      if (qn%n - qn%l <= max_n_bound) then
        call states_batch_add(bound, batch%states(i)%ptr)
      else
        if (state_charge(batch%states(i)%ptr) == M_ZERO) then
          !Unbound state
          call state_update(batch%states(i)%ptr, m, wave_eq, potential, integrator_dp, ev=M_ZERO)
        else
          !There is something wrong
          write(message(1),'("State: ",A," is unbound")') &
               trim(state_label(batch%states(i)%ptr, .true.))
          call write_fatal(1)
        end if
      end if
    end do

    !Bracket the eigenvalues
    allocate(brackets(2, bound%n_states), qns(bound%n_states), bracketed(bound%n_states))
    do i = 1, bound%n_states
      qns(i) = state_qn(bound%states(i)%ptr)
    end do
    call eigensolver_bracket(bound%n_states, qns, wave_eq, eigensolver, &
                             potential, integrator_sp, brackets, bracketed)

    !Find eigenvalues
    do i = 1, bound%n_states
      if (bracketed(i)) then
        call eigensolver_find_ev(qns(i), wave_eq, eigensolver, potential, &
                                 integrator_dp, brackets(:,i), ev)
      else if (.not. bracketed(i) .and. state_charge(bound%states(i)%ptr) == M_ZERO) then
        !This is probably and unbound state after all
        ev = M_ZERO
      else if (.not. bracketed(i) .and. state_charge(bound%states(i)%ptr) /= M_ZERO) then
        !There is something wrong
        message(1) = "Unable to bracket eigenvalues for state:"
        write(message(2),'(2X,A)') trim(qn_label(qns(i), .true.))
        call write_fatal(2)
      end if


      !Update the state with the new eigenvalue
      call state_update(bound%states(i)%ptr, m, wave_eq, potential, integrator_dp, ev)
    end do

    !Free memory
    deallocate(brackets, bracketed, qns)
    call states_batch_end(bound)

    call pop_sub()
  end subroutine states_batch_eigensolve

  subroutine states_batch_sort(batch, criteria)
    !-----------------------------------------------------------------------!
    ! Sort the states in the bacth.                                         !
    !                                                                       !
    !  batch    - bacth of states                                           !
    !  criteria - criteria used to sort the states                          !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(inout) :: batch
    integer,              intent(in)    :: criteria

    integer :: i, dum(1)
    logical, allocatable :: mask(:), qn_mask(:)
    integer, allocatable :: order(:)
    real(R8), allocatable :: ev(:)
    type(qn_t), allocatable :: qn(:)
    type(state_ptr), allocatable :: tmp(:)

    call push_sub("states_batch_sort")

    allocate(mask(batch%n_states), order(batch%n_states), tmp(batch%n_states))

    !sort states
    mask = .true.
    select case (criteria)
    case (SORT_EV)
      allocate(ev(batch%n_states))
      do i = 1, batch%n_states
        ev(i) = state_eigenvalue(batch%states(i)%ptr)
      end do

      do i = 1, batch%n_states
        dum = maxloc(abs(ev), mask=mask)
        order(i) = dum(1)
        mask(order(i)) = .false.
      end do
      deallocate(ev)

    case (SORT_QN)
      !Quantum numbers by increasing n,l,j,m,s and decreasing sg
      !(the algorithm is not very efficient, but we should never have to deal
      ! with a very large number of states)
      allocate(qn(batch%n_states))
      do i = 1, batch%n_states
        qn(i) = state_qn(batch%states(i)%ptr)
      end do

      allocate(qn_mask(batch%n_states))
      do i = 1, batch%n_states
        qn_mask = qn%n == minval(qn%n, mask)
        qn_mask = qn_mask .and. qn%l == minval(qn%l, mask .and. qn_mask)
        qn_mask = qn_mask .and. qn%j == minval(qn%j, mask .and. qn_mask)
        qn_mask = qn_mask .and. qn%m == minval(qn%m, mask .and. qn_mask)
        qn_mask = qn_mask .and. qn%s == minval(qn%s, mask .and. qn_mask)
        dum = maxloc(qn%sg, mask = mask .and. qn_mask)
        order(i) = dum(1)
        mask(order(i)) = .false.
      end do
      deallocate(qn, qn_mask)

    end select

    tmp = batch%states
    do i = 1, batch%n_states
      batch%states(i) = tmp(order(i))
    end do

    deallocate(mask, order, tmp)

    call pop_sub()
  end subroutine states_batch_sort

  subroutine states_batch_smearing(batch, smearing_function, tol, new_charge)
    !-----------------------------------------------------------------------!
    ! Distributes the electrons among the states.                           !
    !                                                                       !
    !  smearing function - how should the electrons be distributed          !
    !  batch             - bacth of states                                  !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(inout) :: batch
    integer,              intent(in)    :: smearing_function
    real(R8),             intent(in)    :: tol
    real(R8), optional,   intent(in)    :: new_charge

    integer :: n_states, is, i, qi, qf, k
    real(R8) :: total_charge, lambda, lambda_in, sum_evdiff
    integer,  allocatable :: order(:)
    real(R8), allocatable :: ev(:), occ(:), new_occ(:), max_occ(:), gamma(:)

    call push_sub("states_batch_smearing")

    if (smearing_function == OCC_FIXED) then
      ! Nothing to do
      call pop_sub()
      return
    end if

    n_states = batch%n_states

    !Sort states by increasing eigenvalue
    call states_batch_sort(batch, SORT_EV)

    !Allocate working arrays
    allocate(ev(batch%n_states), order(n_states))
    allocate(occ(n_states), max_occ(n_states), new_occ(n_states))

    !eigenvalues, occupancies, and maximum occupancies for all the states in a single vector
    do i = 1, n_states
      ev(i) = state_eigenvalue(batch%states(i)%ptr)
      occ(i) = state_charge(batch%states(i)%ptr)
      max_occ(i) = qn_max_occ(state_qn(batch%states(i)%ptr))
      order(i) = i
    end do

    !Order all arrays by decreasing occupancy
    if (n_states /= 1) then
      i = 1
      do
        i = i + 1
        if (ev(i-1) == ev(i) .and. occ(order(i-1)) < occ(order(i))) then
          is = order(i)
          order(i) = order(i-1)
          order(i-1) = is
          i = 1
        end if
        if (i == n_states) exit
      end do
    end if

    do i = 1, n_states
      new_occ(i) = occ(order(i))
    end do
    occ = new_occ
    do i = 1, n_states
      new_occ(i) = max_occ(order(i))
    end do
    max_occ = new_occ

    !
    select case (smearing_function)
    case (OCC_SEMICONDUCTING)

      if (present(new_charge)) then
        total_charge = new_charge
      else
        total_charge = states_batch_charge(batch)
      end if

      do i = 1, n_states
        if (total_charge == M_ZERO) then
          occ(i) = M_ZERO
        elseif (total_charge - max_occ(i) < M_ZERO) then
          occ(i) = total_charge
        else
          occ(i) = max_occ(i)
        end if
        total_charge = total_charge - occ(i)
      end do

    case (OCC_AVERILL_PAINTER)

      !Allocate memory and initialize some quantities
      allocate(gamma(n_states))
      lambda_in = M_HALF
      gamma = M_ONE
      total_charge = states_batch_charge(batch)

      !Remove states that will not participate in the charge transfer
      do i = 1, n_states
        if (occ(i) /= max_occ(i)) then
          qi = i
          exit
        end if
      end do
      qf = qi
      do i = n_states, qi, -1
        if (occ(i) /= M_ZERO) then
          qf = i
          exit
        end if
      end do

      !Main loop
      do
        !Exit if we have less than two states to perform the charge transfer
        if (count(gamma(qi:qf) /= M_ZERO) < 2) exit

        !By default we dont change the occupancies
        new_occ = occ

        !Choose the reference state
        if (occ(qi)/max_occ(qi) < M_HALF) then
          k = qi
        else
          k = qf
        end if

        !Compute gamma
        sum_evdiff = M_ZERO
        do i = qi, qf
          if (i == k) cycle
          if (abs(ev(i) - ev(k)) < tol) then
            gamma(i) = M_ZERO
          else if (ev(i) - ev(k) > M_ZERO) then
            gamma(i) = occ(i)/max_occ(i)/(ev(i) - ev(k))
          elseif (ev(i) - ev(k) < M_ZERO) then
            gamma(i) = (occ(i)/max_occ(i) - M_ONE)/(ev(i) - ev(k))
          end if
          if (gamma(i) /= M_ZERO) sum_evdiff = sum_evdiff + (ev(i) - ev(k))*max_occ(i)
        end do

        if (sum_evdiff > M_ZERO) then
          gamma(k) = (M_ONE - occ(k)/max_occ(k))/sum_evdiff
        elseif (sum_evdiff < M_ZERO) then
          gamma(k) = -occ(k)/max_occ(k)/sum_evdiff
        else
          gamma(k) = M_ZERO
        end if

        !If gamma(k) is zero remove this state from the list and go back to the
        !beginning of the cycle
        if (abs(gamma(k)) == M_ZERO) then
          if (k == qi) then
            qi = qi + 1
          else
            qf = qf - 1
          end if
          cycle
        end if

        !Get lambda (ignore the states that have gamma = 0)
        lambda = min(lambda_in, minval(gamma(qi:qf), mask=gamma(qi:qf) /= M_ZERO))

        !Get the new occupancies (states with gamma = 0 are ignored)
        sum_evdiff = M_ZERO
        do i = qi, qf
          if (i /= k .and. gamma(i) /= M_ZERO) then

            new_occ(i) = occ(i) - lambda*(ev(i) - ev(k))*max_occ(i)

            if (new_occ(i) < M_EPSILON) then
              new_occ(i) = M_ZERO
              sum_evdiff = sum_evdiff + occ(i)/lambda

            else if ( abs(max_occ(i) - new_occ(i)) < M_EPSILON ) then
              new_occ(i) = max_occ(i)
              sum_evdiff = sum_evdiff + (occ(i) - max_occ(i))/lambda

            else
              sum_evdiff = sum_evdiff + (ev(i) - ev(k))*max_occ(i)

            end if

          end if
        end do

        new_occ(k) = occ(k) + lambda*sum_evdiff
        if (new_occ(k) < M_EPSILON) then
          new_occ(k) = M_ZERO
        else if (max_occ(k) - new_occ(k) < M_EPSILON) then
          new_occ(k) = max_occ(k)
        end if

        occ = new_occ

        !Check if we have reached convergence
        if (lambda == lambda_in) exit

        lambda_in = lambda
      end do

      deallocate(gamma)
    end select

    !Get the final occupancies in the correct order
    do i = 1, n_states 
      batch%states(order(i))%ptr%occ = occ(i)
    end do

    !Deallocate memory
    deallocate(ev, order, occ, max_occ, new_occ)

    call pop_sub()
  end subroutine states_batch_smearing

  subroutine states_batch_psp_generation(batch, m, scheme, wave_eq, tol, ae_potential, &
       integrator_sp, integrator_dp, eigensolver, rc, ps_v)
    !-----------------------------------------------------------------------!
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(inout) :: batch
    type(mesh_t),         intent(in)    :: m
    integer,              intent(in)    :: scheme, wave_eq
    real(R8),             intent(in)    :: tol
    type(potential_t),    intent(in)    :: ae_potential
    type(integrator_t),   intent(inout) :: integrator_sp, integrator_dp
    type(eigensolver_t),  intent(in)    :: eigensolver
    real(R8),             intent(in)    :: rc
    real(R8),             intent(out)   :: ps_v(m%np)

    integer :: n, i
    type(state_t), pointer :: state, state2
    type(potential_t) :: ps_potential
    type(states_batch_t) :: other_states

    call push_sub("states_batch_psp_generation")

    state => states_batch_get(batch, 1)
    if (scheme == MRPP .or. scheme == RMRPP) then
      n = 2
      state2 => states_batch_get(batch, 2)
      call state_psp_generation(m, scheme, wave_eq, tol, ae_potential, &
           integrator_sp, integrator_dp, ps_v, state, rc, state2)
    else
      n = 1
      call state_psp_generation(m, scheme, wave_eq, tol, ae_potential, &
           integrator_sp, integrator_dp, ps_v, state, rc)
    end if
    

    !We need to solve the wave-equation with the pseudopotential for the other states in the batch
    if (n < states_batch_size(batch)) then
      call states_batch_null(other_states)
      call potential_null(ps_potential)
      call potential_init(ps_potential, m, ps_v)

      do i = n+1, states_batch_size(batch)
        call states_batch_add(other_states, states_batch_get(batch, i))
      end do

      call states_batch_eigensolve(other_states, m, wave_eq, ps_potential, integrator_dp, integrator_sp, eigensolver)

      call potential_end(ps_potential)
      call states_batch_end(other_states)
    end if

    call pop_sub()
  end subroutine states_batch_psp_generation

  subroutine states_batch_ld_test(batch, m, ae_potential, ps_potential, integrator, r, de, emin, emax)
    !-----------------------------------------------------------------------!
    ! Computes the logarithmic derivatives for a given fold of states.      !
    !                                                                       !
    !  batch        - batch of states (they should belong to the same fold) !
    !  m            - mesh                                                  !
    !  ae_potential - all-electron potential                                !
    !  ps_potential - pseudotential                                         !
    !  integrator   - integrator object                                     !
    !  r            - diagnostic radius at which to do the test             !
    !  de           - energy step                                           !
    !  emin         - lower bound of the energy interval                    !
    !  emax         - upper bound of the energy intervel                    !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    type(mesh_t),         intent(in) :: m
    type(potential_t),    intent(in) :: ae_potential
    type(potential_t),    intent(in) :: ps_potential
    type(integrator_t),   intent(inout) :: integrator
    real(R8),             intent(in) :: r
    real(R8),             intent(in) :: de
    real(R8), optional,   intent(in) :: emin, emax

    integer :: unit
    real(R8) :: e, e1, e2, ae_ld, ps_ld, step, dldde
    character(len=10) :: label

    call push_sub("states_batch_ld_test")

    ASSERT(batch%n_states > 0)

    !Set energy interval
    if (present(emin)) then
      e1 = emin
    else
      e1 = minval(states_batch_eigenvalues(batch)) - M_ONE
    end if
    if (present(emax)) then
      e2 = emax
    else
      e2 = maxval(states_batch_eigenvalues(batch)) + M_ONE
    end if
    if (e1 > e2) then
      message(1) = "The minimum energy at which to calculate the logarithmic"
      message(2) = "derivatives must be smaller than the maximum energy."
      call write_fatal(2)
    end if

    !Open file
    label = state_label(batch%states(1)%ptr)
    call io_open(unit, 'tests/ld-'//trim(label(2:5)))

    !Write some information
    write(message(1), '(2X,"Computing logarithmic derivative for states:",1X,A)') &
         trim(label(2:5))
    write(message(2), '(4X,"Minimum energy: ",F6.3,1X,A)') &
         e1/units_out%energy%factor, trim(units_out%energy%abbrev)
    write(message(3), '(4X,"Maximum energy: ",F6.3,1X,A)') &
         e2/units_out%energy%factor, trim(units_out%energy%abbrev)
    call write_info(3)
    call write_info(3,unit=info_unit("tests"))

    !Write header
    write(unit,'("# Logarithmic derivatives")')
    write(unit,'("# ")')
    write(unit,'("# Energy units: ",A)') trim(units_out%energy%name)
    write(unit,'("# Length units: ",A)') trim(units_out%length%name)
    write(unit,'("#")')
    write(unit,'("# ",53("-"))')
    write(unit,'("# |",7X,"e",7X,"|",5X,"ld_ae(e)",4X,"|",5X,"ld_pp(e)",4X,"|")')
    write(unit,'("# ",53("-"))')

    !Compute the logarithmic derivatives
    e = e1
    do
      ae_ld = state_ld(batch%states(1)%ptr, e, r, integrator, ae_potential, m)
      if (de == M_ZERO) then
        ps_ld = state_ld(batch%states(1)%ptr, e, r, integrator, ps_potential, m, dldde)
      else
        ps_ld = state_ld(batch%states(1)%ptr, e, r, integrator, ps_potential, m)
      end if

      write(unit,'(3X,ES14.7E2,3X,ES15.8E2,3X,ES15.8E2)') e/units_out%energy%factor, ae_ld, ps_ld

      if (e == e2) exit
      if (de /= M_ZERO) then
        step = de
      else
        step = max(abs(0.1_r8/dldde), 0.01_r8)
      end if
      e = min(e + step, e2)
    end do

    close(unit)

    call pop_sub()
  end subroutine states_batch_ld_test

  subroutine states_batch_output_configuration(batch, nspin, unit, verbose_limit)
    !-----------------------------------------------------------------------!
    ! Outputs the states label and occupancies in a nice readable format.   !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    integer,              intent(in), optional :: unit, verbose_limit

    integer :: i
    character(len=4) :: nspin_label
    type(qn_t) :: qn
    type(state_t), pointer :: state 

    call push_sub("states_batch_output_configuration")

    select case (nspin)
    case (1)
      write(message(1),'(2X,"Configuration    : State   Occupation")')
    case (2)
      qn = state_qn(batch%states(1)%ptr)
      if (qn%m == M_ZERO) then
        write(message(1),'(2X,"Configuration    : State   Spin   Occupation")')
      else
        write(message(1),'(2X,"Configuration    :  State   Sigma   Occupation")')
      end if
    end select

    do i =1, batch%n_states
      state => batch%states(i)%ptr
      qn = state_qn(state)
      select case (nspin)
      case (1)
        if (qn%j == M_ZERO) then
          write(message(i+1),'(23X,A,6X,F5.2)') trim(state_label(state)), state_charge(state)
        else
          write(message(i+1),'(21X,A,5X,F5.2)') trim(state_label(state)), state_charge(state)
        end if
      case (2)
        if (qn%m == M_ZERO) then
          if (qn%s == -M_HALF) then
            nspin_label = "-1/2"			
          else
            nspin_label = "+1/2"
          end if
          write(message(i+1),'(21X,A,4X,A,5X,F5.2)') trim(state_label(state)), &
                                                     nspin_label, state_charge(state)
        else
          if (qn%sg == M_HALF) then
            nspin_label = 'up'
          elseif (qn%sg == -M_HALF) then
            nspin_label = 'dn'
          else
            nspin_label = '--'
          end if
          write(message(i+1),'(21X,A,5X,A,4X,F5.2)') trim(state_label(state)), &
                                                     nspin_label, state_charge(state)
        end if
      end select
    end do

    if (present(unit)) then
      call write_info(batch%n_states+1, unit=unit)
    else
      if (present(verbose_limit)) then
        call write_info(batch%n_states+1, verbose_limit)
      else
        call write_info(batch%n_states+1)
      end if
    end if

    call pop_sub()
  end subroutine states_batch_output_configuration

  subroutine states_batch_output_eigenvalues(batch, nspin, unit, verbose_limit)
    !-----------------------------------------------------------------------!
    ! Writes the eigenvalues either to a file or to the screen in a nice    !
    ! readable format.                                                      !
    !                                                                       !
    !  batch - batch of states                                              !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin        
    integer,              intent(in), optional :: unit, verbose_limit

    integer  :: i, j
    real(R8) :: u
    type(qn_t) :: qn
    character(len=4) :: nspin_label
    type(state_t), pointer :: state

    call push_sub("states_batch_output_eigenvalues")

    write(message(1),'(4X,"Eigenvalues [",A,"]")') trim(units_out%energy%abbrev)
    select case (nspin)
    case (1)
      write(message(2),'(5X,"State   Occupation   Eigenvalue")')
    case (2)
      qn = state_qn(batch%states(1)%ptr)
      if (qn%m == M_ZERO) then
        write(message(2),'(5X,"State   Spin   Occupation   Eigenvalue")')
      else
        write(message(2),'(5X," State   Sigma   Occupation   Eigenvalue")')
      end if
    end select
		
    u = units_out%energy%factor
    j = 2
    do i = 1, batch%n_states
      j = j + 1
      state => batch%states(i)%ptr

      qn = state_qn(state)
      select case (nspin)
      case (1)
        if (qn%j == M_ZERO) then
          write(message(j),'(7X,A,6X,F5.2,4X,F12.5)') trim(state_label(state)), &
                                state_charge(state), state_eigenvalue(state)/u
        else
          write(message(j),'(5X,A,5X,F5.2,4X,F12.5)') trim(state_label(state)), &
                                state_charge(state), state_eigenvalue(state)/u
        end if
      case (2)
        if (qn%m == M_ZERO) then
          if (qn%s == -M_HALF) then
            nspin_label = "-1/2"
          else
            nspin_label = "+1/2"
          end if
          write(message(j),'(7X,A,4X,A,5X,F5.2,4X,F12.5)') trim(state_label(state)), nspin_label, &
                                                           state_charge(state), state_eigenvalue(state)/u
        else
          if (qn%sg == M_HALF) then
            nspin_label = 'up'
          elseif (qn%sg == -M_HALF) then
            nspin_label = 'dn'
          else
            nspin_label = '--'
          end if
          write(message(j),'(5X,A,4X,A,5X,F5.2,4X,F12.5)')  trim(state_label(state)), nspin_label, &
                                                            state_charge(state), state_eigenvalue(state)/u
        end if
      end select
    end do

    if (present(unit)) then
      call write_info(j,unit=unit)
    else
      if (present(verbose_limit)) then
        call write_info(j,verbose_limit)
      else
        call write_info(j)
      end if
    end if

    call pop_sub()
  end subroutine states_batch_output_eigenvalues

  subroutine states_batch_output_density(batch, nspin, m, dir)
    !-----------------------------------------------------------------------!
    ! Writes the electronic density to the "dir/density" file in a format   !
    ! suitable for plotting.                                                !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    type(mesh_t),         intent(in) :: m
    character(len=*),     intent(in) :: dir

    integer  :: i, p, unit
    real(R8) :: u
    character(len=30) :: filename
    real(R8), allocatable :: rho(:,:), rho_grad(:,:), rho_lapl(:,:), tau(:,:)

    call push_sub("states_batch_output_density")

    !Get the density and its derivatives
    allocate(rho(m%np, nspin), rho_grad(m%np, nspin))
    allocate(rho_lapl(m%np, nspin), tau(m%np, nspin))
    rho      = states_batch_density     (batch, nspin, m)
    rho_grad = states_batch_density_grad(batch, nspin, m)
    rho_lapl = states_batch_density_lapl(batch, nspin, m)
    tau      = states_batch_tau         (batch, nspin, m)

    u = units_out%length%factor

    !Density, gradient, and laplacian
    do p = 1, nspin
      if (nspin == 1) then
        filename = trim(dir)//"/density"
      else
        if (p == 1) then
          filename = trim(dir)//"/density_dn"
        elseif (p == 2) then
          filename = trim(dir)//"/density_up"
        end if
      end if
      call io_open(unit, file=trim(filename))

      write(unit,'("#")')
      write(unit,'("# Radial density.")')
      write(unit,'("# Length units: ",A)') trim(units_out%length%name)
      write(unit,'("#")')

      write(unit,'("# ",71("-"))')
      write(unit,'("# |",7X,"r",7X,"|",6X,"n(r)",7X,"|",4X,"grad_n(r)",4X,"|",4X,"lapl_n(r)",4X,"|")')
      write(unit,'("# ",71("-"))')
      do i = 1, m%np
        write(unit,'(3X,ES14.8E2,3X,ES15.8E2,3X,ES15.8E2,3X,ES15.8E2)') &
             m%r(i)/u, rho(i, p)*u, rho_grad(i, p)*u**2, rho_lapl(i, p)*u**3
      end do

      close(unit)
    end do

    !Print kinetic energy density
    do p = 1, nspin
      if (nspin == 1) then
        filename = trim(dir)//"/tau"
      else
        if (p == 1) then
          filename = trim(dir)//"/tau_dn"
        elseif (p == 2) then
          filename = trim(dir)//"/tau_up"
        end if
      end if
      call io_open(unit, file=trim(filename))

      write(unit,'("#")')
      write(unit,'("# Radial kinetic energy density.")')
      write(unit,'("# Length units: ",A)') trim(units_out%length%name)
      write(unit,'("# Energy units: ",A)') trim(units_out%energy%name)
      write(unit,'("#")')

      write(unit,'("# ",35("-"))')
      write(unit,'("# |",7X,"r",7X,"|",5X,"tau(r)",6X,"|")')
      write(unit,'("# ",35("-"))')
      do i = 1, m%np
        write(unit,'(3X,ES14.8E2,3X,ES15.8E2)') m%r(i)/u, &
                                              tau(i, p)*u/units_out%energy%factor
      end do

      close(unit)
    end do

    !Free memory
    deallocate(rho, rho_grad, rho_lapl, tau)

    call pop_sub()
  end subroutine states_batch_output_density

  subroutine states_batch_ps_io_set(batch, m, rc)
    !-----------------------------------------------------------------------!
    ! Pass the information about the wavefunctions to the ps_io module.     !
    !-----------------------------------------------------------------------!
    type(states_batch_t), intent(in) :: batch
    type(mesh_t),         intent(in) :: m
    real(R8),             intent(in) :: rc(batch%n_states)

    integer :: i
    character(len=10) :: label
    type(qn_t) :: qn
    integer, allocatable :: n(:), l(:)
    real(R8), allocatable :: j(:), occ(:), ev(:), wfs(:,:), rho_val(:,:)
    type(state_t), pointer :: state

    call push_sub("states_batch_ps_io_set")

    allocate(n(batch%n_states), l(batch%n_states), j(batch%n_states))
    allocate(occ(batch%n_states), ev(batch%n_states))
    allocate(wfs(m%np, batch%n_states))
    allocate(rho_val(m%np, 1))

    rho_val = states_batch_density(batch, 1, m)
    do i = 1, batch%n_states
      state => batch%states(i)%ptr
      qn = state_qn(state)
      label = state_label(state)
      read(label,'(I1)') n(i)
      l(i) = qn%l
      j(i) = qn%j
      occ(i) = state_charge(state)
      ev(i) = state_eigenvalue(state)
      wfs(:, i) = state%wf(:,1)
    end do

    call ps_io_set_wfs(m%np, batch%n_states, n, l, j, occ, ev, rc, wfs, &
                       sum(rho_val, dim=2))

    deallocate(n, l, j, occ, ev, wfs, rho_val)

    call pop_sub()
  end subroutine states_batch_ps_io_set

end module states_batch_m
