JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayGetShape(JNIEnv *env, jobject obj, jobject ndArrayHandle, jobject ndimRef, jobject dataBuf) { jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); jlong ndArrayPtr = env->GetLongField(ndArrayHandle, refLongFid); mx_uint ndim; const mx_uint *pdata; int ret = MXNDArrayGetShape((NDArrayHandle)ndArrayPtr, &ndim, &pdata); // fill dataBuf jclass integerClass = env->FindClass("java/lang/Integer"); jmethodID newInteger = env->GetMethodID(integerClass, "<init>", "(I)V"); jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); jmethodID arrayAppend = env->GetMethodID(arrayClass, "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); for (int i = 0; i < ndim; ++i) { jobject data = env->NewObject(integerClass, newInteger, pdata[i]); env->CallObjectMethod(dataBuf, arrayAppend, data); } // set ndimRef jclass refIntClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt"); jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I"); env->SetIntField(ndimRef, valueInt, ndim); return ret; }
inline std::vector<mx_uint> NDArray::GetShape() const { const mx_uint *out_pdata; mx_uint out_dim; MXNDArrayGetShape(blob_ptr_->handle_, &out_dim, &out_pdata); std::vector<mx_uint> ret; for (mx_uint i = 0; i < out_dim; ++i) { ret.push_back(out_pdata[i]); } return ret; }