Ejemplo n.º 1
0
static int etherflow_(Api_send_tensor_byte_lua)(lua_State *L) {
  // get params
  THByteTensor *tensor = luaT_toudata(L, 1, luaT_checktypename2id(L, "torch.ByteTensor"));
  int size = THByteTensor_nElement(tensor);
  unsigned char *data = THByteTensor_data(tensor);
  etherflow_send_ByteTensor_C(data, size);
  return 0;
}
Ejemplo n.º 2
0
static int libjpeg_(Main_load)(lua_State *L)
{
  /* This struct contains the JPEG decompression parameters and pointers to
   * working space (which is allocated as needed by the JPEG library).
   */
  struct jpeg_decompress_struct cinfo;
  /* We use our private extension JPEG error handler.
   * Note that this struct must live as long as the main JPEG parameter
   * struct, to avoid dangling-pointer problems.
   */
  struct my_error_mgr jerr;
  /* More stuff */
  FILE * infile;		    /* source file (if loading from file) */
  unsigned char * inmem;    /* source memory (if loading from memory) */
  unsigned long inmem_size; /* source memory size (bytes) */
  JSAMPARRAY buffer;		/* Output row buffer */
  /* int row_stride;		/1* physical row width in output buffer *1/ */
  int i, k;

  THTensor *tensor = NULL;
  const int load_from_file = luaL_checkint(L, 1);
  
  if (load_from_file == 1) {
    const char *filename = luaL_checkstring(L, 2);
    
    /* In this example we want to open the input file before doing anything else,
     * so that the setjmp() error recovery below can assume the file is open.
     * VERY IMPORTANT: use "b" option to fopen() if you are on a machine that
     * requires it in order to read binary files.
     */

    if ((infile = fopen(filename, "rb")) == NULL)
    {
      luaL_error(L, "cannot open file <%s> for reading", filename);
    }
  } else {
    /* We're loading from a ByteTensor */
    THByteTensor *src = luaT_checkudata(L, 2, "torch.ByteTensor");
    inmem = THByteTensor_data(src);
    inmem_size = src->size[0];
    infile = NULL;
  }
  
  /* Step 1: allocate and initialize JPEG decompression object */

  /* We set up the normal JPEG error routines, then override error_exit. */
  cinfo.err = jpeg_std_error(&jerr.pub);
  jerr.pub.error_exit = libjpeg_(Main_error);
  /* Establish the setjmp return context for my_error_exit to use. */
  if (setjmp(jerr.setjmp_buffer)) {
    /* If we get here, the JPEG code has signaled an error.
     * We need to clean up the JPEG object, close the input file, and return.
     */
    jpeg_destroy_decompress(&cinfo);
    if (infile) {
      fclose(infile);
    }
    return 0;
  }
  /* Now we can initialize the JPEG decompression object. */
  jpeg_create_decompress(&cinfo);

  /* Step 2: specify data source (eg, a file) */
  if (load_from_file == 1) {
    jpeg_stdio_src(&cinfo, infile);
  } else {
    jpeg_mem_src(&cinfo, inmem, inmem_size);
  }

  /* Step 3: read file parameters with jpeg_read_header() */

  (void) jpeg_read_header(&cinfo, TRUE);
  /* We can ignore the return value from jpeg_read_header since
   *   (a) suspension is not possible with the stdio data source, and
   *   (b) we passed TRUE to reject a tables-only JPEG file as an error.
   * See libjpeg.doc for more info.
   */

  /* Step 4: set parameters for decompression */

  /* In this example, we don't need to change any of the defaults set by
   * jpeg_read_header(), so we do nothing here.
   */

  /* Step 5: Start decompressor */

  (void) jpeg_start_decompress(&cinfo);
  /* We can ignore the return value since suspension is not possible
   * with the stdio data source.
   */

  /* We may need to do some setup of our own at this point before reading
   * the data.  After jpeg_start_decompress() we have the correct scaled
   * output image dimensions available, as well as the output colormap
   * if we asked for color quantization.
   * In this example, we need to make an output work buffer of the right size.
   */ 

  /* Make a one-row-high sample array that will go away when done with image */

  tensor = THTensor_(newWithSize3d)(cinfo.output_components, cinfo.output_height, cinfo.output_width);
  buffer = (*cinfo.mem->alloc_sarray)
		((j_common_ptr) &cinfo, JPOOL_IMAGE, cinfo.output_width * cinfo.output_components, 1);

  /* Step 6: while (scan lines remain to be read) */
  /*           jpeg_read_scanlines(...); */

  /* Here we use the library's state variable cinfo.output_scanline as the
   * loop counter, so that we don't have to keep track ourselves.
   */
  while (cinfo.output_scanline < cinfo.output_height) {
    /* jpeg_read_scanlines expects an array of pointers to scanlines.
     * Here the array is only one element long, but you could ask for
     * more than one scanline at a time if that's more convenient.
     */
    (void) jpeg_read_scanlines(&cinfo, buffer, 1);
    
    for(k = 0; k < cinfo.output_components; k++)
    {
      for(i = 0; i < cinfo.output_width; i++)
        THTensor_(set3d)(tensor, k, cinfo.output_scanline-1, i, 
                         (real)buffer[0][cinfo.output_components*i+k]);
    }
  }

  /* Step 7: Finish decompression */

  (void) jpeg_finish_decompress(&cinfo);
  /* We can ignore the return value since suspension is not possible
   * with the stdio data source.
   */

  /* Step 8: Release JPEG decompression object */

  /* This is an important step since it will release a good deal of memory. */
  jpeg_destroy_decompress(&cinfo);

  /* After finish_decompress, we can close the input file.
   * Here we postpone it until after no more JPEG errors are possible,
   * so as to simplify the setjmp error logic above.  (Actually, I don't
   * think that jpeg_destroy can do an error exit, but why assume anything...)
   */
  if (infile) {
    fclose(infile);
  }

  /* At this point you may want to check to see whether any corrupt-data
   * warnings occurred (test whether jerr.pub.num_warnings is nonzero).
   */

  /* And we're done! */
  luaT_pushudata(L, tensor, torch_Tensor);
  return 1;
}
Ejemplo n.º 3
0
Archivo: jpeg.c Proyecto: omry/image
/*
 * save function
 *
 */
