Beispiel #1
0
void Contract
( const BlockMatrix<T>& A,
        BlockMatrix<T>& B )
{
    DEBUG_ONLY(CSE cse("Contract"))
    AssertSameGrids( A, B );
    const Dist U = B.ColDist();
    const Dist V = B.RowDist();
    // TODO: Shorten this implementation?
    if( A.ColDist() == U && A.RowDist() == V )
    {
        Copy( A, B );
    }
    else if( A.ColDist() == U && A.RowDist() == Partial(V) )
    {
        B.AlignAndResize
        ( A.BlockHeight(), A.BlockWidth(),
          A.ColAlign(), A.RowAlign(), A.ColCut(), A.RowCut(),
          A.Height(), A.Width(), false, false );
        Zeros( B.Matrix(), B.LocalHeight(), B.LocalWidth() );
        AxpyContract( T(1), A, B );
    }
    else if( A.ColDist() == Partial(U) && A.RowDist() == V )
    {
        B.AlignAndResize
        ( A.BlockHeight(), A.BlockWidth(),
          A.ColAlign(), A.RowAlign(), A.ColCut(), A.RowCut(),
          A.Height(), A.Width(), false, false );
        Zeros( B.Matrix(), B.LocalHeight(), B.LocalWidth() );
        AxpyContract( T(1), A, B );
    }
    else if( A.ColDist() == U && A.RowDist() == Collect(V) )
    {
        B.AlignColsAndResize
        ( A.BlockHeight(), A.ColAlign(), A.ColCut(), A.Height(), A.Width(),
          false, false );
        Zeros( B.Matrix(), B.LocalHeight(), B.LocalWidth() );
        AxpyContract( T(1), A, B );
    }
    else if( A.ColDist() == Collect(U) && A.RowDist() == V )
    {
        B.AlignRowsAndResize
        ( A.BlockWidth(), A.RowAlign(), A.RowCut(), A.Height(), A.Width(),
          false, false );
        Zeros( B.Matrix(), B.LocalHeight(), B.LocalWidth() );
        AxpyContract( T(1), A, B );
    }
    else if( A.ColDist() == Collect(U) && A.RowDist() == Collect(V) )
    {
        Zeros( B, A.Height(), A.Width() );
        AxpyContract( T(1), A, B );
    }
    else
        LogicError("Incompatible distributions");
}
Beispiel #2
0
void LT_Dot
( T alpha,
  const AbstractDistMatrix<T>& APre,
  AbstractDistMatrix<T>& CPre,
  const bool conjugate,
  Int blockSize=2000 )
{
    EL_DEBUG_CSE
    const Int n = CPre.Height();
    const Grid& g = APre.Grid();

    const Orientation orient = ( conjugate ? ADJOINT : TRANSPOSE );

    DistMatrixReadProxy<T,T,VC,STAR> AProx( APre );
    auto& A = AProx.GetLocked();

    DistMatrixReadWriteProxy<T,T,MC,MR> CProx( CPre );
    auto& C = CProx.Get();

    DistMatrix<T,STAR,STAR> Z( blockSize, blockSize, g );
    Zero( Z );
    for( Int kOuter=0; kOuter<n; kOuter+=blockSize )
    {
        const Int nbOuter = Min(blockSize,n-kOuter);
        const Range<Int> indOuter( kOuter, kOuter+nbOuter );

        auto A1 = A( ALL, indOuter );
        auto C11 = C( indOuter, indOuter );

        Z.Resize( nbOuter, nbOuter );
        Syrk( LOWER, TRANSPOSE, alpha, A1.Matrix(), Z.Matrix(), conjugate );
        AxpyContract( T(1), Z, C11 );

        for( Int kInner=kOuter+nbOuter; kInner<n; kInner+=blockSize )
        {
            const Int nbInner = Min(blockSize,n-kInner);
            const Range<Int> indInner( kInner, kInner+nbInner );

            auto A2 = A( ALL, indInner );
            auto C21 = C( indInner, indOuter );

            LocalGemm( orient, NORMAL, alpha, A1, A2, Z );
            AxpyContract( T(1), Z, C21 );
        }
    }
}
Beispiel #3
0
void SUMMA_NTDot
( Orientation orientB,
  T alpha,
  const AbstractDistMatrix<T>& APre,
  const AbstractDistMatrix<T>& BPre,
        AbstractDistMatrix<T>& CPre,
  Int blockSize=2000 )
{
    EL_DEBUG_CSE
    const Int m = CPre.Height();
    const Int n = CPre.Width();
    const Grid& g = APre.Grid();

    DistMatrixReadProxy<T,T,STAR,VC> AProx( APre );
    auto& A = AProx.GetLocked();

    ElementalProxyCtrl BCtrl;
    BCtrl.rowConstrain = true;
    BCtrl.rowAlign = A.RowAlign();
    DistMatrixReadProxy<T,T,STAR,VC> BProx( BPre, BCtrl );
    auto& B = BProx.GetLocked();

    DistMatrixReadWriteProxy<T,T,MC,MR> CProx( CPre );
    auto& C = CProx.Get();

    DistMatrix<T,STAR,STAR> C11_STAR_STAR(g);
    for( Int kOuter=0; kOuter<m; kOuter+=blockSize )
    {
        const Int nbOuter = Min(blockSize,m-kOuter);
        const Range<Int> indOuter( kOuter, kOuter+nbOuter );

        auto A1 = A( indOuter, ALL );

        for( Int kInner=0; kInner<n; kInner+=blockSize )
        {
            const Int nbInner = Min(blockSize,n-kInner);
            const Range<Int> indInner( kInner, kInner+nbInner );

            auto B1  = B( indInner, ALL );
            auto C11 = C( indOuter, indInner );

            LocalGemm( NORMAL, orientB, alpha, A1, B1, C11_STAR_STAR );
            AxpyContract( T(1), C11_STAR_STAR, C11 );
        }
    }
}
Beispiel #4
0
void SUMMA_NTA
( Orientation orientB,
  T alpha,
  const AbstractDistMatrix<T>& APre,
  const AbstractDistMatrix<T>& BPre,
        AbstractDistMatrix<T>& CPre )
{
    EL_DEBUG_CSE
    const Int n = CPre.Width();
    const Int bsize = Blocksize();
    const Grid& g = APre.Grid();
    const bool conjugate = ( orientB == ADJOINT );

    DistMatrixReadProxy<T,T,MC,MR> AProx( APre );
    DistMatrixReadProxy<T,T,MC,MR> BProx( BPre );
    DistMatrixReadWriteProxy<T,T,MC,MR> CProx( CPre );
    auto& A = AProx.GetLocked();
    auto& B = BProx.GetLocked();
    auto& C = CProx.Get();

    // Temporary distributions
    DistMatrix<T,MR,STAR> B1Trans_MR_STAR(g);
    DistMatrix<T,MC,STAR> D1_MC_STAR(g);

    B1Trans_MR_STAR.AlignWith( A );
    D1_MC_STAR.AlignWith( A );

    for( Int k=0; k<n; k+=bsize )
    {
        const Int nb = Min(bsize,n-k);
        auto B1 = B( IR(k,k+nb), ALL        );
        auto C1 = C( ALL,        IR(k,k+nb) );

        // C1[MC,*] := alpha A[MC,MR] (B1^[T/H])[MR,*]
        Transpose( B1, B1Trans_MR_STAR, conjugate );
        LocalGemm( NORMAL, NORMAL, alpha, A, B1Trans_MR_STAR, D1_MC_STAR );

        // C1[MC,MR] += scattered result of D1[MC,*] summed over grid rows
        AxpyContract( T(1), D1_MC_STAR, C1 );
    }
}