Ejemplo n.º 1
0
/* static */
void Input::PreparePixInput(const StaticShape& shape, const Pix* pix,
                            TRand* randomizer, NetworkIO* input) {
  bool color = shape.depth() == 3;
  Pix* var_pix = const_cast<Pix*>(pix);
  int depth = pixGetDepth(var_pix);
  Pix* normed_pix = nullptr;
  // On input to BaseAPI, an image is forced to be 1, 8 or 24 bit, without
  // colormap, so we just have to deal with depth conversion here.
  if (color) {
    // Force RGB.
    if (depth == 32)
      normed_pix = pixClone(var_pix);
    else
      normed_pix = pixConvertTo32(var_pix);
  } else {
    // Convert non-8-bit images to 8 bit.
    if (depth == 8)
      normed_pix = pixClone(var_pix);
    else
      normed_pix = pixConvertTo8(var_pix, false);
  }
  int height = pixGetHeight(normed_pix);
  int target_height = shape.height();
  if (target_height == 1) target_height = shape.depth();
  if (target_height != 0 && target_height != height) {
    // Get the scaled image.
    float im_factor = static_cast<float>(target_height) / height;
    Pix* scaled_pix = pixScale(normed_pix, im_factor, im_factor);
    pixDestroy(&normed_pix);
    normed_pix = scaled_pix;
  }
  input->FromPix(shape, normed_pix, randomizer);
  pixDestroy(&normed_pix);
}
Ejemplo n.º 2
0
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
StaticShape Parallel::OutputShape(const StaticShape& input_shape) const {
  StaticShape result = stack_[0]->OutputShape(input_shape);
  int stack_size = stack_.size();
  for (int i = 1; i < stack_size; ++i) {
    StaticShape shape = stack_[i]->OutputShape(input_shape);
    result.set_depth(result.depth() + shape.depth());
  }
  return result;
}
Ejemplo n.º 3
0
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
StaticShape Reconfig::OutputShape(const StaticShape& input_shape) const {
  StaticShape result = input_shape;
  result.set_height(result.height() / y_scale_);
  result.set_width(result.width() / x_scale_);
  if (type_ != NT_MAXPOOL)
    result.set_depth(result.depth() * y_scale_ * x_scale_);
  return result;
}
Ejemplo n.º 4
0
// Parses an input specification and returns the result, which may include a
// series.
Network* NetworkBuilder::ParseInput(char** str) {
  // There must be an input at this point.
  int length = 0;
  int batch, height, width, depth;
  int num_converted =
      sscanf(*str, "%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length);
  StaticShape shape;
  shape.SetShape(batch, height, width, depth);
  // num_converted may or may not include the length.
  if (num_converted != 4 && num_converted != 5) {
    tprintf("Must specify an input layer as the first layer, not %s!!\n", *str);
    return nullptr;
  }
  *str += length;
  Input* input = new Input("Input", shape);
  // We want to allow [<input>rest of net... or <input>[rest of net... so we
  // have to check explicitly for '[' here.
  SkipWhitespace(str);
  if (**str == '[') return ParseSeries(shape, input, str);
  return input;
}
Ejemplo n.º 5
0
Input::Input(const STRING& name, const StaticShape& shape)
    : Network(NT_INPUT, name, shape.height(), shape.depth()),
      shape_(shape),
      cached_x_scale_(1) {
  if (shape.height() == 1) ni_ = shape.depth();
}