!! Copyright (C) 2004 Xavier Andrade, M. Marques
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

! ---------------------------------------------------------
!> This subroutine calculates the solution of (H + shift) x = y
!! Typically shift = - eigenvalue + omega
! ---------------------------------------------------------
subroutine X(linear_solver_solve_HXeY) (this, namespace, hm, mesh, st, ist, ik, x, y, shift, tol, residue, iter_used, occ_response)
  type(linear_solver_t),    target, intent(inout) :: this
  type(namespace_t),        target, intent(in)    :: namespace
  type(hamiltonian_elec_t), target, intent(in)    :: hm
  type(mesh_t),             target, intent(in)    :: mesh
  type(states_elec_t),      target, intent(in)    :: st
  integer,                          intent(in)    :: ist
  integer,                          intent(in)    :: ik
  R_TYPE,                           intent(inout) :: x(:,:)   !< x(mesh%np_part, d%dim)
  R_TYPE,                           intent(in)    :: y(:,:)   !< y(mesh%np, d%dim)
  R_TYPE,                           intent(in)    :: shift
  FLOAT,                            intent(in)    :: tol
  FLOAT,                            intent(out)   :: residue
  integer,                          intent(out)   :: iter_used
  logical, optional,                intent(in)    :: occ_response

  logical :: occ_response_
  R_TYPE, allocatable :: z(:, :)

  PUSH_SUB(X(linear_solver_solve_HXeY))
  call profiling_in(prof, TOSTRING(X(LINEAR_SOLVER)))

  occ_response_ = .true.
  if (present(occ_response)) occ_response_ = occ_response

  args%ls        => this
  args%namespace => namespace
  args%hm        => hm
  args%mesh      => mesh
  args%st        => st
  args%ist       = ist
  args%ik        = ik
  args%X(shift)  = shift
  iter_used = this%max_iter

  select case (this%solver)

  case (OPTION__LINEARSOLVER__CG)
    call X(linear_solver_cg) (this, namespace, hm, mesh, st, ist, ik, x, y, shift, tol, residue, iter_used)

  case (OPTION__LINEARSOLVER__IDRS)
    call X(linear_solver_idrs) (this, namespace, mesh, st, x, y, tol, residue, iter_used)

  case (OPTION__LINEARSOLVER__BICGSTAB)
    call X(linear_solver_bicgstab) (this, namespace, hm, mesh, st, ist, ik, x, y, shift, tol, residue, iter_used, occ_response_)

  case (OPTION__LINEARSOLVER__MULTIGRID)
    call X(linear_solver_multigrid)(this, namespace, hm, mesh, st, ist, ik, x, y, shift, tol, residue, iter_used)

  case (OPTION__LINEARSOLVER__QMR_SYMMETRIC)
    ! complex symmetric: for Sternheimer, only if wfns are real
    call X(qmr_sym_gen_dotu)(mesh%np, x(:, 1), y(:, 1), &
      X(linear_solver_operator_na), X(mf_dotu_aux), X(mf_nrm2_aux), X(linear_solver_preconditioner), &
      iter_used, residue = residue, threshold = tol, showprogress = .false.)

  case (OPTION__LINEARSOLVER__QMR_SYMMETRIZED)
    ! symmetrized equation
    SAFE_ALLOCATE(z(1:mesh%np, 1))
    call X(linear_solver_operator_t_na)(y(:, 1), z(:, 1))
    call X(qmr_sym_gen_dotu)(mesh%np, x(:, 1), z(:, 1), &
      X(linear_solver_operator_sym_na), X(mf_dotu_aux), X(mf_nrm2_aux), X(linear_solver_preconditioner), &
      iter_used, residue = residue, threshold = tol, showprogress = .false.)

  case (OPTION__LINEARSOLVER__QMR_DOTP)
    ! using conjugated dot product
    call X(qmr_sym_gen_dotu)(mesh%np, x(:, 1), y(:, 1), &
      X(linear_solver_operator_na), X(mf_dotp_aux), X(mf_nrm2_aux), X(linear_solver_preconditioner), &
      iter_used, residue = residue, threshold = tol, showprogress = .false.)

  case (OPTION__LINEARSOLVER__QMR_GENERAL)
    ! general algorithm
    call X(qmr_gen_dotu)(mesh%np, x(:, 1), y(:, 1), X(linear_solver_operator_na), X(linear_solver_operator_t_na), &
      X(mf_dotu_aux), X(mf_nrm2_aux), X(linear_solver_preconditioner), X(linear_solver_preconditioner), &
      iter_used, residue = residue, threshold = tol, showprogress = .false.)

  case (OPTION__LINEARSOLVER__SOS)
    call X(linear_solver_sos)(hm, namespace, mesh, st, ist, ik, x, y, shift, residue, iter_used)

  case default
    write(message(1), '(a,i2)') "Unknown linear-response solver", this%solver
    call messages_fatal(1, namespace=namespace)

  end select

  call profiling_out(prof)
  POP_SUB(X(linear_solver_solve_HXeY))

end subroutine X(linear_solver_solve_HXeY)

! ---------------------------------------------------------

subroutine X(linear_solver_solve_HXeY_batch) (this, namespace, hm, mesh, st, xb, yb, shift, tol, &
  residue, iter_used, occ_response, use_initial_guess)
  type(linear_solver_t),    target, intent(inout) :: this
  type(namespace_t),                intent(in)    :: namespace
  type(hamiltonian_elec_t), target, intent(in)    :: hm
  type(mesh_t),             target, intent(in)    :: mesh
  type(states_elec_t),      target, intent(in)    :: st
  type(wfs_elec_t),                 intent(inout) :: xb
  type(wfs_elec_t),                 intent(inout) :: yb
  R_TYPE,                           intent(in)    :: shift(:)
  FLOAT,                            intent(in)    :: tol
  FLOAT,                            intent(out)   :: residue(:)
  integer,                          intent(out)   :: iter_used(:)
  logical, optional,                intent(in)    :: occ_response
  logical, optional,                intent(in)    :: use_initial_guess

  integer :: ii

  PUSH_SUB(X(linear_solver_solve_HXeY_batch))

  select case (this%solver)
  case (OPTION__LINEARSOLVER__QMR_DOTP)

    call profiling_in(prof_batch, TOSTRING(X(LINEAR_SOLVER_BATCH)))

    if (hamiltonian_elec_apply_packed(hm)) then
      call xb%do_pack()
      call yb%do_pack()
    end if
    call X(linear_solver_qmr_dotp)(this, namespace, hm, mesh, st, xb, yb, shift, iter_used, &
      residue, tol, use_initial_guess)
    if (hamiltonian_elec_apply_packed(hm)) then
      call yb%do_unpack()
      call xb%do_unpack()
    end if
    call profiling_out(prof_batch)

  case default
    do ii = 1, xb%nst
      call X(linear_solver_solve_HXeY) (this, namespace, hm, mesh, st, xb%ist(ii), xb%ik, xb%X(ff)(:,:,ii), &
        yb%X(ff)(:,:,ii), shift(ii), tol, residue(ii), iter_used(ii), occ_response)
    end do

  end select

  POP_SUB(X(linear_solver_solve_HXeY_batch))

