Ejemplo n.º 1
0
void LocalTrr2kKernel
( UpperOrLower uplo,
  Orientation orientA, Orientation orientB,
  Orientation orientC, Orientation orientD,
  T alpha, const ElementalMatrix<T>& A, const ElementalMatrix<T>& B,
  T beta,  const ElementalMatrix<T>& C, const ElementalMatrix<T>& D,
                 ElementalMatrix<T>& E )
{
    DEBUG_CSE

    const bool transA = orientA != NORMAL;
    const bool transB = orientB != NORMAL;
    const bool transC = orientC != NORMAL;
    const bool transD = orientD != NORMAL;
    // TODO: Stringent distribution and alignment checks

    typedef ElementalMatrix<T> ADM;
    auto A0 = unique_ptr<ADM>( A.Construct(A.Grid(),A.Root()) );
    auto A1 = unique_ptr<ADM>( A.Construct(A.Grid(),A.Root()) );
    auto B0 = unique_ptr<ADM>( B.Construct(B.Grid(),B.Root()) );
    auto B1 = unique_ptr<ADM>( B.Construct(B.Grid(),B.Root()) );
    auto C0 = unique_ptr<ADM>( C.Construct(C.Grid(),C.Root()) );
    auto C1 = unique_ptr<ADM>( C.Construct(C.Grid(),C.Root()) );
    auto D0 = unique_ptr<ADM>( D.Construct(D.Grid(),D.Root()) );
    auto D1 = unique_ptr<ADM>( D.Construct(D.Grid(),D.Root()) );
    auto ETL = unique_ptr<ADM>( E.Construct(E.Grid(),E.Root()) );
    auto ETR = unique_ptr<ADM>( E.Construct(E.Grid(),E.Root()) );
    auto EBL = unique_ptr<ADM>( E.Construct(E.Grid(),E.Root()) );
    auto EBR = unique_ptr<ADM>( E.Construct(E.Grid(),E.Root()) );
    auto FTL = unique_ptr<ADM>( E.Construct(E.Grid(),E.Root()) );
    auto FBR = unique_ptr<ADM>( E.Construct(E.Grid(),E.Root()) );

    const Int half = E.Height() / 2;
    if( transA )
        LockedPartitionRight( A, *A0, *A1, half );
    else
        LockedPartitionDown( A, *A0, *A1, half );
    if( transB )
        LockedPartitionDown( B, *B0, *B1, half );
    else
        LockedPartitionRight( B, *B0, *B1, half );
    if( transC )
        LockedPartitionRight( C, *C0, *C1, half );
    else
        LockedPartitionDown( C, *C0, *C1, half );
    if( transD )
        LockedPartitionDown( D, *D0, *D1, half );
    else
        LockedPartitionRight( D, *D0, *D1, half );
    PartitionDownDiagonal( E, *ETL, *ETR, *EBL, *EBR, half );

    if( uplo == LOWER )
    {
        Gemm
        ( orientA, orientB, 
          alpha, A1->LockedMatrix(), B0->LockedMatrix(), 
          T(1), EBL->Matrix() );
        Gemm
        ( orientC, orientD, 
          beta, C1->LockedMatrix(), D0->LockedMatrix(), 
          T(1), EBL->Matrix() );
    }
    else
    {
        Gemm
        ( orientA, orientB, 
          alpha, A0->LockedMatrix(), B1->LockedMatrix(), 
          T(1), ETR->Matrix() );
        Gemm
        ( orientC, orientD, 
          beta, C0->LockedMatrix(), D1->LockedMatrix(), 
          T(1), ETR->Matrix() );
    }

    FTL->AlignWith( *ETL );
    FTL->Resize( ETL->Height(), ETL->Width() );
    Gemm
    ( orientA, orientB, 
      alpha, A0->LockedMatrix(), B0->LockedMatrix(),
      T(0), FTL->Matrix() );
    Gemm
    ( orientC, orientD,
      beta, C0->LockedMatrix(), D0->LockedMatrix(),
      T(1), FTL->Matrix() );
    AxpyTrapezoid( uplo, T(1), *FTL, *ETL );

    FBR->AlignWith( *EBR );
    FBR->Resize( EBR->Height(), EBR->Width() );
    Gemm
    ( orientA, orientB, 
      alpha, A1->LockedMatrix(), B1->LockedMatrix(),
      T(0), FBR->Matrix() );
    Gemm
    ( orientC, orientD,
      beta, C1->LockedMatrix(), D1->LockedMatrix(),
      T(1), FBR->Matrix() );
    AxpyTrapezoid( uplo, T(1), *FBR, *EBR );
}
Ejemplo n.º 2
0
inline void
DistMatrix<T,MD,STAR,Int>::AlignColsWith( const DistMatrix<S,STAR,MD,N>& A )
{ AlignWith( A ); }