Ejemplo n.º 1
0
SCM tf_graph_operation_by_name_(SCM scm_graph, SCM scm_name)
{
  struct tf_graph_t *graph = get_tf_graph(scm_graph);
  char *name = scm_to_locale_string(scm_name);
  TF_Operation *operation = TF_GraphOperationByName(graph->graph, name);
  free(name);
  if (!operation)
    scm_misc_error("tf-graph-operation-by-name_", "Operation '~a' not found", scm_list_1(scm_name));
  SCM retval = SCM_EOL;
  int noutputs = TF_OperationNumOutputs(operation);
  for (int i=noutputs-1; i>=0; i--) {
    SCM element;
    struct tf_output_t *output = (struct tf_output_t *)scm_gc_calloc(sizeof(struct tf_output_t), "tf-graph-operation-by-name_");
    SCM_NEWSMOB(element, tf_output_tag, output);
    output->output.oper = operation;
    output->output.index = i;
    retval = scm_cons(element, retval);
  };
  if (noutputs == 1)
    retval = scm_car(retval);
  return retval;
}
Ejemplo n.º 2
0
static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *output)
{
    TFModel *tf_model = (TFModel *)model;
    int64_t input_dims[] = {1, input->height, input->width, input->channels};
    TF_SessionOptions *sess_opts;
    const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
    TF_Tensor *output_tensor;

    // Input operation should be named 'x'
    tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x");
    if (!tf_model->input.oper){
        return DNN_ERROR;
    }
    tf_model->input.index = 0;
    if (tf_model->input_tensor){
        TF_DeleteTensor(tf_model->input_tensor);
    }
    tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4,
                                               input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float));
    if (!tf_model->input_tensor){
        return DNN_ERROR;
    }
    input->data = (float *)TF_TensorData(tf_model->input_tensor);

    // Output operation should be named 'y'
    tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y");
    if (!tf_model->output.oper){
        return DNN_ERROR;
    }
    tf_model->output.index = 0;

    if (tf_model->session){
        TF_CloseSession(tf_model->session, tf_model->status);
        TF_DeleteSession(tf_model->session, tf_model->status);
    }

    sess_opts = TF_NewSessionOptions();
    tf_model->session = TF_NewSession(tf_model->graph, sess_opts, tf_model->status);
    TF_DeleteSessionOptions(sess_opts);
    if (TF_GetCode(tf_model->status) != TF_OK)
    {
        return DNN_ERROR;
    }

    // Run initialization operation with name "init" if it is present in graph
    if (init_op){
        TF_SessionRun(tf_model->session, NULL,
                      NULL, NULL, 0,
                      NULL, NULL, 0,
                      &init_op, 1, NULL, tf_model->status);
        if (TF_GetCode(tf_model->status) != TF_OK)
        {
            return DNN_ERROR;
        }
    }

    // Execute network to get output height, width and number of channels
    TF_SessionRun(tf_model->session, NULL,
                  &tf_model->input, &tf_model->input_tensor, 1,
                  &tf_model->output, &output_tensor, 1,
                  NULL, 0, NULL, tf_model->status);
    if (TF_GetCode(tf_model->status) != TF_OK){
        return DNN_ERROR;
    }
    else{
        output->height = TF_Dim(output_tensor, 1);
        output->width = TF_Dim(output_tensor, 2);
        output->channels = TF_Dim(output_tensor, 3);
        output->data = av_malloc(output->height * output->width * output->channels * sizeof(float));
        if (!output->data){
            return DNN_ERROR;
        }
        tf_model->output_data = output;
        TF_DeleteTensor(output_tensor);
    }

    return DNN_SUCCESS;
}