Ejemplo n.º 1
0
Darray<float> cudot (const Darray<float>& lhs, const Darray<float>& rhs)
{
	// context check
	CHECK_EQ(lhs.getDeviceManager().getDeviceID(), rhs.getDeviceManager().getDeviceID());
	
	CHECK_EQ(lhs.ndim(), rhs.ndim());
	CHECK_LT(lhs.ndim(), 3);
	CHECK_LT(rhs.ndim(), 3);

	Darray<float> ret;

	if (lhs.ndim()==1 && rhs.ndim()==1)
	{
		// shape check
		CHECK_EQ(lhs.size(), rhs.size());
		ret = Darray<float>(lhs.getDeviceManager(), {1});
		
		// using cublas sdot
		lhs.deviceSet();
		cublasSdot (DeviceManager::handle,
				    lhs.size(),
				    lhs.data,
				    1,
				    rhs.data,
				    1,
				    ret.data);
	}
	// 2D matrix dot
	else if (lhs.ndim()==2 && rhs.ndim()==2)
	{
		// shape check
		CHECK_EQ(lhs.shape()[1], rhs.shape()[0]);
		ret = Darray<float>(lhs.getDeviceManager(), {lhs.shape()[0], rhs.shape()[1]});
		
		// using cublas sgemm
		lhs.deviceSet();
		const float alpha = 1.;
		const float beta = 0.;
		CUBLAS_SAFE_CALL(
		cublasSgemm (DeviceManager::handle,
					CUBLAS_OP_N,
					CUBLAS_OP_N,
					lhs.shape()[0],
					rhs.shape()[1],
					lhs.shape()[1],
					&alpha,
					lhs.dev_data,
					lhs.shape()[0],
					rhs.dev_data,
					rhs.shape()[0],
					&beta,
					ret.dev_data,
					ret.shape()[0])
		);
	}
	return ret;
}
Ejemplo n.º 2
0
float cunorm2 (const Darray<float>& ary)
{
	ary.deviceSet();
	float ret;
	CUBLAS_SAFE_CALL(
			cublasSnrm2 (DeviceManager::handle,
						 ary.size(),
						 ary.dev_data,
						 1,
						 &ret)
	);
	return ret;
}
Ejemplo n.º 3
0
double cunorm2 (const Darray<double>& ary)
{
	ary.deviceSet();
	double ret;
	CUBLAS_SAFE_CALL(
			cublasDnrm2 (DeviceManager::handle,
						 ary.size(),
						 ary.dev_data,
						 1,
						 &ret)
	);
	return ret;
}