void _unary_func(lua_State *L, double(*f)(double), Complex(*g)(Complex), int cast) { if (lua_isnumber(L, 1)) { const double x = lua_tonumber(L, 1); lua_pushnumber(L, f(x)); } else if (lunum_hasmetatable(L, 1, "complex")) { if (g == NULL) { return luaL_error(L, "complex operation not supported"); } const Complex z = lunum_checkcomplex(L, 1); lunum_pushcomplex(L, g(z)); } else if (lunum_hasmetatable(L, 1, "array")) { Array *A = (Array*) lunum_checkarray1(L, 1); if (cast == 0) { Array B = array_new_copy(A, A->dtype); switch (B.dtype) { case ARRAY_TYPE_BOOL : EXPR_EVALF(Bool , B.size, B.data); break; case ARRAY_TYPE_CHAR : EXPR_EVALF(char , B.size, B.data); break; case ARRAY_TYPE_SHORT : EXPR_EVALF(short , B.size, B.data); break; case ARRAY_TYPE_INT : EXPR_EVALF(int , B.size, B.data); break; case ARRAY_TYPE_LONG : EXPR_EVALF(long , B.size, B.data); break; case ARRAY_TYPE_SIZE_T : EXPR_EVALF(size_t , B.size, B.data); break; case ARRAY_TYPE_FLOAT : EXPR_EVALF(float , B.size, B.data); break; case ARRAY_TYPE_DOUBLE : EXPR_EVALF(double , B.size, B.data); break; case ARRAY_TYPE_COMPLEX : EXPR_EVALG(Complex, B.size, B.data); break; } lunum_pusharray1(L, &B); }
void *lunum_tovalue(lua_State *L, enum ArrayType T) { Complex x=0.0; if (lua_isnumber(L, -1)) { x = lua_tonumber(L, -1); } else if (lua_isboolean(L, -1)) { x = lua_toboolean(L, -1); } else if (lunum_hasmetatable(L, -1, "complex")) { x = *((Complex*) lua_touserdata(L, -1)); } else { luaL_error(L, "unkown data type"); } void *y = malloc(array_sizeof(T)); switch (T) { case ARRAY_TYPE_BOOL : *((Bool *)y) = x; break; case ARRAY_TYPE_CHAR : *((char *)y) = x; break; case ARRAY_TYPE_SHORT : *((short *)y) = x; break; case ARRAY_TYPE_INT : *((int *)y) = x; break; case ARRAY_TYPE_LONG : *((long *)y) = x; break; case ARRAY_TYPE_FLOAT : *((float *)y) = x; break; case ARRAY_TYPE_DOUBLE : *((double *)y) = x; break; case ARRAY_TYPE_COMPLEX : *((Complex*)y) = x; break; } return y; }
static int luaC_lunum_zeros(lua_State *L) { if (lua_isnumber(L, 1)) { const lua_Integer N = luaL_checkinteger(L, 1); if (N <= 0) { return luaL_error(L, "Invalid size %d", N); } const ArrayType T = (ArrayType) luaL_optinteger(L, 2, ARRAY_TYPE_DOUBLE); Array A = array_new_zeros(N, T); lunum_pusharray1(L, &A); return 1; } else if (lua_istable(L, 1) || lunum_hasmetatable(L, 1, "array")) { size_t Nd_t; size_t *N = (size_t*) lunum_checkarray2(L, 1, ARRAY_TYPE_SIZE_T, &Nd_t); int Nd = (int)Nd_t; const ArrayType T = (ArrayType) luaL_optinteger(L, 2, ARRAY_TYPE_DOUBLE); size_t ntot = 1; for (int d=0; d<Nd; ++d) ntot *= N[d]; Array A = array_new_zeros(ntot, T); array_resize_t(&A, N, Nd); lunum_pusharray1(L, &A); return 1; } else { return luaL_error(L, "argument must be either number, table, or array"); } }
static int _complex_binary_op1(lua_State *L, ArrayBinaryOperation op) { if (lunum_hasmetatable(L, 1, "array") || lunum_hasmetatable(L, 2, "array")) { return _array_binary_op(L, op); } if (!lunum_hasmetatable(L, 1, "complex")) { lunum_pushcomplex(L, lua_tonumber(L, 1)); lua_replace(L, 1); } if (!lunum_hasmetatable(L, 2, "complex")) { lunum_pushcomplex(L, lua_tonumber(L, 2)); lua_replace(L, 2); } return _complex_binary_op2(L, op); }
int luaC_lunum_zeros(lua_State *L) { if (lua_isnumber(L, 1)) { const int N = luaL_checkinteger(L, 1); const enum ArrayType T = (enum ArrayType) luaL_optinteger(L, 2, ARRAY_TYPE_DOUBLE); struct Array A = array_new_zeros(N, T); lunum_pusharray1(L, &A); return 1; } else if (lua_istable(L, 1) || lunum_hasmetatable(L, 1, "array")) { int Nd; int *N = (int*) lunum_checkarray2(L, 1, ARRAY_TYPE_INT, &Nd); const enum ArrayType T = (enum ArrayType) luaL_optinteger(L, 2, ARRAY_TYPE_DOUBLE); int ntot = 1; for (int d=0; d<Nd; ++d) ntot *= N[d]; struct Array A = array_new_zeros(ntot, T); array_resize(&A, N, Nd); lunum_pusharray1(L, &A); return 1; } else { luaL_error(L, "argument must be either number, table, or array"); return 0; } }
int _array_binary_op1(lua_State *L, enum ArrayOperation op) { if (!lunum_hasmetatable(L, 1, "array")) { struct Array *B = lunum_checkarray1(L, 2); lunum_upcast(L, 1, B->dtype, B->size); lua_replace(L, 1); struct Array *A = lunum_checkarray1(L, 1); array_resize(A, B->shape, B->ndims); } if (!lunum_hasmetatable(L, 2, "array")) { struct Array *A = lunum_checkarray1(L, 1); lunum_upcast(L, 2, A->dtype, A->size); lua_replace(L, 2); struct Array *B = lunum_checkarray1(L, 2); array_resize(B, A->shape, A->ndims); } return _array_binary_op2(L, op); }
static int _array_binary_op(lua_State *L, ArrayBinaryOperation op) { if ((lua_istable(L, 1) || lunum_hasmetatable(L, 1, "array")) && (lua_istable(L, 2) || lunum_hasmetatable(L, 2, "array"))) { /* both args are tables or arrays, upcast to arrays if not already */ if (!lunum_hasmetatable(L, 1, "array")) { Array *B = lunum_checkarray1(L, 2); lunum_upcast(L, 1, B->dtype, B->size); lua_replace(L, 1); Array *A = lunum_checkarray1(L, 1); array_resize_t(A, B->shape, B->ndims); } if (!lunum_hasmetatable(L, 2, "array")) { Array *A = lunum_checkarray1(L, 1); lunum_upcast(L, 2, A->dtype, A->size); lua_replace(L, 2); Array *B = lunum_checkarray1(L, 2); array_resize_t(B, A->shape, A->ndims); } return _array_array_binary_op(L, op); } else { /* one arg is not a table(array) */ return _array_number_binary_op(L, op, lunum_hasmetatable(L, 1, "array")); } }
struct Array *lunum_checkarray1(lua_State *L, int pos) { lua_pushvalue(L, pos); if (!lunum_hasmetatable(L, -1, "array")) { luaL_error(L, "bad argument #%d (array expected, got %s)", pos, lua_typename(L, lua_type(L, -1))); } lua_pushstring(L, "__cstruct"); lua_rawget(L, -2); struct Array *A = (struct Array*) lua_touserdata(L, -1); lua_pop(L, 2); return A; }
static int luaC_array__index(lua_State *L) { Array *A = lunum_checkarray1(L, 1); // Figure out what is the format of the input index. If it's a number or a // table of numbers, then pass it along to _get_index. If it's an array of bools, // then use it as a mask. // --------------------------------------------------------------------------- if (lunum_hasmetatable(L, 2, "array")) { Array *M = lunum_checkarray1(L, 2); if (M->dtype != ARRAY_TYPE_BOOL) { return luaL_error(L, "index array must be of type bool"); } Array B = array_new_from_mask(A, M); lunum_pusharray1(L, &B); return 1; } /* try to index into array */ int success; const size_t m = _get_index(L, A, &success); if (success) { _push_value(L, A->dtype, (char*)A->data + array_sizeof(A->dtype)*m); return 1; } /* check metatable */ lua_getmetatable(L, 1); lua_pushvalue(L, 2); if (lua_gettable(L, -2) != LUA_TNIL) { return 1; } return 0; }
int luaC_array__index(lua_State *L) { struct Array *A = lunum_checkarray1(L, 1); // Figure out what is the format of the input index. If it's a number or a // table of numbers, then pass it along to _get_index. If it's a table of // tables or numbers, then assume it's a slice. If it's an array of bools, // then use it as a mask. // --------------------------------------------------------------------------- if (lunum_hasmetatable(L, 2, "array")) { struct Array *M = lunum_checkarray1(L, 2); if (M->dtype != ARRAY_TYPE_BOOL) { luaL_error(L, "index array must be of type bool"); } struct Array B = array_new_from_mask(A, M); lunum_pusharray1(L, &B); return 1; } else if (lua_type(L, 2) == LUA_TTABLE || lua_type(L, 2) == LUA_TSTRING) { lua_getglobal(L, "lunum"); lua_getfield(L, -1, "__build_slice"); lua_remove(L, -2); lua_pushvalue(L, 1); lua_pushvalue(L, 2); lua_call(L, 2, 1); return 1; } const int m = _get_index(L, A); _push_value(L, A->dtype, (char*)A->data + array_sizeof(A->dtype)*m); return 1; }
int lunum_upcast(lua_State *L, int pos, enum ArrayType T, int N) // ----------------------------------------------------------------------------- // If the object at position 'pos' is already an array of dtype 'T', then push // nothing and return 0. If the dtype is not 'T', then return 1 and push a copy // of that array with dtype 'T' onto the stack. If it is a table, then push an // array of dtype 'T' having the length of the table. If it is a number or // complex, then push an array of dtype float or complex respectively having // length 'N'. // ----------------------------------------------------------------------------- { if (array_typename(T) == NULL) { luaL_error(L, "invalid array type"); } // Deal with lunum.array // --------------------------------------------------------------------------- if (lunum_hasmetatable(L, pos, "array")) { struct Array *A = lunum_checkarray1(L, pos); if (A->dtype == T) { return 0; } else { struct Array A_ = array_new_copy(A, T); lunum_pusharray1(L, &A_); return 1; } } // Deal with Lua table // --------------------------------------------------------------------------- else if (lua_istable(L, pos)) { struct Array A = array_new_zeros(lua_rawlen(L, pos), T); for (int i=0; i<A.size; ++i) { lua_pushnumber(L, i+1); lua_gettable(L, pos); void *val = lunum_tovalue(L, T); memcpy((char*)A.data + array_sizeof(T)*i, val, array_sizeof(T)); free(val); lua_pop(L, 1); } lunum_pusharray1(L, &A); return 1; } // Deal with Lua bool // --------------------------------------------------------------------------- else if (lua_isboolean(L, pos)) { const Bool x = lua_toboolean(L, pos); struct Array A = array_new_zeros(N, ARRAY_TYPE_BOOL); array_assign_from_scalar(&A, &x); lunum_pusharray1(L, &A); return 1; } // Deal with Lua numbers // --------------------------------------------------------------------------- else if (lua_isnumber(L, pos)) { const double x = lua_tonumber(L, pos); struct Array A = array_new_zeros(N, ARRAY_TYPE_DOUBLE); array_assign_from_scalar(&A, &x); struct Array B = array_new_copy(&A, T); array_del(&A); lunum_pusharray1(L, &B); return 1; } // Deal with lunum.complex // --------------------------------------------------------------------------- else if (lunum_hasmetatable(L, pos, "complex")) { const Complex z = *((Complex*) lua_touserdata(L, pos)); struct Array A = array_new_zeros(N, ARRAY_TYPE_COMPLEX); array_assign_from_scalar(&A, &z); lunum_pusharray1(L, &A); return 1; } // Throw an error // --------------------------------------------------------------------------- else { luaL_error(L, "cannot cast to array from object of dtype %s\n", lua_typename(L, lua_type(L, pos))); return 0; } }
static int _array_number_binary_op(lua_State *L, ArrayBinaryOperation op, Bool array_first) { const Array *A = array_first ? lunum_checkarray1(L, 1) : lunum_checkarray1(L, 2); ArrayType T = A->dtype; int num_pos = array_first ? 2 : 1; union { Bool b; char c; short s; int i; long l; size_t t; float f; double d; Complex z; lua_Integer li; lua_Number ln; } num; /* to force integer conversion if possible */ int isnum; if (lua_isboolean(L, num_pos)) { num.i = lua_toboolean(L, num_pos); /* number can't have a higher type, upgrade to type T */ switch (T) { case ARRAY_TYPE_BOOL : num.b = (Bool)num.i; break; case ARRAY_TYPE_CHAR : num.c = (char)num.i; break; case ARRAY_TYPE_SHORT : num.s = (short)num.i; break; case ARRAY_TYPE_INT : num.i = (int)num.i; break; case ARRAY_TYPE_LONG : num.l = (long)num.i; break; case ARRAY_TYPE_SIZE_T : num.t = (size_t)num.i; break; case ARRAY_TYPE_FLOAT : num.f = (float)num.i; break; case ARRAY_TYPE_DOUBLE : num.d = (double)num.i; break; case ARRAY_TYPE_COMPLEX : num.z = (Complex)num.i; break; } } else if (num.li = lua_tointegerx(L, num_pos, &isnum), isnum) { /* already assigned above */ if (T >= ARRAY_TYPE_LONG) { /* A has higher type */ } else { /* number has higher type */ T = ARRAY_TYPE_LONG; } /* upgrade to type T */ switch (T) { case ARRAY_TYPE_BOOL : num.b = (Bool)num.li; break; case ARRAY_TYPE_CHAR : num.c = (char)num.li; break; case ARRAY_TYPE_SHORT : num.s = (short)num.li; break; case ARRAY_TYPE_INT : num.i = (int)num.li; break; case ARRAY_TYPE_LONG : num.l = (long)num.li; break; case ARRAY_TYPE_SIZE_T : num.t = (size_t)num.li; break; case ARRAY_TYPE_FLOAT : num.f = (float)num.li; break; case ARRAY_TYPE_DOUBLE : num.d = (double)num.li; break; case ARRAY_TYPE_COMPLEX : num.z = (Complex)num.li; break; } } else if (num.ln = lua_tonumberx(L, num_pos, &isnum), isnum) { /* already assigned above */ if (T >= ARRAY_TYPE_DOUBLE) { /* A has higher type */ } else { /* number has higher type */ T = ARRAY_TYPE_DOUBLE; } /* upgrade to type T */ switch (T) { case ARRAY_TYPE_BOOL : num.b = (Bool)num.ln; break; case ARRAY_TYPE_CHAR : num.c = (char)num.ln; break; case ARRAY_TYPE_SHORT : num.s = (short)num.ln; break; case ARRAY_TYPE_INT : num.i = (int)num.ln; break; case ARRAY_TYPE_LONG : num.l = (long)num.ln; break; case ARRAY_TYPE_SIZE_T : num.t = (size_t)num.ln; break; case ARRAY_TYPE_FLOAT : num.f = (float)num.ln; break; case ARRAY_TYPE_DOUBLE : num.d = (double)num.ln; break; case ARRAY_TYPE_COMPLEX : num.z = (Complex)num.ln; break; } } else if (lunum_hasmetatable(L, num_pos, "complex")) { /* number complex */ num.z = *((Complex*) lua_touserdata(L, num_pos)); T = ARRAY_TYPE_COMPLEX; } else { return luaL_error(L, "Invalid argument in Array binary op"); } Array C = array_new_zeros(A->size, T); array_resize_t(&C, A->shape, A->ndims); lunum_pusharray1(L, &C); array_number_binary_op(L, A, (void *)&num, &C, op, array_first); return 1; }