end subroutine X(linear_solver_solve_HXeY_batch)

! ---------------------------------------------------------
!> Conjugate gradients
subroutine X(linear_solver_cg) (ls, namespace, hm, mesh, st, ist, ik, x, y, shift, tol, residue, iter_used)
  type(linear_solver_t),    intent(inout) :: ls
  type(namespace_t),        intent(in)    :: namespace
  type(hamiltonian_elec_t), intent(in)    :: hm
  type(mesh_t),             intent(in)    :: mesh
  type(states_elec_t),      intent(in)    :: st
  integer,                  intent(in)    :: ist
  integer,                  intent(in)    :: ik
  R_TYPE,                   intent(inout) :: x(:,:)   !< x(mesh%np, st%d%dim)
  R_TYPE,                   intent(in)    :: y(:,:)   !< y(mesh%np, st%d%dim)
  R_TYPE,                   intent(in)    :: shift
  FLOAT,                    intent(in)    :: tol
  FLOAT,                    intent(out)   :: residue
  integer,                  intent(out)   :: iter_used

  R_TYPE, allocatable :: r(:,:), p(:,:), z(:,:), Hp(:,:)
  FLOAT  :: alpha, beta, gamma
  integer :: iter
  logical :: conv_last, conv

  PUSH_SUB(X(linear_solver_cg))

  SAFE_ALLOCATE( r(1:mesh%np_part, 1:st%d%dim))
  SAFE_ALLOCATE( p(1:mesh%np_part, 1:st%d%dim))
  SAFE_ALLOCATE( z(1:mesh%np, 1:st%d%dim))
  SAFE_ALLOCATE(Hp(1:mesh%np, 1:st%d%dim))

  ! Initial residue
  call X(linear_solver_operator)(hm, namespace, mesh, st, ist, ik, shift, x, Hp)

  r(1:mesh%np, 1:st%d%dim) = y(1:mesh%np, 1:st%d%dim) - Hp(1:mesh%np, 1:st%d%dim)

  call X(preconditioner_apply)(ls%pre, namespace, mesh, hm, r, z, 1, shift)

  ! Initial search direction
  p(1:mesh%np, 1:st%d%dim) = z(1:mesh%np, 1:st%d%dim)

  conv_last = .false.
  do iter = 1, ls%max_iter
    gamma = R_REAL(X(mf_dotp)(mesh, st%d%dim, r, z))

    call X(linear_solver_operator)(hm, namespace, mesh, st, ist, ik, shift, p, Hp)

    alpha = gamma/R_REAL(X(mf_dotp) (mesh, st%d%dim, p, Hp))

    !r = r - alpha*Hp
    call lalg_axpy(mesh%np, st%d%dim, -alpha, Hp, r)
    !x = x + alpha*p
    call lalg_axpy(mesh%np, st%d%dim,  alpha,  p, x)

    residue = X(mf_nrm2)(mesh, st%d%dim, r)
    conv = (residue < tol)
    if (conv .and. conv_last) exit
    conv_last = conv

    call X(preconditioner_apply)(ls%pre, namespace, mesh, hm, r, z, 1, shift)

    beta = R_REAL(X(mf_dotp)(mesh, st%d%dim, r, z)) / gamma

    p(1:mesh%np, 1:st%d%dim) = z(1:mesh%np, 1:st%d%dim) + beta*p(1:mesh%np, 1:st%d%dim)

  end do

  iter_used = iter

  if (.not. conv) then
    write(message(1), '(a)') "CG solver not converged!"
    call messages_warning(1, namespace=namespace)
  else
    if (debug%info) then
      write(message(1), '(a,i4,a)') 'Debug: CG solver converged in ', iter, ' iterations.'
      call messages_info(1, namespace=namespace)
    end if
  end if

  SAFE_DEALLOCATE_A(r)
  SAFE_DEALLOCATE_A(p)
  SAFE_DEALLOCATE_A(z)
  SAFE_DEALLOCATE_A(Hp)

  POP_SUB(X(linear_solver_cg))
end subroutine X(linear_solver_cg)


