コード例 #1
0
ファイル: Apply.cpp プロジェクト: restrin/Elemental
void Apply
( const DistMultiVec<Real>& x, 
  const DistMultiVec<Real>& y,
        DistMultiVec<Real>& z,
  const DistMultiVec<Int>& orders, 
  const DistMultiVec<Int>& firstInds,
  Int cutoff )
{
    DEBUG_ONLY(CSE cse("soc::Apply"))
    soc::Dots( x, y, z, orders, firstInds );
    auto xRoots = x;
    auto yRoots = y;
    cone::Broadcast( xRoots, orders, firstInds );
    cone::Broadcast( yRoots, orders, firstInds );

    const Int firstLocalRow = x.FirstLocalRow();
    const Int localHeight = x.LocalHeight();
    const Real* xBuf     = x.LockedMatrix().LockedBuffer();
    const Real* xRootBuf = xRoots.LockedMatrix().LockedBuffer();
    const Real* yBuf     = y.LockedMatrix().LockedBuffer();
    const Real* yRootBuf = yRoots.LockedMatrix().LockedBuffer();
          Real* zBuf     = z.Matrix().Buffer();
    const Int* firstIndBuf = firstInds.LockedMatrix().LockedBuffer();

    for( Int iLoc=0; iLoc<localHeight; ++iLoc )
    {
        const Int i = iLoc + firstLocalRow;
        const Int firstInd = firstIndBuf[iLoc];
        if( i != firstInd )
            zBuf[iLoc] += xRootBuf[iLoc]*yBuf[iLoc] + yRootBuf[iLoc]*xBuf[iLoc];
    }
}
コード例 #2
0
ファイル: Local.hpp プロジェクト: timwee/Elemental
void LocalTrr2kKernel
( UpperOrLower uplo,
  Orientation orientA, Orientation orientB,
  Orientation orientC, Orientation orientD,
  T alpha, const ElementalMatrix<T>& A, const ElementalMatrix<T>& B,
  T beta,  const ElementalMatrix<T>& C, const ElementalMatrix<T>& D,
                 ElementalMatrix<T>& E )
{
    DEBUG_CSE

    const bool transA = orientA != NORMAL;
    const bool transB = orientB != NORMAL;
    const bool transC = orientC != NORMAL;
    const bool transD = orientD != NORMAL;
    // TODO: Stringent distribution and alignment checks

    typedef ElementalMatrix<T> ADM;
    auto A0 = unique_ptr<ADM>( A.Construct(A.Grid(),A.Root()) );
    auto A1 = unique_ptr<ADM>( A.Construct(A.Grid(),A.Root()) );
    auto B0 = unique_ptr<ADM>( B.Construct(B.Grid(),B.Root()) );
    auto B1 = unique_ptr<ADM>( B.Construct(B.Grid(),B.Root()) );
    auto C0 = unique_ptr<ADM>( C.Construct(C.Grid(),C.Root()) );
    auto C1 = unique_ptr<ADM>( C.Construct(C.Grid(),C.Root()) );
    auto D0 = unique_ptr<ADM>( D.Construct(D.Grid(),D.Root()) );
    auto D1 = unique_ptr<ADM>( D.Construct(D.Grid(),D.Root()) );
    auto ETL = unique_ptr<ADM>( E.Construct(E.Grid(),E.Root()) );
    auto ETR = unique_ptr<ADM>( E.Construct(E.Grid(),E.Root()) );
    auto EBL = unique_ptr<ADM>( E.Construct(E.Grid(),E.Root()) );
    auto EBR = unique_ptr<ADM>( E.Construct(E.Grid(),E.Root()) );
    auto FTL = unique_ptr<ADM>( E.Construct(E.Grid(),E.Root()) );
    auto FBR = unique_ptr<ADM>( E.Construct(E.Grid(),E.Root()) );

    const Int half = E.Height() / 2;
    if( transA )
        LockedPartitionRight( A, *A0, *A1, half );
    else
        LockedPartitionDown( A, *A0, *A1, half );
    if( transB )
        LockedPartitionDown( B, *B0, *B1, half );
    else
        LockedPartitionRight( B, *B0, *B1, half );
    if( transC )
        LockedPartitionRight( C, *C0, *C1, half );
    else
        LockedPartitionDown( C, *C0, *C1, half );
    if( transD )
        LockedPartitionDown( D, *D0, *D1, half );
    else
        LockedPartitionRight( D, *D0, *D1, half );
    PartitionDownDiagonal( E, *ETL, *ETR, *EBL, *EBR, half );

    if( uplo == LOWER )
    {
        Gemm
        ( orientA, orientB, 
          alpha, A1->LockedMatrix(), B0->LockedMatrix(), 
          T(1), EBL->Matrix() );
        Gemm
        ( orientC, orientD, 
          beta, C1->LockedMatrix(), D0->LockedMatrix(), 
          T(1), EBL->Matrix() );
    }
    else
    {
        Gemm
        ( orientA, orientB, 
          alpha, A0->LockedMatrix(), B1->LockedMatrix(), 
          T(1), ETR->Matrix() );
        Gemm
        ( orientC, orientD, 
          beta, C0->LockedMatrix(), D1->LockedMatrix(), 
          T(1), ETR->Matrix() );
    }

    FTL->AlignWith( *ETL );
    FTL->Resize( ETL->Height(), ETL->Width() );
    Gemm
    ( orientA, orientB, 
      alpha, A0->LockedMatrix(), B0->LockedMatrix(),
      T(0), FTL->Matrix() );
    Gemm
    ( orientC, orientD,
      beta, C0->LockedMatrix(), D0->LockedMatrix(),
      T(1), FTL->Matrix() );
    AxpyTrapezoid( uplo, T(1), *FTL, *ETL );

    FBR->AlignWith( *EBR );
    FBR->Resize( EBR->Height(), EBR->Width() );
    Gemm
    ( orientA, orientB, 
      alpha, A1->LockedMatrix(), B1->LockedMatrix(),
      T(0), FBR->Matrix() );
    Gemm
    ( orientC, orientD,
      beta, C1->LockedMatrix(), D1->LockedMatrix(),
      T(1), FBR->Matrix() );
    AxpyTrapezoid( uplo, T(1), *FBR, *EBR );
}