Пример #1
0
inline void
TrsmRLT
( Orientation orientation, UnitOrNonUnit diag,
  F alpha, const DistMatrix<F>& L, DistMatrix<F>& X,
  bool checkIfSingular )
{
#ifndef RELEASE
    CallStackEntry entry("internal::TrsmRLT");
    if( orientation == NORMAL )
        LogicError("TrsmRLT expects a (Conjugate)Transpose option");
#endif
    const Grid& g = L.Grid();

    // Matrix views
    DistMatrix<F> 
        LTL(g), LTR(g),  L00(g), L01(g), L02(g),
        LBL(g), LBR(g),  L10(g), L11(g), L12(g),
                         L20(g), L21(g), L22(g);

    DistMatrix<F> XL(g), XR(g),
                  X0(g), X1(g), X2(g);

    // Temporary distributions
    DistMatrix<F,STAR,STAR> L11_STAR_STAR(g);
    DistMatrix<F,VR,  STAR> L21_VR_STAR(g);
    DistMatrix<F,STAR,MR  > L21AdjOrTrans_STAR_MR(g);
    DistMatrix<F,VC,  STAR> X1_VC_STAR(g);
    DistMatrix<F,STAR,MC  > X1Trans_STAR_MC(g);

    // Start the algorithm
    Scale( alpha, X );
    LockedPartitionDownDiagonal
    ( L, LTL, LTR,
         LBL, LBR, 0 );
    PartitionRight( X, XL, XR, 0 );
    while( XR.Width() > 0 )
    {
        LockedRepartitionDownDiagonal
        ( LTL, /**/ LTR,  L00, /**/ L01, L02,
         /*************/ /******************/
               /**/       L10, /**/ L11, L12,
          LBL, /**/ LBR,  L20, /**/ L21, L22 );

        RepartitionRight
        ( XL, /**/     XR,
          X0, /**/ X1, X2 );

        X1_VC_STAR.AlignWith( X2 );
        X1Trans_STAR_MC.AlignWith( X2 );
        L21_VR_STAR.AlignWith( X2 );
        L21AdjOrTrans_STAR_MR.AlignWith( X2 );
        //--------------------------------------------------------------------//
        L11_STAR_STAR = L11; 
        X1_VC_STAR = X1;  
        
        LocalTrsm
        ( RIGHT, LOWER, orientation, diag, 
          F(1), L11_STAR_STAR, X1_VC_STAR, checkIfSingular );

        X1Trans_STAR_MC.TransposeFrom( X1_VC_STAR );
        X1.TransposeFrom( X1Trans_STAR_MC );
        L21_VR_STAR = L21;
        if( orientation == ADJOINT )
            L21AdjOrTrans_STAR_MR.AdjointFrom( L21_VR_STAR );
        else
            L21AdjOrTrans_STAR_MR.TransposeFrom( L21_VR_STAR );

        // X2[MC,MR] -= X1[MC,*] (L21[MR,*])^(T/H)
        //            = X1^T[* ,MC] (L21^(T/H))[*,MR]
        LocalGemm
        ( TRANSPOSE, NORMAL, 
          F(-1), X1Trans_STAR_MC, L21AdjOrTrans_STAR_MR, F(1), X2 );
        //--------------------------------------------------------------------//

        SlideLockedPartitionDownDiagonal
        ( LTL, /**/ LTR,  L00, L01, /**/ L02,
               /**/       L10, L11, /**/ L12,
         /*************/ /******************/
          LBL, /**/ LBR,  L20, L21, /**/ L22 );

        SlidePartitionRight
        ( XL,     /**/ XR,
          X0, X1, /**/ X2 );
    }
}
Пример #2
0
inline void
Inverse( DistMatrix<F>& A )
{
#ifndef RELEASE
    PushCallStack("Inverse");
    if( A.Height() != A.Width() )
        throw std::logic_error("Cannot invert non-square matrices");
#endif
    const Grid& g = A.Grid();
    DistMatrix<int,VC,STAR> p( g );
    LU( A, p );
    TriangularInverse( UPPER, NON_UNIT, A );

    // Solve inv(A) L = inv(U) for inv(A)
    DistMatrix<F> ATL(g), ATR(g),
               ABL(g), ABR(g);
    DistMatrix<F> A00(g), A01(g), A02(g),
               A10(g), A11(g), A12(g),
               A20(g), A21(g), A22(g);
    DistMatrix<F> A1(g), A2(g);
    DistMatrix<F,VC,  STAR> A1_VC_STAR(g);
    DistMatrix<F,STAR,STAR> L11_STAR_STAR(g);
    DistMatrix<F,VR,  STAR> L21_VR_STAR(g);
    DistMatrix<F,STAR,MR  > L21Trans_STAR_MR(g);
    DistMatrix<F,MC,  STAR> Z1(g);
    PartitionUpDiagonal
    ( A, ATL, ATR,
      ABL, ABR, 0 );
    while( ABR.Height() < A.Height() )
    {
        RepartitionUpDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
          /**/       A10, A11, /**/ A12,
          /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );

        View( A1, A, 0, A00.Width(),             A.Height(), A01.Width() );
        View( A2, A, 0, A00.Width()+A01.Width(), A.Height(), A02.Width() );

        L21_VR_STAR.AlignWith( A2 );
        L21Trans_STAR_MR.AlignWith( A2 );
        Z1.AlignWith( A01 );
        Z1.ResizeTo( A.Height(), A01.Width() );
        //--------------------------------------------------------------------//
        // Copy out L1
        L11_STAR_STAR = A11;
        L21_VR_STAR = A21;
        L21Trans_STAR_MR.TransposeFrom( L21_VR_STAR );

        // Zero the strictly lower triangular portion of A1
        MakeTrapezoidal( LEFT, UPPER, 0, A11 );
        Zero( A21 );

        // Perform the lazy update of A1
        internal::LocalGemm
        ( NORMAL, TRANSPOSE,
          F(-1), A2, L21Trans_STAR_MR, F(0), Z1 );
        A1.SumScatterUpdate( F(1), Z1 );

        // Solve against this diagonal block of L11
        A1_VC_STAR = A1;
        internal::LocalTrsm
        ( RIGHT, LOWER, NORMAL, UNIT, F(1), L11_STAR_STAR, A1_VC_STAR );
        A1 = A1_VC_STAR;
        //--------------------------------------------------------------------//
        Z1.FreeAlignments();
        L21Trans_STAR_MR.FreeAlignments();
        L21_VR_STAR.FreeAlignments();

        SlidePartitionUpDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
          /*************/ /*******************/
          /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );
    }

    // inv(A) := inv(A) P
    ApplyInverseColumnPivots( A, p );
