/* static */ void Input::PreparePixInput(const StaticShape& shape, const Pix* pix, TRand* randomizer, NetworkIO* input) { bool color = shape.depth() == 3; Pix* var_pix = const_cast<Pix*>(pix); int depth = pixGetDepth(var_pix); Pix* normed_pix = nullptr; // On input to BaseAPI, an image is forced to be 1, 8 or 24 bit, without // colormap, so we just have to deal with depth conversion here. if (color) { // Force RGB. if (depth == 32) normed_pix = pixClone(var_pix); else normed_pix = pixConvertTo32(var_pix); } else { // Convert non-8-bit images to 8 bit. if (depth == 8) normed_pix = pixClone(var_pix); else normed_pix = pixConvertTo8(var_pix, false); } int height = pixGetHeight(normed_pix); int target_height = shape.height(); if (target_height == 1) target_height = shape.depth(); if (target_height != 0 && target_height != height) { // Get the scaled image. float im_factor = static_cast<float>(target_height) / height; Pix* scaled_pix = pixScale(normed_pix, im_factor, im_factor); pixDestroy(&normed_pix); normed_pix = scaled_pix; } input->FromPix(shape, normed_pix, randomizer); pixDestroy(&normed_pix); }
// Returns the shape output from the network given an input shape (which may // be partially unknown ie zero). StaticShape Parallel::OutputShape(const StaticShape& input_shape) const { StaticShape result = stack_[0]->OutputShape(input_shape); int stack_size = stack_.size(); for (int i = 1; i < stack_size; ++i) { StaticShape shape = stack_[i]->OutputShape(input_shape); result.set_depth(result.depth() + shape.depth()); } return result; }
// Returns the shape output from the network given an input shape (which may // be partially unknown ie zero). StaticShape Reconfig::OutputShape(const StaticShape& input_shape) const { StaticShape result = input_shape; result.set_height(result.height() / y_scale_); result.set_width(result.width() / x_scale_); if (type_ != NT_MAXPOOL) result.set_depth(result.depth() * y_scale_ * x_scale_); return result; }
// Parses an input specification and returns the result, which may include a // series. Network* NetworkBuilder::ParseInput(char** str) { // There must be an input at this point. int length = 0; int batch, height, width, depth; int num_converted = sscanf(*str, "%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length); StaticShape shape; shape.SetShape(batch, height, width, depth); // num_converted may or may not include the length. if (num_converted != 4 && num_converted != 5) { tprintf("Must specify an input layer as the first layer, not %s!!\n", *str); return nullptr; } *str += length; Input* input = new Input("Input", shape); // We want to allow [<input>rest of net... or <input>[rest of net... so we // have to check explicitly for '[' here. SkipWhitespace(str); if (**str == '[') return ParseSeries(shape, input, str); return input; }
Input::Input(const STRING& name, const StaticShape& shape) : Network(NT_INPUT, name, shape.height(), shape.depth()), shape_(shape), cached_x_scale_(1) { if (shape.height() == 1) ni_ = shape.depth(); }