int libjpeg_(Main_save)(lua_State *L) {
  unsigned char *inmem = NULL;  /* destination memory (if saving to memory) */
  unsigned long inmem_size = 0;  /* destination memory size (bytes) */

  /* get args */
  const char *filename = luaL_checkstring(L, 1);
  THTensor *tensor = luaT_checkudata(L, 2, torch_Tensor);  
  THTensor *tensorc = THTensor_(newContiguous)(tensor);
  real *tensor_data = THTensor_(data)(tensorc);

  const int save_to_file = luaL_checkint(L, 3);

  THByteTensor* tensor_dest = NULL;
  if (save_to_file == 0) {
    tensor_dest = luaT_checkudata(L, 5, "torch.ByteTensor");
  }

  int quality = luaL_checkint(L, 4);
  if (quality < 0 || quality > 100) {
    luaL_error(L, "quality should be between 0 and 100");
  }

  /* jpeg struct */
  struct jpeg_compress_struct cinfo;
  struct jpeg_error_mgr jerr;

  /* pointer to raw image */
  unsigned char *raw_image = NULL;

  /* dimensions of the image we want to write */
  int width=0, height=0, bytes_per_pixel=0;
  int color_space=0;
  if (tensorc->nDimension == 3) {
    bytes_per_pixel = tensorc->size[0];
    height = tensorc->size[1];
    width = tensorc->size[2];
    if (bytes_per_pixel == 3) {
      color_space = JCS_RGB;
    } else if (bytes_per_pixel == 1) {
      color_space = JCS_GRAYSCALE;
    } else {
      luaL_error(L, "tensor should have 1 or 3 channels (gray or RGB)");
    }
  } else if (tensorc->nDimension == 2) {
    bytes_per_pixel = 1;
    height = tensorc->size[0];
    width = tensorc->size[1];
    color_space = JCS_GRAYSCALE;
  } else {
    luaL_error(L, "supports only 1 or 3 dimension tensors");
  }

  /* alloc raw image data */
  raw_image = (unsigned char *)malloc((sizeof (unsigned char))*width*height*bytes_per_pixel);

  /* convert tensor to raw bytes */
  int x,y,k;
  for (k=0; k<bytes_per_pixel; k++) {
    for (y=0; y<height; y++) {
      for (x=0; x<width; x++) {
        raw_image[(y*width+x)*bytes_per_pixel+k] = *tensor_data++;
      }
    }
  }

  /* this is a pointer to one row of image data */
  JSAMPROW row_pointer[1];
  FILE *outfile = NULL;
  if (save_to_file == 1) {
    outfile = fopen( filename, "wb" );
    if ( !outfile ) {
      luaL_error(L, "Error opening output jpeg file %s\n!", filename );
    }
  }

  cinfo.err = jpeg_std_error( &jerr );
  jpeg_create_compress(&cinfo);

  /* specify data source (eg, a file) */
  if (save_to_file == 1) {
    jpeg_stdio_dest(&cinfo, outfile);
  } else {
    jpeg_mem_dest(&cinfo, &inmem, &inmem_size);
  }

  /* Setting the parameters of the output file here */
  cinfo.image_width = width;	
  cinfo.image_height = height;
  cinfo.input_components = bytes_per_pixel;
  cinfo.in_color_space = color_space;

  /* default compression parameters, we shouldn't be worried about these */
  jpeg_set_defaults( &cinfo );
  jpeg_set_quality(&cinfo, quality, (boolean)0);

  /* Now do the compression .. */
  jpeg_start_compress( &cinfo, TRUE );

  /* like reading a file, this time write one row at a time */
  while( cinfo.next_scanline < cinfo.image_height ) {
    row_pointer[0] = &raw_image[ cinfo.next_scanline * cinfo.image_width *  cinfo.input_components];
    jpeg_write_scanlines( &cinfo, row_pointer, 1 );
  }

  /* similar to read file, clean up after we're done compressing */
  jpeg_finish_compress( &cinfo );
  jpeg_destroy_compress( &cinfo );
  
  if (outfile != NULL) {
    fclose( outfile );
  }

  if (save_to_file == 0) {
    
    THByteTensor_resize1d(tensor_dest, inmem_size);  /* will fail if it's not a Byte Tensor */ 
    unsigned char* tensor_dest_data = THByteTensor_data(tensor_dest); 
    memcpy(tensor_dest_data, inmem, inmem_size);
    free(inmem);
  }

  /* some cleanup */
  free(raw_image);
  THTensor_(free)(tensorc);

  /* success code is 1! */
  return 1;
}
Ejemplo n.º 4
0
static int torchzfp_(Main_compress)(lua_State *L) {
  THTensor* in = reinterpret_cast<THTensor*>(
      luaT_checkudata(L, 1, torch_Tensor));
  real* in_data = THTensor_(data)(in);
  const uint32_t dim = in->nDimension;
  if (dim == 0) {
    luaL_error(L, "Input tensor must not be empty");
  }
  THByteTensor* out = reinterpret_cast<THByteTensor*>(
      luaT_checkudata(L, 2, "torch.ByteTensor"));
  const double accuracy = static_cast<double>(lua_tonumber(L, 3));

  // Hacky code to figure out what type 'real' is at runtime. This really should
  // be template specialization (so it's compiled in at runtime).
  real dummy;
  static_cast<void>(dummy);  // Silence compiler warnings.
  zfp_type type;
  if (typeid(dummy) == typeid(float)) {
    type = zfp_type_float;
  } else if (typeid(dummy) == typeid(double)) {
    type = zfp_type_double;
  } else {
    luaL_error(L, "Input type must be double or float.");
  }

  // Allocate meta data for the array.
  zfp_field* field;
  uint32_t dim_zfp;
  if (dim == 1) {
    field = zfp_field_1d(in_data, type, in->size[0]);
    dim_zfp = 1;
  } else if (dim == 2) {
    field = zfp_field_2d(in_data, type, in->size[1], in->size[0]);
    dim_zfp = 2;
  } else if (dim == 3) {
    field = zfp_field_3d(in_data, type, in->size[2], in->size[1], in->size[0]);
    dim_zfp = 3;
  } else {
    // ZFP only allows up to 3D tensors, so we'll have to treat the input
    // tensor as a concatenated 3D tensor. This will affect compression ratios
    // but there's not much we can do about this.
    uint32_t sizez = 1;
    for (uint32_t i = 0; i < dim - 2; i++) {
      sizez *= in->size[i];
    }
    uint32_t sizey = in->size[dim - 2];
    uint32_t sizex = in->size[dim - 1];
    field = zfp_field_3d(in_data, type, sizex, sizey, sizez);
    dim_zfp = 4;
  }

  // Allocate meta data for the compressed stream.
  zfp_stream* zfp = zfp_stream_open(NULL);

  // Set stream compression mode and parameters.
  zfp_stream_set_accuracy(zfp, accuracy, type);

  // Allocate buffer for compressed data.
  size_t bufsize = zfp_stream_maximum_size(zfp, field);
  std::unique_ptr<uint8_t[]> buffer(new uint8_t[bufsize]);

  // Associate bit stream with allocated buffer.
  bitstream* stream = stream_open(buffer.get(), bufsize);
  zfp_stream_set_bit_stream(zfp, stream);
  zfp_stream_rewind(zfp);

  // Compress entire array.
  const size_t zfpsize = zfp_compress(zfp, field);

  // Clean up.
  zfp_field_free(field);
  zfp_stream_close(zfp);
  stream_close(stream);

  if (!zfpsize) {
    luaL_error(L, "ZFP compression failed!");
  }

  // Copy the compressed array into the return tensor. NOTE: Torch does not
  // support in-place resize with shrink. If you resize smaller you ALWAYS
  // keep around the memory, so unfortuantely this copy is necessary (i.e.
  // we will always need to perform the compression in a temporary buffer
  // first).
  THByteTensor_resize1d(out, zfpsize);
  unsigned char* out_data = THByteTensor_data(out);
  memcpy(out_data, buffer.get(), zfpsize);

  return 0;  // Recall: number of lua return items.
}
Ejemplo n.º 5
0
static int torchzfp_(Main_decompress)(lua_State *L) {
  THTensor* out = reinterpret_cast<THTensor*>(
      luaT_checkudata(L, 1, torch_Tensor));
  real* out_data = THTensor_(data)(out);
  const uint32_t dim = out->nDimension;
  THByteTensor* in = reinterpret_cast<THByteTensor*>(
      luaT_checkudata(L, 2, "torch.ByteTensor"));
  unsigned char* in_data = THByteTensor_data(in);
  const double accuracy = static_cast<double>(lua_tonumber(L, 3));

  // Hacky code to figure out what type 'real' is at runtime. This really should
  // be template specialization (so it's compiled in at runtime).
  real dummy;
  static_cast<void>(dummy);  // Silence compiler warnings.
  zfp_type type;
  if (typeid(dummy) == typeid(float)) {
    type = zfp_type_float;
  } else if (typeid(dummy) == typeid(double)) {
    type = zfp_type_double;
  } else {
    luaL_error(L, "Output type must be double or float.");
  }

  // Allocate meta data for the array.
  zfp_field* field;
  uint32_t dim_zfp;
  if (dim == 1) {
    field = zfp_field_1d(out_data, type, out->size[0]);
    dim_zfp = 1;
  } else if (dim == 2) {
    field = zfp_field_2d(out_data, type, out->size[1], out->size[0]);
    dim_zfp = 2;
  } else if (dim == 3) {
    field = zfp_field_3d(out_data, type, out->size[2], out->size[1],
                         out->size[0]);
    dim_zfp = 3;
  } else {
    // ZFP only allows up to 3D tensors, so we'll have to treat the input
    // tensor as a concatenated 3D tensor. This will affect compression ratios
    // but there's not much we can do about this.
    uint32_t sizez = 1;
    for (uint32_t i = 0; i < dim - 2; i++) {
      sizez *= out->size[i];
    }
    uint32_t sizey = out->size[dim - 2];
    uint32_t sizex = out->size[dim - 1];
    field = zfp_field_3d(out_data, type, sizex, sizey, sizez);
    dim_zfp = 3;
  }
  
  // Allocate meta data for the compressed stream.
  zfp_stream* zfp = zfp_stream_open(NULL);

  // Set stream compression mode and parameters.
  zfp_stream_set_accuracy(zfp, accuracy, type);

  // Get buffer for compressed data.
  void* buffer = reinterpret_cast<void*>(in_data);

  // Associate bit stream with allocated buffer.
  const uint32_t bufsize = in->size[0];
  bitstream* stream = stream_open(buffer, bufsize);
  zfp_stream_set_bit_stream(zfp, stream);
  zfp_stream_rewind(zfp);
  
  // Compress entire array.
  const int ret = zfp_decompress(zfp, field);

  // Clean up.
  zfp_field_free(field);
  zfp_stream_close(zfp);
  stream_close(stream);

  if (!ret) { 
    luaL_error(L, "ZFP decompression failed!");
  }

  return 0;  // Recall: number of lua return items.
}
Ejemplo n.º 6
0
static void load_array_to_lua(lua_State *L, cnpy::NpyArray& arr){
	int ndims = arr.shape.size();

	//based on code from mattorch with stride fix
	int k;
	THLongStorage *size = THLongStorage_newWithSize(ndims);
	THLongStorage *stride = THLongStorage_newWithSize(ndims);
	for (k=0; k<ndims; k++) {
		THLongStorage_set(size, k, arr.shape[k]);
		if (k > 0)
			THLongStorage_set(stride, ndims-k-1, arr.shape[ndims-k]*THLongStorage_get(stride,ndims-k));
		else
			THLongStorage_set(stride, ndims-k-1, 1);
	}

	void * tensorDataPtr = NULL;
	size_t numBytes = 0;

	if ( arr.arrayType == 'f' ){ // float32/64
		if ( arr.word_size == 4 ){ //float32
			THFloatTensor *tensor = THFloatTensor_newWithSize(size, stride);
		    tensorDataPtr = (void *)(THFloatTensor_data(tensor));
		    numBytes = THFloatTensor_nElement(tensor) * arr.word_size;
		    luaT_pushudata(L, tensor, luaT_checktypename2id(L, "torch.FloatTensor"));
    
		}else if ( arr.word_size ==  8){ //float 64
			THDoubleTensor *tensor = THDoubleTensor_newWithSize(size, stride);
			tensorDataPtr = (void *)(THDoubleTensor_data(tensor));
		    numBytes = THDoubleTensor_nElement(tensor) * arr.word_size;
		    luaT_pushudata(L, tensor, luaT_checktypename2id(L, "torch.DoubleTensor"));
		}
	}else if ( arr.arrayType == 'i' || arr.arrayType == 'u' ){ // does torch have unsigned types .. need to look
		if ( arr.word_size == 1 ){ //int8
			THByteTensor *tensor = THByteTensor_newWithSize(size, stride);
			tensorDataPtr = (void *)(THByteTensor_data(tensor));
		    numBytes = THByteTensor_nElement(tensor) * arr.word_size;
		    luaT_pushudata(L, tensor, luaT_checktypename2id(L, "torch.ByteTensor"));
    
		}else if ( arr.word_size == 2 ){ //int16
			THShortTensor *tensor = THShortTensor_newWithSize(size, stride);
			tensorDataPtr = (void *)(THShortTensor_data(tensor));
		    numBytes = THShortTensor_nElement(tensor) * arr.word_size;
		    luaT_pushudata(L, tensor, luaT_checktypename2id(L, "torch.ShortTensor"));
    
		}else if ( arr.word_size == 4 ){ //int32
			THIntTensor *tensor = THIntTensor_newWithSize(size, stride);
			tensorDataPtr = (void *)(THIntTensor_data(tensor));
		    numBytes = THIntTensor_nElement(tensor) * arr.word_size;
		    luaT_pushudata(L, tensor, luaT_checktypename2id(L, "torch.IntTensor"));
    
		}else if ( arr.word_size ==  8){ //long 64
			THLongTensor *tensor = THLongTensor_newWithSize(size, stride);
			tensorDataPtr = (void *)(THLongTensor_data(tensor));
		    numBytes = THLongTensor_nElement(tensor) * arr.word_size;
		    luaT_pushudata(L, tensor, luaT_checktypename2id(L, "torch.LongTensor"));
		}
	}else{
		printf("array type unsupported");
		throw std::runtime_error("unsupported data type");
	}

		// now copy the data
		assert(tensorDataPtr);
		memcpy(tensorDataPtr, (void *)(arr.data<void>()), numBytes);


}
Ejemplo n.º 7
0
void THNN_ByteSpatialMaxPooling_updateOutput(
          THNNState *state,
          THByteTensor *input,
          THByteTensor *output,
          THByteTensor *indices,
          int kW,
          int kH,
          int dW,
          int dH,
          int padW,
          int padH,
          bool ceil_mode)
{
  int dimw = 2;
  int dimh = 1;
  long nbatch = 1;
  long nslices;
  long iheight;
  long iwidth;
  long oheight;
  long owidth;
  uint8_t *input_data;
  uint8_t *output_data;
  uint8_t *indices_data;


  THArgCheck(input->nDimension == 3 || input->nDimension == 4 , 2, "3D or 4D (batch mode) tensor expected");

  if (input->nDimension == 4)
  {
    nbatch = input->size[0];
    dimw++;
    dimh++;
  }
  THArgCheck(input->size[dimw] >= kW - padW && input->size[dimh] >= kH - padH, 2, "input image smaller than kernel size");

  THArgCheck(kW/2 >= padW && kH/2 >= padH, 2, "pad should be smaller than half of kernel size");

  /* sizes */
  nslices = input->size[dimh-1];
  iheight = input->size[dimh];
  iwidth = input->size[dimw];
  if (ceil_mode)
  {
    oheight = (long)(ceil((float)(iheight - kH + 2*padH) / dH)) + 1;
    owidth  = (long)(ceil((float)(iwidth  - kW + 2*padW) / dW)) + 1;
  }
  else
  {
    oheight = (long)(floor((float)(iheight - kH + 2*padH) / dH)) + 1;
    owidth  = (long)(floor((float)(iwidth  - kW + 2*padW) / dW)) + 1;
  }

  if (padW || padH)
  {
    // ensure that the last pooling starts inside the image
    if ((oheight - 1)*dH >= iheight + padH)
      --oheight;
    if ((owidth  - 1)*dW >= iwidth  + padW)
      --owidth;
  }

  /* get contiguous input */
  input = THByteTensor_newContiguous(input);

  /* resize output */
  if (input->nDimension == 3)
  {
    THByteTensor_resize3d(output, nslices, oheight, owidth);
    /* indices will contain the locations for each output point */
    THByteTensor_resize3d(indices,  nslices, oheight, owidth);

    input_data = THByteTensor_data(input);
    output_data = THByteTensor_data(output);
    indices_data = THByteTensor_data(indices);

    THNN_ByteSpatialMaxPooling_updateOutput_frame(input_data, output_data,
                                                 indices_data,
                                                 nslices,
                                                 iwidth, iheight,
                                                 owidth, oheight,
                                                 kW, kH, dW, dH,
                                                 padW, padH);
  }
  else
  {
    long p;

    THByteTensor_resize4d(output, nbatch, nslices, oheight, owidth);
    /* indices will contain the locations for each output point */
    THByteTensor_resize4d(indices, nbatch, nslices, oheight, owidth);

    input_data = THByteTensor_data(input);
    output_data = THByteTensor_data(output);
    indices_data = THByteTensor_data(indices);

#pragma omp parallel for private(p)
    for (p = 0; p < nbatch; p++)
    {
      THNN_ByteSpatialMaxPooling_updateOutput_frame(input_data+p*nslices*iwidth*iheight, output_data+p*nslices*owidth*oheight,
                                                   indices_data+p*nslices*owidth*oheight,
                                                   nslices,
                                                   iwidth, iheight,
                                                   owidth, oheight,
                                                   kW, kH, dW, dH,
                                                   padW, padH);
    }
  }

  /* cleanup */
  THByteTensor_free(input);
}
Ejemplo n.º 8
0
static int libpng_(Main_load)(lua_State *L)
{

  png_byte header[8];    // 8 is the maximum size that can be checked

  int width, height, bit_depth;
  png_byte color_type;
  
  png_structp png_ptr;
  png_infop info_ptr;
  png_bytep * row_pointers;
  size_t fread_ret;
  FILE* fp;
  libpng_inmem_buffer inmem = {0};    /* source memory (if loading from memory) */
  libpng_errmsg errmsg;

  const int load_from_file = luaL_checkint(L, 1);

  if (load_from_file == 1){
    const char *file_name = luaL_checkstring(L, 2);
   /* open file and test for it being a png */
    fp = fopen(file_name, "rb");
    if (!fp)
      luaL_error(L, "[read_png_file] File %s could not be opened for reading", file_name);
    fread_ret = fread(header, 1, 8, fp);
    if (fread_ret != 8)
      luaL_error(L, "[read_png_file] File %s error reading header", file_name);
    if (png_sig_cmp(header, 0, 8))
      luaL_error(L, "[read_png_file] File %s is not recognized as a PNG file", file_name);
  } else {
    /* We're loading from a ByteTensor */
    THByteTensor *src = luaT_checkudata(L, 2, "torch.ByteTensor");
    inmem.buffer = THByteTensor_data(src);
    inmem.length = src->size[0];
    inmem.offset = 8;
    fp = NULL;
    if (png_sig_cmp(inmem.buffer, 0, 8))
      luaL_error(L, "[read_png_byte_tensor] ByteTensor is not recognized as a PNG file");
  }
  /* initialize stuff */
  png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);

  if (!png_ptr)
    luaL_error(L, "[read_png] png_create_read_struct failed");

  png_set_error_fn(png_ptr, &errmsg, libpng_error_fn, NULL);

  info_ptr = png_create_info_struct(png_ptr);
  if (!info_ptr)
    luaL_error(L, "[read_png] png_create_info_struct failed");

  if (setjmp(png_jmpbuf(png_ptr)))
    luaL_error(L, "[read_png] Error during init_io: %s", errmsg.str);

  if (load_from_file == 1){
    png_init_io(png_ptr, fp);
  } else {
    /* set the read callback */
    png_set_read_fn(png_ptr,(png_voidp)&inmem, libpng_userReadData);
  }
  png_set_sig_bytes(png_ptr, 8);
  png_read_info(png_ptr, info_ptr);

  width      = png_get_image_width(png_ptr, info_ptr);
  height     = png_get_image_height(png_ptr, info_ptr);
  color_type = png_get_color_type(png_ptr, info_ptr);
  bit_depth  = png_get_bit_depth(png_ptr, info_ptr);
  

  /* get depth */
  int depth = 0;
  if (color_type == PNG_COLOR_TYPE_RGBA)
    depth = 4;
  else if (color_type == PNG_COLOR_TYPE_RGB)
    depth = 3;
  else if (color_type == PNG_COLOR_TYPE_GRAY)
  {
    if(bit_depth < 8)
    {
      png_set_expand_gray_1_2_4_to_8(png_ptr);
    }
    depth = 1;
  }
  else if (color_type == PNG_COLOR_TYPE_GA)
    depth = 2;
  else if (color_type == PNG_COLOR_TYPE_PALETTE)
    {
      depth = 3;
      png_set_expand(png_ptr);
    }
  else
    luaL_error(L, "[read_png_file] Unknown color space");

  if(bit_depth < 8)
  {
    png_set_strip_16(png_ptr);
  }
  
  png_read_update_info(png_ptr, info_ptr);

  /* read file */
  if (setjmp(png_jmpbuf(png_ptr)))
    luaL_error(L, "[read_png_file] Error during read_image: %s", errmsg.str);

  /* alloc tensor */
  THTensor *tensor = THTensor_(newWithSize3d)(depth, height, width);
  real *tensor_data = THTensor_(data)(tensor);

  /* alloc data in lib format */
  row_pointers = (png_bytep*) malloc(sizeof(png_bytep) * height);
  int y;
  for (y=0; y<height; y++)
    row_pointers[y] = (png_byte*) malloc(png_get_rowbytes(png_ptr,info_ptr));

  /* read image in */
  png_read_image(png_ptr, row_pointers);

  /* convert image to dest tensor */
  int x,k;
  if ((bit_depth == 16) && (sizeof(real) > 1)) {
    for (k=0; k<depth; k++) {
      for (y=0; y<height; y++) {
	png_byte* row = row_pointers[y];
	for (x=0; x<width; x++) {
	  // PNG is big-endian
	  int val = ((int)row[(x*depth+k)*2] << 8) + row[(x*depth+k)*2+1];
	  *tensor_data++ = (real)val;
	}
      }
    }
  } else {
    int stride = 1;
    if (bit_depth == 16) {
      /* PNG has 16 bit color depth, but the tensor type is byte. */
      stride = 2;
    }
    for (k=0; k<depth; k++) {
      for (y=0; y<height; y++) {
	png_byte* row = row_pointers[y];
	for (x=0; x<width; x++) {
	  *tensor_data++ = (real)row[(x*depth+k)*stride];
	  //png_byte val = row[x*depth+k];
	  //THTensor_(set3d)(tensor, k, y, x, (real)val);
	}
      }
    }
  }


  /* cleanup heap allocation */
  for (y=0; y<height; y++)
    free(row_pointers[y]);
  free(row_pointers);

  /* cleanup png structs */
  png_read_end(png_ptr, NULL);
  png_destroy_read_struct(&png_ptr, &info_ptr, NULL);

  /* done with file */
  if (fp) {
    fclose(fp);
  }

  /* return tensor */
  luaT_pushudata(L, tensor, torch_Tensor);

  if (bit_depth < 8) {
    bit_depth = 8;
  }
  lua_pushnumber(L, bit_depth);

  return 2;
}