inline NDArray::NDArray(const Shape &shape, const Context &context, bool delay_alloc) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(), context.GetDeviceId(), delay_alloc, &handle), 0); blob_ptr_ = std::make_shared<NDBlob>(handle); }
inline NDArray::NDArray(const std::vector<mx_float> &data, const Shape &shape, const Context &context) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(), context.GetDeviceId(), false, &handle), 0); MXNDArraySyncCopyFromCPU(handle, data.data(), shape.Size()); blob_ptr_ = std::make_shared<NDBlob>(handle); }
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayCreate(JNIEnv *env, jobject obj, jintArray shape, jint ndim, jint devType, jint devId, jint delayAlloc, jobject ndArrayHandle) { jint *shapeArr = env->GetIntArrayElements(shape, NULL); NDArrayHandle out; int ret = MXNDArrayCreate((mx_uint *)shapeArr, (mx_uint)ndim, devType, devId, delayAlloc, &out); env->ReleaseIntArrayElements(shape, shapeArr, 0); jclass ndClass = env->GetObjectClass(ndArrayHandle); jfieldID ptr = env->GetFieldID(ndClass, "value", "J"); env->SetLongField(ndArrayHandle, ptr, (long)out); return ret; }