Beispiel #1
0
SCM make_tf_session(SCM scm_graph)
{
  SCM retval;
  struct tf_session_t *self = (struct tf_session_t *)scm_gc_calloc(sizeof(struct tf_session_t), "make-tf-session");
  SCM_NEWSMOB(retval, tf_session_tag, self);
  self->graph = get_tf_graph(scm_graph);
  TF_SessionOptions *options = TF_NewSessionOptions();
  self->session = TF_NewSession(self->graph->graph, options, status());
  TF_DeleteSessionOptions(options);
  if (TF_GetCode(_status) != TF_OK)
    scm_misc_error("make-tf-session", TF_Message(_status), SCM_EOL);
  return retval;
}
Beispiel #2
0
static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *output)
{
    TFModel *tf_model = (TFModel *)model;
    int64_t input_dims[] = {1, input->height, input->width, input->channels};
    TF_SessionOptions *sess_opts;
    const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
    TF_Tensor *output_tensor;

    // Input operation should be named 'x'
    tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x");
    if (!tf_model->input.oper){
        return DNN_ERROR;
    }
    tf_model->input.index = 0;
    if (tf_model->input_tensor){
        TF_DeleteTensor(tf_model->input_tensor);
    }
    tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4,
                                               input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float));
    if (!tf_model->input_tensor){
        return DNN_ERROR;
    }
    input->data = (float *)TF_TensorData(tf_model->input_tensor);

    // Output operation should be named 'y'
    tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y");
    if (!tf_model->output.oper){
        return DNN_ERROR;
    }
    tf_model->output.index = 0;

    if (tf_model->session){
        TF_CloseSession(tf_model->session, tf_model->status);
        TF_DeleteSession(tf_model->session, tf_model->status);
    }

    sess_opts = TF_NewSessionOptions();
    tf_model->session = TF_NewSession(tf_model->graph, sess_opts, tf_model->status);
    TF_DeleteSessionOptions(sess_opts);
    if (TF_GetCode(tf_model->status) != TF_OK)
    {
        return DNN_ERROR;
    }

    // Run initialization operation with name "init" if it is present in graph
    if (init_op){
        TF_SessionRun(tf_model->session, NULL,
                      NULL, NULL, 0,
                      NULL, NULL, 0,
                      &init_op, 1, NULL, tf_model->status);
        if (TF_GetCode(tf_model->status) != TF_OK)
        {
            return DNN_ERROR;
        }
    }

    // Execute network to get output height, width and number of channels
    TF_SessionRun(tf_model->session, NULL,
                  &tf_model->input, &tf_model->input_tensor, 1,
                  &tf_model->output, &output_tensor, 1,
                  NULL, 0, NULL, tf_model->status);
    if (TF_GetCode(tf_model->status) != TF_OK){
        return DNN_ERROR;
    }
    else{
        output->height = TF_Dim(output_tensor, 1);
        output->width = TF_Dim(output_tensor, 2);
        output->channels = TF_Dim(output_tensor, 3);
        output->data = av_malloc(output->height * output->width * output->channels * sizeof(float));
        if (!output->data){
            return DNN_ERROR;
        }
        tf_model->output_data = output;
        TF_DeleteTensor(output_tensor);
    }

    return DNN_SUCCESS;
}