Ejemplo n.º 1
0
inline NDArray::NDArray(const Shape &shape, const Context &context, bool delay_alloc) {
  NDArrayHandle handle;
  CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(),
                           context.GetDeviceId(), delay_alloc, &handle),
           0);
  blob_ptr_ = std::make_shared<NDBlob>(handle);
}
Ejemplo n.º 2
0
inline NDArray::NDArray(const std::vector<mx_float> &data, const Shape &shape,
                        const Context &context) {
  NDArrayHandle handle;
  CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(),
                           context.GetDeviceId(), false, &handle),
           0);
  MXNDArraySyncCopyFromCPU(handle, data.data(), shape.Size());
  blob_ptr_ = std::make_shared<NDBlob>(handle);
}
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayCreate(JNIEnv *env, jobject obj,
                                                              jintArray shape,
                                                              jint ndim,
                                                              jint devType,
                                                              jint devId,
                                                              jint delayAlloc,
                                                              jobject ndArrayHandle) {
  jint *shapeArr = env->GetIntArrayElements(shape, NULL);
  NDArrayHandle out;
  int ret = MXNDArrayCreate((mx_uint *)shapeArr, (mx_uint)ndim, devType, devId, delayAlloc, &out);
  env->ReleaseIntArrayElements(shape, shapeArr, 0);
  jclass ndClass = env->GetObjectClass(ndArrayHandle);
  jfieldID ptr = env->GetFieldID(ndClass, "value", "J");
  env->SetLongField(ndArrayHandle, ptr, (long)out);
  return ret;
}