JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayGetShape(JNIEnv *env, jobject obj,
                                                                jobject ndArrayHandle,
                                                                jobject ndimRef,
                                                                jobject dataBuf) {
  jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
  jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
  jlong ndArrayPtr = env->GetLongField(ndArrayHandle, refLongFid);

  mx_uint ndim;
  const mx_uint *pdata;
  int ret = MXNDArrayGetShape((NDArrayHandle)ndArrayPtr, &ndim, &pdata);

  // fill dataBuf
  jclass integerClass = env->FindClass("java/lang/Integer");
  jmethodID newInteger = env->GetMethodID(integerClass, "<init>", "(I)V");

  jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
  jmethodID arrayAppend = env->GetMethodID(arrayClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
  for (int i = 0; i < ndim; ++i) {
    jobject data = env->NewObject(integerClass, newInteger, pdata[i]);
    env->CallObjectMethod(dataBuf, arrayAppend, data);
  }

  // set ndimRef
  jclass refIntClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt");
  jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
  env->SetIntField(ndimRef, valueInt, ndim);

  return ret;
}
Пример #2
0
inline std::vector<mx_uint> NDArray::GetShape() const {
  const mx_uint *out_pdata;
  mx_uint out_dim;
  MXNDArrayGetShape(blob_ptr_->handle_, &out_dim, &out_pdata);
  std::vector<mx_uint> ret;
  for (mx_uint i = 0; i < out_dim; ++i) {
    ret.push_back(out_pdata[i]);
  }
  return ret;
}