program lsqmt
use printmat2
implicit none
!
! As lsqr but with support for multiple traits
!
integer,parameter::effcross=0,& !effects can be cross-classified
effcov=1 !or covariables
real, allocatable:: xx(:,:),xy(:),sol(:) !storage for the equations
integer, allocatable:: indata(:) !storage for one line of effects

integer,parameter:: neff=2,& !number of effects
nlev(2)=(/2,3/),& !number of levels
ntrait=2,& !number of traits
miss=0 !value of missing trait/effect
real :: r(ntrait,ntrait),& !residual covariance matrix
rinv(ntrait,ntrait) ! and its inverse
integer :: effecttype(neff)=(/effcross, effcross/)
integer :: nestedcov(neff) =(/0,0/)
real :: weight_cov(neff,ntrait)

real :: y(ntrait) ! observation value
integer :: neq,io,i,j,k,l ! number of equations and io-status
integer,allocatable:: address(:,:) ! start and address of each effect
open(98,file='solutions')
open(99,file='r_matrix')
!
!r= reshape((/1,2,2,5/),(/2,2/)) !r={1 2;2 5}
do i=1,2
   read(99,*) r(i,:)
enddo
!r= reshape((/1,0,0,2/),(/2,2/)) !r={1 2;2 5}
print*,'R matrix'
call printnice(r,'(10f12.6)')

neq=ntrait*sum(nlev)
print*,'# effects =',neff
print*,'# equations =',neq
allocate (xx(neq,neq), xy(neq), sol(neq),indata(neff*ntrait),&
address(neff,ntrait))
xx=0; xy=0; sol=0
!
open(1,file='data_pr3')
!
do
   read(1,*,iostat=io)indata,y
   if (io.ne.0) exit
   call find_addresses
   call find_rinv
   do i=1,neff
      do j=1,neff
         do k=1,ntrait
            do l=1,ntrait
               xx(address(i,k),address(j,l))=xx(address(i,k),address(j,l))+&
               weight_cov(i,k)*weight_cov(j,l)*rinv(k,l)
            enddo
         enddo
      enddo
      do k=1,ntrait
         do l=1,ntrait
            xy(address(i,k))=xy(address(i,k))+rinv(k,l)*y(l)*weight_cov(i,k)
         enddo
      enddo
   enddo
enddo
!
call printnice(rinv,'(10f12.4)')

print*,'left hand side'
do i=1,neq
   print '(100f5.1)',xx(i,:)
enddo
!
print '( '' right hand side:'' ,100f6.1)',xy
!
call solve_dense_gs(neq,xx,xy,sol) !solution by Gauss-Seidel
do i=1,neq
   write(98,'(100f7.3)') sol(i)
enddo
print '( '' solution:'' ,100f7.3)',sol

contains

subroutine find_addresses
integer :: i,j,baseaddr
do i=1,neff
   do j=1,ntrait
      if (datum(i,j) == miss) then !missing effect
         address(i,j)=0 !dummy address
         weight_cov(i,j)=0.0
         cycle
      endif
      baseaddr=sum(nlev(1:i-1))*ntrait+j !base address (start)
      select case (effecttype(i))
      case (effcross)
         address(i,j)=baseaddr+(datum(i,j)-1)*ntrait
         weight_cov(i,j)=1.0
         case (effcov)
         weight_cov(i,j)=datum(i,j)
         if (nestedcov(i) == 0) then
            address(i,j)=baseaddr
         elseif (nestedcov(i)>0 .and. nestedcov(i).lt.neff) then
            address(i,j)=baseaddr+(datum(nestedcov(i),j)-1)*ntrait
         else
            print*,'wrong description of nested covariable'
            stop
         endif
      case default
         print*,'unimplemented effect ',i
         stop
      end select
   enddo
enddo
end subroutine

function datum(ef,tr)
real:: datum
integer :: ef,tr
! calculates the value effect ef and trait tr
datum=indata(ef +(tr-1)*neff)
end function

subroutine find_rinv
! calculates inv(Q R Q), where Q is an identity matrix zeroed for
! elements corresponding to y(i)=miss
integer :: i,irank
! real:: w(10)
rinv=r
do i=1,ntrait
   if (y(i) == miss) then
      rinv(i,:)=0; rinv(:,i)=0
   endif
enddo
call ginv(rinv,ntrait,1e-5,irank)

end subroutine

end program lsqmt

subroutine solve_dense_gs(n,lhs,rhs,sol)
! finds sol in the system of linear equations: lhs*sol=rhs
! the solution is iterative by Gauss-Seidel
integer :: n
real :: lhs(n,n),rhs(n),sol(n),eps
integer :: round
!
round=0
do
   eps=0; round=round+1
   do i=1,n
      if (lhs(i,i).eq.0) cycle
      solnew=sol(i)+(rhs(i)-sum(lhs(i,:)*sol))/lhs(i,i)
      eps=eps+ (sol(i)-solnew)**2
      sol(i)=solnew
   end do
   print*,round,eps
   if (eps.lt. 1e-10) exit
end do
print*,'solutions computed in ',round,' rounds of iteration'
end subroutine