! ---------------------------------------------------------
!> IDRS
!> This is the "Induced Dimension Reduction", IDR(s) (for s=4). IDR(s) is a robust and efficient short recurrence
!> Krylov subspace method for solving large nonsymmetric systems of linear equations. It is described in
!> [Peter Sonneveld and Martin B. van Gijzen, SIAM J. Sci. Comput. 31, 1035 (2008)]. We have adapted the code
!> released by M. B. van Gizjen [http://ta.twi.tudelft.nl/nw/users/gijzen/IDR.html].
subroutine X(linear_solver_idrs) (ls, namespace, mesh, st, x, y, tol, residue, iter_used)
  type(linear_solver_t), intent(inout) :: ls
  type(namespace_t),     intent(in)    :: namespace
  type(mesh_t),          intent(in)    :: mesh
  type(states_elec_t),   intent(in)    :: st
  R_TYPE,                intent(inout) :: x(:,:)   !< x(mesh%np, st%d%dim)
  R_TYPE,                intent(in)    :: y(:,:)   !< y(mesh%np, st%d%dim)
  FLOAT,                 intent(in)    :: tol
  FLOAT,                 intent(out)   :: residue
  integer,               intent(out)   :: iter_used

  integer :: s, info
  R_TYPE, allocatable :: rhs(:, :), phi(:, :), x0(:, :)

  PUSH_SUB(dlinear_solver_idrs)

  SAFE_ALLOCATE(rhs(1:mesh%np*st%d%dim, 1))
  SAFE_ALLOCATE(phi(1:mesh%np*st%d%dim, 1))
  SAFE_ALLOCATE(x0(1:mesh%np*st%d%dim, 1))

  rhs(:, 1) = X(singledimarray)(mesh%np*st%d%dim, y)
  x0(:, 1)  = X(singledimarray)(mesh%np*st%d%dim, x)

  s = 4
  info = -1
  phi = X(idrs)(rhs, s, X(preconditioner), X(matrixvector), ddotproduct, zdotproduct, &
    tol, ls%max_iter, x0 = x0, iterations = iter_used, flag = info, relres = residue)

  x = X(doubledimarray)(mesh%np, st%d%dim, phi(:, 1))

  if (residue > tol .or. info .ne. 0) then
    write(message(1), '(a)')     "IDRS solver failed."
    write(message(2), '(a,i3)')  "Flag =,", info
    call messages_warning(2, namespace=namespace)
  end if

  SAFE_DEALLOCATE_A(x0)
  SAFE_DEALLOCATE_A(phi)
  SAFE_DEALLOCATE_A(rhs)
  POP_SUB(dlinear_solver_idrs)
end subroutine X(linear_solver_idrs)


! ---------------------------------------------------------
!> BICONJUGATE GRADIENTS STABILIZED
!! see http://math.nist.gov/iml++/bicgstab.h.txt
subroutine X(linear_solver_bicgstab) (ls, namespace, hm, mesh, st, ist, ik, x, y, shift, tol, residue, iter_used, occ_response)
  type(linear_solver_t),    intent(inout) :: ls
  type(namespace_t),        intent(in)    :: namespace
  type(hamiltonian_elec_t), intent(in)    :: hm
  type(mesh_t),             intent(in)    :: mesh
  type(states_elec_t),      intent(in)    :: st
  integer,                  intent(in)    :: ist
  integer,                  intent(in)    :: ik
  R_TYPE,                   intent(inout) :: x(:,:)   !< x(mesh%np, st%d%dim)
  R_TYPE,                   intent(in)    :: y(:,:)   !< y(mesh%np, st%d%dim)
  R_TYPE,                   intent(in)    :: shift
  FLOAT,                    intent(in)    :: tol
  FLOAT,                    intent(out)   :: residue
  integer,                  intent(out)   :: iter_used
  logical,                  intent(in)    :: occ_response

  R_TYPE, allocatable :: r(:,:), Hp(:,:), rs(:,:), Hs(:,:), p(:,:), s(:,:), psi(:, :), phat(:,:), shat(:,:)
  R_TYPE  :: alpha, beta, w, rho_1, rho_2
  logical :: conv_last, conv
  integer :: iter, idim, ip
  FLOAT :: gamma

  PUSH_SUB(X(linear_solver_bicgstab))

  SAFE_ALLOCATE( r(1:mesh%np, 1:st%d%dim))
  SAFE_ALLOCATE( p(1:mesh%np_part, 1:st%d%dim))
  SAFE_ALLOCATE(rs(1:mesh%np, 1:st%d%dim))
  SAFE_ALLOCATE( s(1:mesh%np_part, 1:st%d%dim))
  SAFE_ALLOCATE(Hp(1:mesh%np, 1:st%d%dim))
  SAFE_ALLOCATE(Hs(1:mesh%np, 1:st%d%dim))

  ! this will store the preconditioned functions
  SAFE_ALLOCATE(phat(1:mesh%np_part, 1:st%d%dim))
  SAFE_ALLOCATE(shat(1:mesh%np_part, 1:st%d%dim))

  ! Initial residue
  call X(linear_solver_operator) (hm, namespace, mesh, st, ist, ik, shift, x, Hp)

  do idim = 1, st%d%dim
    do ip = 1, mesh%np
      r(ip, idim) = y(ip, idim) - Hp(ip, idim)
    end do
  end do

  !re-orthogonalize r, this helps considerably with convergence
  if (occ_response) then
    SAFE_ALLOCATE(psi(1:mesh%np, 1:st%d%dim))

    call states_elec_get_state(st, mesh, ist, ik, psi)

    alpha = X(mf_dotp)(mesh, st%d%dim, psi, r)
    call lalg_axpy(mesh%np, st%d%dim,-alpha, psi, r)

    SAFE_DEALLOCATE_A(psi)
  else
    ! project RHS onto the unoccupied states
    call X(lr_orth_vector)(mesh, st, r, ist, ik, shift + st%eigenval(ist, ik))
  end if

  call lalg_copy(mesh%np, st%d%dim, r, rs)

  gamma = X(mf_nrm2)(mesh, st%d%dim, r)

  conv_last = .false.
  do iter = 1, ls%max_iter

    rho_1 = X(mf_dotp) (mesh, st%d%dim, rs, r)

    ! Here we want an exact comparison to zero,
    ! as rho_1 enters only as the ratio of quantities
    ! This fails only is rho_1 is exactly zero.
    if (rho_1 == M_ZERO) exit

    if (iter == 1) then
      call lalg_copy(mesh%np, st%d%dim, r, p)
    else
      beta = rho_1/rho_2*alpha/w
      do idim = 1, st%d%dim
        do ip = 1, mesh%np
          p(ip, idim) = r(ip, idim) + beta*(p(ip, idim) - w*Hp(ip, idim))
        end do
      end do
    end if

    ! preconditioning
    call X(preconditioner_apply)(ls%pre, namespace, mesh, hm, p, phat, 1, shift)
    call X(linear_solver_operator)(hm, namespace, mesh, st, ist, ik, shift, phat, Hp)

    alpha = rho_1/X(mf_dotp)(mesh, st%d%dim, rs, Hp)

    do idim = 1, st%d%dim
      do ip = 1, mesh%np
        s(ip, idim) = r(ip, idim) - alpha*Hp(ip, idim)
      end do
    end do

    gamma = X(mf_nrm2) (mesh, st%d%dim, s)

    conv = (gamma < tol)
    if (conv) then
      call lalg_axpy(mesh%np, st%d%dim, alpha, phat, x)
      exit
    end if

    call X(preconditioner_apply)(ls%pre, namespace, mesh, hm, s, shat, 1, shift)
    call X(linear_solver_operator)(hm, namespace, mesh, st, ist, ik, shift, shat, Hs)

    w = X(mf_dotp)(mesh, st%d%dim, Hs, s)/X(mf_dotp) (mesh, st%d%dim, Hs, Hs)

    do idim = 1, st%d%dim
      do ip = 1, mesh%np
        x(ip, idim) = x(ip, idim) + alpha*phat(ip, idim) + w*shat(ip, idim)
        r(ip, idim) = s(ip, idim) - w*Hs(ip, idim)
      end do
    end do

    rho_2 = rho_1

    gamma = X(mf_nrm2)(mesh, st%d%dim, r)
    conv = (gamma < tol)
    if (conv .and. conv_last) then
      exit
    end if
    conv_last = conv

    !This fails only if w is exactly zero
    if (w == M_ZERO) exit

  end do

  iter_used = iter
  residue = gamma

  if (.not. conv) then
    write(message(1), '(a)') "BiCGSTAB solver not converged!"
    call messages_warning(1, namespace=namespace)
  end if

  SAFE_DEALLOCATE_A(r)
  SAFE_DEALLOCATE_A(p)
  SAFE_DEALLOCATE_A(Hp)
  SAFE_DEALLOCATE_A(s)
  SAFE_DEALLOCATE_A(rs)
  SAFE_DEALLOCATE_A(Hs)
  SAFE_DEALLOCATE_A(phat)
  SAFE_DEALLOCATE_A(shat)

  POP_SUB(X(linear_solver_bicgstab))
end subroutine X(linear_solver_bicgstab)


! ---------------------------------------------------------
subroutine X(linear_solver_multigrid) (ls, namespace, hm, mesh, st, ist, ik, x, y, shift, tol, residue, iter_used)
  type(linear_solver_t),    intent(inout) :: ls
  type(namespace_t),        intent(in)    :: namespace
  type(hamiltonian_elec_t), intent(in)    :: hm
  type(mesh_t),             intent(in)    :: mesh
  type(states_elec_t),      intent(in)    :: st
  integer,                  intent(in)    :: ist
  integer,                  intent(in)    :: ik
  R_TYPE,                   intent(inout) :: x(:,:)   ! x(mesh%np, st%d%dim)
  R_TYPE,                   intent(in)    :: y(:,:)   ! y(mesh%np, st%d%dim)
  R_TYPE,                   intent(in)    :: shift
  FLOAT,                    intent(in)    :: tol
  FLOAT,                    intent(out)   :: residue
  integer,                  intent(out)   :: iter_used

  R_TYPE, allocatable :: diag(:,:), hx(:,:), res(:,:), psi(:, :)
  integer :: iter

  PUSH_SUB(X(linear_solver_multigrid))

  SAFE_ALLOCATE(diag(1:mesh%np, 1:st%d%dim))
  SAFE_ALLOCATE(  hx(1:mesh%np, 1:st%d%dim))
  SAFE_ALLOCATE( res(1:mesh%np, 1:st%d%dim))

  call X(hamiltonian_elec_diagonal)(hm, mesh, diag, ik)
  diag(1:mesh%np, 1:st%d%dim) = diag(1:mesh%np, 1:st%d%dim) + shift

  do iter = 1, ls%max_iter

    call smoothing(3)

    call smoothing(3)

    !calculate the residue
    call X(linear_solver_operator)(hm, namespace, mesh, st, ist, ik, shift, x, hx)
    res(1:mesh%np, 1:st%d%dim) = hx(1:mesh%np, 1:st%d%dim) - y(1:mesh%np, 1:st%d%dim)
    residue = X(mf_nrm2)(mesh, st%d%dim, res)

    if (residue < tol) exit

    if (debug%info) then

      SAFE_ALLOCATE(psi(1:mesh%np, 1:st%d%dim))

      call states_elec_get_state(st, mesh, ist, ik, psi)
      write(message(1), *)  "Multigrid: iter ", iter,  residue, abs(X(mf_dotp)(mesh, st%d%dim, psi, x))
      call messages_info(1, namespace=namespace)

      SAFE_DEALLOCATE_A(psi)

    end if

  end do

  iter_used = iter

  if (residue > tol) then
    write(message(1), '(a)') "Multigrid solver not converged!"
    call messages_warning(1, namespace=namespace)
  end if

  POP_SUB(X(linear_solver_multigrid))

contains

  subroutine smoothing(steps)
    integer, intent(in) :: steps

    integer :: ii, ip, idim
    R_TYPE  :: rr

    PUSH_SUB(X(linear_solver_multigrid).smoothing)

    do ii = 1, steps

      call X(linear_solver_operator)(hm, namespace, mesh, st, ist, ik, shift, x, hx)

      do idim = 1, st%d%dim
        do ip = 1, mesh%np
          rr = hx(ip, idim) - y(ip, idim)
          x(ip, idim) = x(ip, idim) - M_TWOTHIRD * rr / diag(ip, idim)
        end do
      end do

    end do

    call X(lr_orth_vector)(mesh, st, x, ist, ik, shift + st%eigenval(ist, ik))

    POP_SUB(X(linear_solver_multigrid).smoothing)
  end subroutine smoothing

end subroutine X(linear_solver_multigrid)


! ---------------------------------------------------------
!> This routine applies the operator hx = [H (+ Q) + shift] x
subroutine X(linear_solver_operator) (hm, namespace, mesh, st, ist, ik, shift, x, hx)
  type(hamiltonian_elec_t), intent(in)    :: hm
  type(namespace_t),        intent(in)    :: namespace
  type(mesh_t),             intent(in)    :: mesh
  type(states_elec_t),      intent(in)    :: st
  integer,                  intent(in)    :: ist
  integer,                  intent(in)    :: ik
  R_TYPE,                   intent(inout) :: x(:,:)   !<  x(mesh%np_part, st%d%dim)
  R_TYPE,                   intent(out)   :: Hx(:,:)  !< Hx(mesh%np, st%d%dim)
  R_TYPE,                   intent(in)    :: shift

  integer :: jst
  FLOAT   :: alpha_j
  R_TYPE  :: proj
  R_TYPE, allocatable :: psi(:, :)

  PUSH_SUB(X(linear_solver_operator))

  call X(hamiltonian_elec_apply_single)(hm, namespace, mesh, x, Hx, ist, ik)

  !Hx = Hx + shift*x
  call lalg_axpy(mesh%np, st%d%dim, shift, x, Hx)

  if (st%smear%method == SMEAR_SEMICONDUCTOR .or. st%smear%integral_occs) then
    POP_SUB(X(linear_solver_operator))
    return
  end if

  ! This is the Q term in Eq. (11) of PRB 51, 6773 (1995)
  ASSERT(.not. st%parallel_in_states)
  do jst = 1, st%nst
    alpha_j = lr_alpha_j(st, jst, ik)
    if (abs(alpha_j) <= M_EPSILON) cycle

    SAFE_ALLOCATE(psi(1:mesh%np, 1:st%d%dim))

    call states_elec_get_state(st, mesh, jst, ik, psi)

    proj = X(mf_dotp)(mesh, st%d%dim, psi, x)
    call lalg_axpy(mesh%np, st%d%dim, alpha_j*proj, psi, Hx)

    SAFE_DEALLOCATE_A(psi)

  end do

  POP_SUB(X(linear_solver_operator))

end subroutine X(linear_solver_operator)

! ---------------------------------------------------------
subroutine X(linear_solver_operator_batch) (hm, namespace, mesh, st, shift, xb, hxb)
  type(hamiltonian_elec_t), intent(in)    :: hm
  type(namespace_t),        intent(in)    :: namespace
  type(mesh_t),             intent(in)    :: mesh
  type(states_elec_t),      intent(in)    :: st
  R_TYPE,                   intent(in)    :: shift(:)
  type(wfs_elec_t),         intent(inout) :: xb
  type(wfs_elec_t),         intent(inout) :: hxb

  integer :: ii
  R_TYPE, allocatable :: shift_ist_indexed(:)

  PUSH_SUB(X(linear_solver_operator_batch))

  if (st%smear%method == SMEAR_SEMICONDUCTOR .or. st%smear%integral_occs) then

    call X(hamiltonian_elec_apply_batch)(hm, namespace, mesh, xb, hxb)

    SAFE_ALLOCATE(shift_ist_indexed(st%st_start:st%st_end))

    do ii = 1, xb%nst
      shift_ist_indexed(xb%ist(ii)) = shift(ii)
    end do

    call batch_axpy(mesh%np, shift_ist_indexed, xb, hxb)

    SAFE_DEALLOCATE_A(shift_ist_indexed)

  else

    do ii = 1, xb%nst
      call X(linear_solver_operator)(hm, namespace, mesh, st, xb%ist(ii), xb%ik, shift(ii), xb%X(ff)(:,:,ii), &
        hxb%X(ff)(:,:,ii))
    end do

  end if

  POP_SUB(X(linear_solver_operator_batch))

end subroutine X(linear_solver_operator_batch)

! ---------------------------------------------------------
!> applies linear_solver_operator with other arguments implicit as global variables
subroutine X(linear_solver_operator_na) (x, hx)
  R_TYPE,                intent(in)    :: x(:)   !<  x(mesh%np, st%d%dim)
  R_TYPE,                intent(out)   :: Hx(:)  !< Hx(mesh%np, st%d%dim)

  R_TYPE, allocatable :: tmpx(:, :)
  R_TYPE, allocatable :: tmpy(:, :)

  SAFE_ALLOCATE(tmpx(1:args%mesh%np_part, 1))
  SAFE_ALLOCATE(tmpy(1:args%mesh%np, 1))

  call lalg_copy(args%mesh%np, x, tmpx(:, 1))
  call X(linear_solver_operator)(args%hm, args%namespace, args%mesh, args%st, args%ist, args%ik, args%X(shift), tmpx, tmpy)
  call lalg_copy(args%mesh%np, tmpy(:, 1), hx)

  SAFE_DEALLOCATE_A(tmpx)
  SAFE_DEALLOCATE_A(tmpy)

end subroutine X(linear_solver_operator_na)


! ---------------------------------------------------------
!> applies transpose of linear_solver_operator with other arguments implicit as global variables
!! \f$ (H - shift)^T = H* - shift = (H - shift*)* \f$
subroutine X(linear_solver_operator_t_na) (x, hx)
  R_TYPE,                intent(in)    :: x(:)   !  x(mesh%np, st%d%dim)
  R_TYPE,                intent(out)   :: Hx(:)  ! Hx(mesh%np, st%d%dim)

  R_TYPE, allocatable :: tmpx(:, :)
  R_TYPE, allocatable :: tmpy(:, :)

  SAFE_ALLOCATE(tmpx(1:args%mesh%np_part, 1))
  SAFE_ALLOCATE(tmpy(1:args%mesh%np, 1))

  call lalg_copy(args%mesh%np, R_CONJ(x), tmpx(:, 1))
  call X(linear_solver_operator)(args%hm, args%namespace, args%mesh, args%st, args%ist, args%ik, R_CONJ(args%X(shift)), tmpx, tmpy)
  call lalg_copy(args%mesh%np, R_CONJ(tmpy(:, 1)), hx)

  SAFE_DEALLOCATE_A(tmpx)
  SAFE_DEALLOCATE_A(tmpy)

end subroutine X(linear_solver_operator_t_na)


! ---------------------------------------------------------
!> applies linear_solver_operator in symmetrized form: \f$  A^T A \f$
subroutine X(linear_solver_operator_sym_na) (x, hx)
  R_TYPE,                intent(in)    :: x(:)   !<  x(mesh%np, st%d%dim)
  R_TYPE,                intent(out)   :: Hx(:)  !< Hx(mesh%np, st%d%dim)

  R_TYPE, allocatable :: tmpx(:, :)
  R_TYPE, allocatable :: tmpy(:, :)
  R_TYPE, allocatable :: tmpz(:, :)

  SAFE_ALLOCATE(tmpx(1:args%mesh%np_part, 1))
  SAFE_ALLOCATE(tmpy(1:args%mesh%np_part, 1))
  SAFE_ALLOCATE(tmpz(1:args%mesh%np_part, 1))

  call lalg_copy(args%mesh%np, x, tmpx(:, 1))
  call X(linear_solver_operator)(args%hm, args%namespace, args%mesh, args%st, args%ist, args%ik, args%X(shift), tmpx, tmpy)
  call X(linear_solver_operator_t_na)(tmpy(:, 1), tmpz(:, 1))
  call lalg_copy(args%mesh%np, tmpz(:, 1), hx)

  SAFE_DEALLOCATE_A(tmpx)
  SAFE_DEALLOCATE_A(tmpy)
  SAFE_DEALLOCATE_A(tmpz)

end subroutine X(linear_solver_operator_sym_na)

! ---------------------------------------------------------
subroutine X(linear_solver_preconditioner) (x, hx)
  R_TYPE,                intent(in)    :: x(:)   !<  x(mesh%np, st%d%dim)
  R_TYPE,                intent(out)   :: hx(:)  !< Hx(mesh%np, st%d%dim)

  R_TYPE, allocatable :: tmpx(:, :)
  R_TYPE, allocatable :: tmpy(:, :)

  PUSH_SUB(X(linear_solver_preconditioner))

  SAFE_ALLOCATE(tmpx(1:args%mesh%np_part, 1))
  SAFE_ALLOCATE(tmpy(1:args%mesh%np_part, 1))

  call lalg_copy(args%mesh%np, x, tmpx(:, 1))
  call X(preconditioner_apply)(args%ls%pre, args%namespace, args%mesh, args%hm, tmpx, tmpy, 1, args%X(shift))
  call lalg_copy(args%mesh%np, tmpy(:, 1), hx)

  SAFE_DEALLOCATE_A(tmpx)
  SAFE_DEALLOCATE_A(tmpy)
  POP_SUB(X(linear_solver_preconditioner))

end subroutine X(linear_solver_preconditioner)

! ---------------------------------------------------------
subroutine X(linear_solver_sos) (hm, namespace, mesh, st, ist, ik, x, y, shift, residue, iter_used)
  type(hamiltonian_elec_t),       intent(in)    :: hm
  type(namespace_t),              intent(in)    :: namespace
  type(mesh_t),                   intent(in)    :: mesh
  type(states_elec_t),            intent(in)    :: st
  integer,                        intent(in)    :: ist
  integer,                        intent(in)    :: ik
  R_TYPE,                         intent(inout) :: x(:,:)   !< x(mesh%np, st%d%dim)
  R_TYPE,                         intent(in)    :: y(:,:)   !< y(mesh%np, st%d%dim)
  R_TYPE,                         intent(in)    :: shift
  FLOAT,                          intent(out)   :: residue
  integer,                        intent(out)   :: iter_used

  integer :: jst
  R_TYPE  :: aa
  R_TYPE, allocatable  :: rr(:, :)
  R_TYPE, allocatable :: psi(:, :)

  PUSH_SUB(X(linear_solver_sos))

  x(1:mesh%np, 1:st%d%dim) = M_ZERO

  SAFE_ALLOCATE(psi(1:mesh%np, 1:st%d%dim))

  do jst = 1, st%nst
    if (ist == jst) cycle

    call states_elec_get_state(st, mesh, jst, ik, psi)

    aa = X(mf_dotp)(mesh, st%d%dim, psi, y)
    aa = aa/(st%eigenval(jst, ik) + lr_alpha_j(st, jst, ik) + shift)
    ! Normally the expression in perturbation theory would have here
    ! denominator = st%eigenval(jst, ik) - st%eigenval(ist, ik)
    ! For solving this type of problem, -st%eigenval(ist, ik) is included in shift

    call lalg_axpy(mesh%np, st%d%dim, aa, psi, x)
  end do

  SAFE_DEALLOCATE_A(psi)

  ! calculate the residual
  SAFE_ALLOCATE(rr(1:mesh%np, 1:st%d%dim))
  call X(linear_solver_operator)(hm, namespace, mesh, st, ist, ik, shift, x, rr)

  call lalg_axpy(mesh%np, st%d%dim, -M_ONE, y, rr)

  residue = X(mf_nrm2)(mesh, st%d%dim, rr)
  iter_used = 1

  SAFE_DEALLOCATE_A(rr)
  POP_SUB(X(linear_solver_sos))

end subroutine X(linear_solver_sos)

! ---------------------------------------------------------
!> for complex symmetric matrices
!! W Chen and B Poirier, J Comput Phys 219, 198-209 (2006)
subroutine X(linear_solver_qmr_dotp)(this, namespace, hm, mesh, st, xb, bb, shift, iter_used, residue, threshold, &
  use_initial_guess)
  type(linear_solver_t),    intent(inout) :: this
  type(namespace_t),        intent(in)    :: namespace
  type(hamiltonian_elec_t), intent(in)    :: hm
  type(mesh_t),             intent(in)    :: mesh
  type(states_elec_t),      intent(in)    :: st
  type(wfs_elec_t),         intent(inout) :: xb
  type(wfs_elec_t),         intent(in)    :: bb
  R_TYPE,                   intent(in)    :: shift(:)
  integer,                  intent(out)   :: iter_used(:)
  FLOAT,                    intent(out)   :: residue(:)   !< the residue = abs(Ax-b)
  FLOAT,                    intent(in)    :: threshold    !< convergence threshold
  logical, optional,        intent(in)    :: use_initial_guess

  type(wfs_elec_t) :: vvb, res, zzb, qqb, ppb, deltax, deltar
  integer             :: ii, iter
  FLOAT, allocatable  :: rho(:), oldrho(:), norm_b(:), xsi(:), gamma(:), alpha(:), theta(:), oldtheta(:), saved_res(:)
  FLOAT, allocatable  :: oldgamma(:)
  R_TYPE, allocatable :: eta(:), beta(:), delta(:), eps(:), exception_saved(:, :, :)
  integer, allocatable :: status(:), saved_iter(:)

  integer, parameter ::        &
    QMR_NOT_CONVERGED    = 0,  &
    QMR_CONVERGED        = 1,  &
    QMR_RES_ZERO         = 2,  &
    QMR_B_ZERO           = 3,  &
    QMR_BREAKDOWN_PB     = 4,  &
    QMR_BREAKDOWN_VZ     = 5,  &
    QMR_BREAKDOWN_QP     = 6,  &
    QMR_BREAKDOWN_GAMMA  = 7

  PUSH_SUB(X(linear_solver_qmr_dotp))

  SAFE_ALLOCATE(rho(1:xb%nst))
  SAFE_ALLOCATE(oldrho(1:xb%nst))
  SAFE_ALLOCATE(norm_b(1:xb%nst))
  SAFE_ALLOCATE(xsi(1:xb%nst))
  SAFE_ALLOCATE(gamma(1:xb%nst))
  SAFE_ALLOCATE(oldgamma(1:xb%nst))
  SAFE_ALLOCATE(alpha(1:xb%nst))
  SAFE_ALLOCATE(eta(1:xb%nst))
  SAFE_ALLOCATE(theta(1:xb%nst))
  SAFE_ALLOCATE(oldtheta(1:xb%nst))
  SAFE_ALLOCATE(beta(1:xb%nst))
  SAFE_ALLOCATE(delta(1:xb%nst))
  SAFE_ALLOCATE(eps(1:xb%nst))
  SAFE_ALLOCATE(saved_res(1:xb%nst))

  SAFE_ALLOCATE(status(1:xb%nst))
  SAFE_ALLOCATE(saved_iter(1:xb%nst))

  SAFE_ALLOCATE(exception_saved(1:mesh%np, 1:st%d%dim, 1:xb%nst))

  call xb%copy_to(vvb)
  call xb%copy_to(res)
  call xb%copy_to(zzb)
  call xb%copy_to(qqb)
  call xb%copy_to(ppb)
  call xb%copy_to(deltax)
  call xb%copy_to(deltar)

  if (optional_default(use_initial_guess, .true.)) then
    ! Compared to the original algorithm, we assume that we have an initial guess
    ! This means that instead of setting x^(0)=0, we have x^(0)=xb
    ! TODO: We should implement here the proper recursion for x^(0) /= 0
    ! as published in "Preconditioning of Symmetric, but Highly Indefinite Linear Systems"
    ! R. W. Freund, 15th IMACS World Congress on Scientific Computation, Modelling and Applied Mathematics
    ! 2, 551 (1997)
    call X(linear_solver_operator_batch)(hm, namespace, mesh, st, shift, xb, vvb)

    call batch_xpay(mesh%np, bb, CNST(-1.0), vvb)
    call vvb%copy_data_to(mesh%np, res)

    ! Norm of the right-hand side
    call mesh_batch_nrm2(mesh, vvb, rho)
    call mesh_batch_nrm2(mesh, bb, norm_b)

    do ii = 1, xb%nst
      ! If the initial guess is a good enough solution
      if (abs(rho(ii)) <= threshold) then
        status(ii) = QMR_RES_ZERO
        residue(ii) = rho(ii)
        call batch_get_state(xb, ii, mesh%np, exception_saved(:, :, ii))
        saved_iter(ii) = 0
        saved_res(ii) = residue(ii)
      end if
    end do

  else ! If we don't know any guess, let's stick to the original algorithm

    call batch_set_zero(xb)
    call bb%copy_data_to(mesh%np, vvb)
    call vvb%copy_data_to(mesh%np, res)
    call mesh_batch_nrm2(mesh, bb, norm_b)
    rho = norm_b

  end if

  status = QMR_NOT_CONVERGED

  iter = 0

  do ii = 1, xb%nst

    residue(ii) = rho(ii)
    ! if b is zero, the solution is trivial
    if (status(ii) == QMR_NOT_CONVERGED .and. abs(norm_b(ii)) <= M_EPSILON) then
      exception_saved = M_ZERO
      status(ii) = QMR_B_ZERO
      residue(ii) = norm_b(ii)
      saved_iter(ii) = iter
      saved_res(ii) = residue(ii)
    end if

  end do

  ! We compute z = Pb and compute its norm \xsi^(1)
  call X(preconditioner_apply_batch)(this%pre, namespace, mesh, hm, vvb, zzb, 1, omega = shift)
  call mesh_batch_nrm2(mesh, zzb, xsi)

  gamma = M_ONE
  oldgamma = gamma
  eta   = CNST(-1.0)
  alpha = M_ONE
  theta = M_ZERO

  do while(iter < this%max_iter)
    iter = iter + 1

    ! Exit condition
    if (all(status /= QMR_NOT_CONVERGED)) exit

    ! Failure of the algorithm
    do ii = 1, xb%nst
      if (status(ii) == QMR_NOT_CONVERGED .and. (abs(rho(ii)) < M_EPSILON .or. abs(xsi(ii)) < M_EPSILON)) then
        call batch_get_state(xb, ii, mesh%np, exception_saved(:, :, ii))
        status(ii) = QMR_BREAKDOWN_PB
        saved_iter(ii) = iter
        saved_res(ii) = residue(ii)
      end if

      alpha(ii) = alpha(ii)*xsi(ii)/rho(ii)
    end do

    ! v^(i) = v^(i) / \rho_i
    call batch_scal(mesh%np, M_ONE/rho, vvb, a_full = .false.)
    ! z^(i) = z^(i) / \xsi_i
    call batch_scal(mesh%np, M_ONE/xsi, zzb, a_full = .false.)
    ! \delta_i = v^(i)\ldotp z^(i)
    call X(mesh_batch_dotp_vector)(mesh, vvb, zzb, delta)

    !If \delta_i = 0, method fails
    do ii = 1, xb%nst
      if (status(ii) == QMR_NOT_CONVERGED .and. abs(delta(ii)) < M_EPSILON) then
        call batch_get_state(xb, ii, mesh%np, exception_saved(:, :, ii))
        status(ii) = QMR_BREAKDOWN_VZ
        saved_iter(ii) = iter
        saved_res(ii) = residue(ii)
      end if
    end do

    if (iter == 1) then
      ! q^(1) = z^(1)
      call zzb%copy_data_to(mesh%np, qqb)
    else
      ! q^(i) = z^(i) - (\rho_i\delta_i)/(\eps_{i-1})q^(i-1)
      call batch_xpay(mesh%np, zzb, -rho*delta/eps, qqb, a_full = .false.)
    end if

    ! ppb = (H-shift)*qqb
    call X(linear_solver_operator_batch)(hm, namespace, mesh, st, shift, qqb, ppb)
    ! p^(i) = \alpha_{i+1} (H-shift)*q^(i)
    call batch_scal(mesh%np, alpha, ppb, a_full = .false.)

    ! \eps_i = q^{(i)}\ldotp p^{(i)}
    call X(mesh_batch_dotp_vector)(mesh, qqb, ppb, eps)

    ! If \eps_i == 0, method fails
    do ii = 1, xb%nst
      if (status(ii) == QMR_NOT_CONVERGED .and. abs(eps(ii)) < M_EPSILON) then
        call batch_get_state(xb, ii, mesh%np, exception_saved(:, :, ii))
        status(ii) = QMR_BREAKDOWN_QP
        saved_iter(ii) = iter
        saved_res(ii) = residue(ii)
      end if

      beta(ii) = eps(ii)/delta(ii)
    end do

    ! v^(i+1) = p^(i) - \beta_i v^(i)
    call batch_xpay(mesh%np, ppb, -beta, vvb, a_full = .false.)

    do ii = 1, xb%nst
      oldrho(ii) = rho(ii)
    end do

    ! \rho_{i+1} = ||v^{i+1}||_2
    call mesh_batch_nrm2(mesh, vvb, rho)

    ! z^{i+1} = P v^{i+1}
    call X(preconditioner_apply_batch)(this%pre, namespace, mesh, hm, vvb, zzb, 1, omega = shift)
    ! z^{i+1} = P v^{i+1}/ \alpha^{i+1}
    call batch_scal(mesh%np, M_ONE/alpha, zzb, a_full = .false.)

    ! \xsi_{i+1} = ||z^{i+1}||_2
    call mesh_batch_nrm2(mesh, zzb, xsi)

    do ii = 1, xb%nst

      oldtheta(ii) = theta(ii)
      ! \theta_i = \rho_{i+1}/(\gamma_{i-1} |\beta_i|)
      theta(ii) = rho(ii)/(gamma(ii)*abs(beta(ii)))

      oldgamma(ii) = gamma(ii)
      ! \gamma_i = 1/sqrt(1+\theta_i^2)
      gamma(ii) = M_ONE/sqrt(M_ONE + theta(ii)**2)

      ! If \gamma_i == 0, method fails
      if (status(ii) == QMR_NOT_CONVERGED .and. abs(gamma(ii)) < M_EPSILON) then
        call batch_get_state(xb, ii, mesh%np, exception_saved(:, :, ii))
        status(ii) = QMR_BREAKDOWN_GAMMA
        saved_iter(ii) = iter
        saved_res(ii) = residue(ii)
      end if

      ! \eta_i = -\eta_{i-1}\rho_i \gamma_i^2/ (\beta_i \gamma_{i-1}^2)
      eta(ii) = -eta(ii)*oldrho(ii)*gamma(ii)**2/(beta(ii)*oldgamma(ii)**2)
    end do

    if (iter == 1) then

      ! \delta_x^(1) = \eta_1 \alpha_2 q^{(1)}
      call qqb%copy_data_to(mesh%np, deltax)
      call batch_scal(mesh%np, eta*alpha, deltax, a_full = .false.)

      ! \delta_r^(1) = \eta_1 p^1
      call ppb%copy_data_to(mesh%np, deltar)
      call batch_scal(mesh%np, eta, deltar, a_full = .false.)

    else

      ! \delta_x^{i} = (\theta_{i-1}\gamma_i)^2 \delta_x^{i-1} + \eta_i\alpha_{i+1} q^i
      call batch_scal(mesh%np, (oldtheta*gamma)**2, deltax, a_full = .false.)
      call batch_axpy(mesh%np, eta*alpha, qqb, deltax, a_full = .false.)

      ! \delta_r^{i} = (\theta_{i-1}\gamma_i)^2 \delta_r^{i-1} + \eta_i p^i
      call batch_scal(mesh%np, (oldtheta*gamma)**2, deltar, a_full = .false.)
      call batch_axpy(mesh%np, eta, ppb, deltar, a_full = .false.)

    end if

    ! FIXME: if the states are converged, why changing them here
    ! x^{i} = x^{i-1} + \delta_x^{i}
    call batch_axpy(mesh%np, M_ONE, deltax, xb)
    ! r^{i} = r^{i-1} - \delta_r^i
    ! This is given by r^{i} = b - Hx^{i}
    call batch_axpy(mesh%np, CNST(-1.0), deltar, res)

    ! We compute the norm of the residual
    call mesh_batch_nrm2(mesh, res, residue)
    do ii = 1, xb%nst
      residue(ii) = residue(ii)/norm_b(ii)
    end do

    ! Convergence is reached once the residues are smaller than the threshold
    do ii = 1, xb%nst
      if (status(ii) == QMR_NOT_CONVERGED .and. residue(ii) < threshold) then
        status(ii) = QMR_CONVERGED
        if (debug%info) then
          write(message(1),*) 'Debug: State ', xb%ist(ii), ' converged in ', iter, ' iterations.'
          call messages_info(1, namespace=namespace)
        end if
      end if
    end do

  end do

  do ii = 1, xb%nst
    if (status(ii) == QMR_NOT_CONVERGED .or. status(ii) == QMR_CONVERGED) then
      ! We stop at the entrance of the next iteraction, so we substract one here
      iter_used(ii) = iter -1
    else
      call batch_set_state(xb, ii, mesh%np, exception_saved(:, :, ii))
      iter_used(ii) = saved_iter(ii)
      residue(ii) = saved_res(ii)
    end if

    select case (status(ii))
    case (QMR_NOT_CONVERGED)
      write(message(1), '(a)') "QMR solver not converged!"
      write(message(2), '(a)') "Try increasing the maximum number of iterations or the tolerance."
      call messages_warning(2, namespace=namespace)
    case (QMR_BREAKDOWN_PB)
      write(message(1), '(a)') "QMR breakdown, cannot continue: b or P*b is the zero vector!"
      call messages_warning(1, namespace=namespace)
    case (QMR_BREAKDOWN_VZ)
      write(message(1), '(a)') "QMR breakdown, cannot continue: v^T*z is zero!"
      call messages_warning(1, namespace=namespace)
    case (QMR_BREAKDOWN_QP)
      write(message(1), '(a)') "QMR breakdown, cannot continue: q^T*p is zero!"
      call messages_warning(1, namespace=namespace)
    case (QMR_BREAKDOWN_GAMMA)
      write(message(1), '(a)') "QMR breakdown, cannot continue: gamma is zero!"
      call messages_warning(1, namespace=namespace)
    end select

  end do

  call vvb%end()
  call res%end()
  call zzb%end()
  call qqb%end()
  call ppb%end()
  call deltax%end()
  call deltar%end()

  SAFE_DEALLOCATE_A(exception_saved)
  SAFE_DEALLOCATE_A(rho)
  SAFE_DEALLOCATE_A(oldrho)
  SAFE_DEALLOCATE_A(norm_b)
  SAFE_DEALLOCATE_A(xsi)
  SAFE_DEALLOCATE_A(gamma)
  SAFE_DEALLOCATE_A(oldgamma)
  SAFE_DEALLOCATE_A(alpha)
  SAFE_DEALLOCATE_A(eta)
  SAFE_DEALLOCATE_A(theta)
  SAFE_DEALLOCATE_A(oldtheta)
  SAFE_DEALLOCATE_A(beta)
  SAFE_DEALLOCATE_A(delta)
  SAFE_DEALLOCATE_A(eps)

  SAFE_DEALLOCATE_A(status)
  SAFE_DEALLOCATE_A(saved_res)
  SAFE_DEALLOCATE_A(saved_iter)

  POP_SUB(X(linear_solver_qmr_dotp))
end subroutine X(linear_solver_qmr_dotp)


function X(singledimarray)(n, a)
  integer, intent(in) :: n
  R_TYPE, intent(in) :: a(:, :)
  R_TYPE :: X(singledimarray)(n)
  integer :: idim, np, dim
  np = args%mesh%np
  dim = args%st%d%dim
  do idim = 1, dim
    X(singledimarray)((idim-1)*np+1: idim*np) = a(1:np, idim)
  end do
end function X(singledimarray)

function X(doubledimarray)(np, dim, a)
  integer, intent(in) :: np, dim
  R_TYPE, intent(in) :: a(:)
  R_TYPE :: X(doubledimarray)(np, dim)
  integer :: idim
  do idim = 1, dim
    X(doubledimarray)(1:np, idim) = a((idim-1)*np+1: idim*np)
  end do
end function X(doubledimarray)

R_TYPE function X(dotproduct)(a, b)
  R_TYPE, intent(in) :: a(:), b(:)
  X(dotproduct) = X(mf_dotp)(args%mesh, args%st%d%dim, &
    X(doubledimarray)(args%mesh%np, args%st%d%dim, a), &
    X(doubledimarray)(args%mesh%np, args%st%d%dim, b))
end function X(dotproduct)

function X(matrixvector)(v)
  R_TYPE, intent(in)       :: v(:, :)
  R_TYPE                   :: X(matrixvector)(size(v, 1), size(v, 2))

  integer :: np, np_part, dim
  R_TYPE, allocatable :: phi(:, :)
  R_TYPE, allocatable :: hphi(:, :)

  np = args%mesh%np
  np_part = args%mesh%np_part
  dim = args%st%d%dim
  SAFE_ALLOCATE(phi(1:np_part, 1:dim))
  SAFE_ALLOCATE(hphi(1:np, 1:dim))

  phi = R_TOTYPE(M_ZERO)
  phi(1:np, 1:dim) = X(doubledimarray)(np, dim, v(:, 1))
  call X(linear_solver_operator) (args%hm, args%namespace, args%mesh, args%st, args%ist, args%ik, args%X(shift), phi, hphi)
  X(matrixvector)(1:np*dim, 1) = X(singledimarray)(np*dim, hphi)

  SAFE_DEALLOCATE_A(phi)
  SAFE_DEALLOCATE_A(hphi)
end function X(matrixvector)

function X(preconditioner)(v)
  R_TYPE, dimension(:, :), intent(in)     :: v
  R_TYPE, dimension(size(v, 1), size(v, 2)) :: X(preconditioner)

  integer :: np, np_part, dim
  R_TYPE, allocatable :: phi(:, :)
  R_TYPE, allocatable :: precphi(:, :)

  np = args%mesh%np
  np_part = args%mesh%np_part
  dim = args%st%d%dim
  SAFE_ALLOCATE(phi(1:np_part, 1:dim))
  SAFE_ALLOCATE(precphi(1:np, 1:dim))

  phi = R_TOTYPE(M_ZERO)
  phi(1:np, 1:dim) = X(doubledimarray)(np, dim, v(:, 1))
  call X(preconditioner_apply)(args%ls%pre, args%namespace, args%mesh, args%hm, phi, precphi, 1, args%X(shift))
  X(preconditioner)(1:np*dim, 1) = X(singledimarray)(np*dim, precphi)

  SAFE_DEALLOCATE_A(phi)
  SAFE_DEALLOCATE_A(precphi)
end function X(preconditioner)


!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