#ifndef RELEASE
    PopCallStack();
#endif
}
Пример #3
0
inline void
TwoSidedTrsmLVar5
( UnitOrNonUnit diag, DistMatrix<F>& A, const DistMatrix<F>& L )
{
#ifndef RELEASE
    CallStackEntry entry("internal::TwoSidedTrsmLVar5");
    if( A.Height() != A.Width() )
        LogicError("A must be square");
    if( L.Height() != L.Width() )
        LogicError("Triangular matrices must be square");
    if( A.Height() != L.Height() )
        LogicError("A and L must be the same size");
#endif
    const Grid& g = A.Grid();

    // Matrix views
    DistMatrix<F>
    ATL(g), ATR(g),  A00(g), A01(g), A02(g),
        ABL(g), ABR(g),  A10(g), A11(g), A12(g),
        A20(g), A21(g), A22(g);
    DistMatrix<F>
    LTL(g), LTR(g),  L00(g), L01(g), L02(g),
        LBL(g), LBR(g),  L10(g), L11(g), L12(g),
        L20(g), L21(g), L22(g);

    // Temporary distributions
    DistMatrix<F,STAR,STAR> A11_STAR_STAR(g);
    DistMatrix<F,MC,  STAR> A21_MC_STAR(g);
    DistMatrix<F,VC,  STAR> A21_VC_STAR(g);
    DistMatrix<F,VR,  STAR> A21_VR_STAR(g);
    DistMatrix<F,STAR,MR  > A21Adj_STAR_MR(g);
    DistMatrix<F,STAR,STAR> L11_STAR_STAR(g);
    DistMatrix<F,MC,  STAR> L21_MC_STAR(g);
    DistMatrix<F,VC,  STAR> L21_VC_STAR(g);
    DistMatrix<F,VR,  STAR> L21_VR_STAR(g);
    DistMatrix<F,STAR,MR  > L21Adj_STAR_MR(g);
    DistMatrix<F,VC,  STAR> Y21_VC_STAR(g);
    DistMatrix<F> Y21(g);

    PartitionDownDiagonal
    ( A, ATL, ATR,
      ABL, ABR, 0 );
    LockedPartitionDownDiagonal
    ( L, LTL, LTR,
      LBL, LBR, 0 );
    while( ATL.Height() < A.Height() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
          /*************/ /******************/
          /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );

        LockedRepartitionDownDiagonal
        ( LTL, /**/ LTR,  L00, /**/ L01, L02,
          /*************/ /******************/
          /**/       L10, /**/ L11, L12,
          LBL, /**/ LBR,  L20, /**/ L21, L22 );

        A21_MC_STAR.AlignWith( A22 );
        A21_VC_STAR.AlignWith( A22 );
        A21_VR_STAR.AlignWith( A22 );
        A21Adj_STAR_MR.AlignWith( A22 );
        L21_MC_STAR.AlignWith( A22 );
        L21_VC_STAR.AlignWith( A22 );
        L21_VR_STAR.AlignWith( A22 );
        L21Adj_STAR_MR.AlignWith( A22 );
        Y21.AlignWith( A21 );
        Y21_VC_STAR.AlignWith( A22 );
        //--------------------------------------------------------------------//
        // A11 := inv(L11) A11 inv(L11)'
        L11_STAR_STAR = L11;
        A11_STAR_STAR = A11;
        LocalTwoSidedTrsm( LOWER, diag, A11_STAR_STAR, L11_STAR_STAR );
        A11 = A11_STAR_STAR;

        // Y21 := L21 A11
        L21_VC_STAR = L21;
        Zeros( Y21_VC_STAR, A21.Height(), A21.Width() );
        Hemm
        ( RIGHT, LOWER,
          F(1), A11_STAR_STAR.Matrix(), L21_VC_STAR.Matrix(),
          F(0), Y21_VC_STAR.Matrix() );
        Y21 = Y21_VC_STAR;

        // A21 := A21 inv(L11)'
        A21_VC_STAR = A21;
        LocalTrsm
        ( RIGHT, LOWER, ADJOINT, diag, F(1), L11_STAR_STAR, A21_VC_STAR );
        A21 = A21_VC_STAR;

        // A21 := A21 - 1/2 Y21
        Axpy( F(-1)/F(2), Y21, A21 );

        // A22 := A22 - (L21 A21' + A21 L21')
        A21_MC_STAR = A21;
        L21_MC_STAR = L21;
        A21_VC_STAR = A21_MC_STAR;
        A21_VR_STAR = A21_VC_STAR;
        L21_VR_STAR = L21_VC_STAR;
        A21Adj_STAR_MR.AdjointFrom( A21_VR_STAR );
        L21Adj_STAR_MR.AdjointFrom( L21_VR_STAR );
        LocalTrr2k
        ( LOWER,
          F(-1), L21_MC_STAR, A21Adj_STAR_MR,
          A21_MC_STAR, L21Adj_STAR_MR,
          F(1), A22 );

        // A21 := A21 - 1/2 Y21
        Axpy( F(-1)/F(2), Y21, A21 );

        // A21 := inv(L22) A21
        //
        // This is the bottleneck because A21 only has blocksize columns
        Trsm( LEFT, LOWER, NORMAL, diag, F(1), L22, A21 );
        //--------------------------------------------------------------------//

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
          /**/       A10, A11, /**/ A12,
          /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );

        SlideLockedPartitionDownDiagonal
        ( LTL, /**/ LTR,  L00, L01, /**/ L02,
          /**/       L10, L11, /**/ L12,
          /**********************************/
          LBL, /**/ LBR,  L20, L21, /**/ L22 );
    }
}
Пример #4
0
inline void
TwoSidedTrmmLVar2
( UnitOrNonUnit diag, DistMatrix<F>& A, const DistMatrix<F>& L )
{
#ifndef RELEASE
    PushCallStack("internal::TwoSidedTrmmLVar2");
    if( A.Height() != A.Width() )
        throw std::logic_error( "A must be square." );
    if( L.Height() != L.Width() )
        throw std::logic_error( "Triangular matrices must be square." );
    if( A.Height() != L.Height() )
        throw std::logic_error( "A and L must be the same size." );
#endif
    const Grid& g = A.Grid();

    // Matrix views
    DistMatrix<F>
    ATL(g), ATR(g),  A00(g), A01(g), A02(g),
        ABL(g), ABR(g),  A10(g), A11(g), A12(g),
        A20(g), A21(g), A22(g);
    DistMatrix<F>
    LTL(g), LTR(g),  L00(g), L01(g), L02(g),
        LBL(g), LBR(g),  L10(g), L11(g), L12(g),
        L20(g), L21(g), L22(g);

    // Temporary distributions
    DistMatrix<F,STAR,VR  > A10_STAR_VR(g);
    DistMatrix<F,STAR,STAR> A11_STAR_STAR(g);
    DistMatrix<F,VC,  STAR> A21_VC_STAR(g);
    DistMatrix<F,STAR,STAR> L11_STAR_STAR(g);
    DistMatrix<F,MC,  STAR> L21_MC_STAR(g);
    DistMatrix<F,STAR,MR  > L21Adj_STAR_MR(g);
    DistMatrix<F,VC,  STAR> L21_VC_STAR(g);
    DistMatrix<F,VR,  STAR> L21_VR_STAR(g);
    DistMatrix<F,STAR,MR  > X10_STAR_MR(g);
    DistMatrix<F,STAR,STAR> X11_STAR_STAR(g);
    DistMatrix<F,MC,  STAR> Z21_MC_STAR(g);
    DistMatrix<F,MR,  STAR> Z21_MR_STAR(g);
    DistMatrix<F,MR,  MC  > Z21_MR_MC(g);
    DistMatrix<F> Y21(g);

    PartitionDownDiagonal
    ( A, ATL, ATR,
      ABL, ABR, 0 );
    LockedPartitionDownDiagonal
    ( L, LTL, LTR,
      LBL, LBR, 0 );
    while( ATL.Height() < A.Height() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
          /*************/ /******************/
          /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );

        LockedRepartitionDownDiagonal
        ( LTL, /**/ LTR,  L00, /**/ L01, L02,
          /*************/ /******************/
          /**/       L10, /**/ L11, L12,
          LBL, /**/ LBR,  L20, /**/ L21, L22 );

        A21_VC_STAR.AlignWith( A22 );
        L21_MC_STAR.AlignWith( A20 );
        L21_VC_STAR.AlignWith( A22 );
        L21_VR_STAR.AlignWith( A22 );
        L21Adj_STAR_MR.AlignWith( A22 );
        X10_STAR_MR.AlignWith( A10 );
        Y21.AlignWith( A21 );
        Z21_MC_STAR.AlignWith( A22 );
        Z21_MR_STAR.AlignWith( A22 );
        //--------------------------------------------------------------------//
        // A10 := L11' A10
        L11_STAR_STAR = L11;
        A10_STAR_VR = A10;
        LocalTrmm
        ( LEFT, LOWER, ADJOINT, diag, F(1), L11_STAR_STAR, A10_STAR_VR );
        A10 = A10_STAR_VR;

        // A10 := A10 + L21' A20
        L21_MC_STAR = L21;
        X10_STAR_MR.ResizeTo( A10.Height(), A10.Width() );
        LocalGemm( ADJOINT, NORMAL, F(1), L21_MC_STAR, A20, F(0), X10_STAR_MR );
        A10.SumScatterUpdate( F(1), X10_STAR_MR );

        // Y21 := A22 L21
        L21_VC_STAR = L21_MC_STAR;
        L21_VR_STAR = L21_VC_STAR;
        L21Adj_STAR_MR.AdjointFrom( L21_VR_STAR );
        Z21_MC_STAR.ResizeTo( A21.Height(), A21.Width() );
        Z21_MR_STAR.ResizeTo( A21.Height(), A21.Width() );
        Zero( Z21_MC_STAR );
        Zero( Z21_MR_STAR );
        LocalSymmetricAccumulateLL
        ( ADJOINT,
          F(1), A22, L21_MC_STAR, L21Adj_STAR_MR, Z21_MC_STAR, Z21_MR_STAR );
        Z21_MR_MC.SumScatterFrom( Z21_MR_STAR );
        Y21 = Z21_MR_MC;
        Y21.SumScatterUpdate( F(1), Z21_MC_STAR );

        // A21 := A21 L11
        A21_VC_STAR = A21;
        LocalTrmm
        ( RIGHT, LOWER, NORMAL, diag, F(1), L11_STAR_STAR, A21_VC_STAR );
        A21 = A21_VC_STAR;

        // A21 := A21 + 1/2 Y21
        Axpy( F(1)/F(2), Y21, A21 );

        // A11 := L11' A11 L11
        A11_STAR_STAR = A11;
        LocalTwoSidedTrmm( LOWER, diag, A11_STAR_STAR, L11_STAR_STAR );
        A11 = A11_STAR_STAR;

        // A11 := A11 + (A21' L21 + L21' A21)
        A21_VC_STAR = A21;
        X11_STAR_STAR.ResizeTo( A11.Height(), A11.Width() );
        Her2k
        ( LOWER, ADJOINT,
          F(1), A21_VC_STAR.LocalMatrix(), L21_VC_STAR.LocalMatrix(),
          F(0), X11_STAR_STAR.LocalMatrix() );
        A11.SumScatterUpdate( F(1), X11_STAR_STAR );

        // A21 := A21 + 1/2 Y21
        Axpy( F(1)/F(2), Y21, A21 );
        //--------------------------------------------------------------------//
        A21_VC_STAR.FreeAlignments();
        L21_MC_STAR.FreeAlignments();
        L21_VC_STAR.FreeAlignments();
        L21_VR_STAR.FreeAlignments();
        L21Adj_STAR_MR.FreeAlignments();
        X10_STAR_MR.FreeAlignments();
        Y21.FreeAlignments();
        Z21_MC_STAR.FreeAlignments();
        Z21_MR_STAR.FreeAlignments();

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
          /**/       A10, A11, /**/ A12,
          /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );

        SlideLockedPartitionDownDiagonal
        ( LTL, /**/ LTR,  L00, L01, /**/ L02,
          /**/       L10, L11, /**/ L12,
          /*************/ /******************/
          LBL, /**/ LBR,  L20, L21, /**/ L22 );
    }
#ifndef RELEASE
    PopCallStack();
#endif
}