コード例 #1
0
ファイル: tensorflow.c プロジェクト: wedesoft/aiscm
SCM tf_graph_import_(SCM scm_graph, SCM scm_file_name)
{
  struct tf_graph_t *graph = get_tf_graph(scm_graph);
  char *file_name = scm_to_locale_string(scm_file_name);
  FILE *file = fopen(file_name, "r");
  free(file_name);
  if (!file)
    scm_misc_error("tf-graph-import_", strerror(errno), SCM_EOL);
  int fd = fileno(file);
  struct stat st;
  fstat(fd, &st);
  size_t size = st.st_size;
  TF_Buffer *buffer = TF_NewBuffer();
  void *data = scm_gc_malloc(size, "tf-graph-import_");
  fread(data, size, 1, file);
  buffer->data = data;
  buffer->length = size;
  fclose(file);
  TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
  TF_GraphImportGraphDef(graph->graph, buffer, opts, status());
  TF_DeleteImportGraphDefOptions(opts);
  TF_DeleteBuffer(buffer);
  if (TF_GetCode(_status) != TF_OK)
    scm_misc_error("tf-graph-import_", TF_Message(_status), SCM_EOL);
  return SCM_UNDEFINED;
}
コード例 #2
0
DNNModel *ff_dnn_load_model_tf(const char *model_filename)
{
    DNNModel *model = NULL;
    TFModel *tf_model = NULL;
    TF_Buffer *graph_def;
    TF_ImportGraphDefOptions *graph_opts;

    model = av_malloc(sizeof(DNNModel));
    if (!model){
        return NULL;
    }

    tf_model = av_malloc(sizeof(TFModel));
    if (!tf_model){
        av_freep(&model);
        return NULL;
    }
    tf_model->session = NULL;
    tf_model->input_tensor = NULL;
    tf_model->output_data = NULL;

    graph_def = read_graph(model_filename);
    if (!graph_def){
        av_freep(&tf_model);
        av_freep(&model);
        return NULL;
    }
    tf_model->graph = TF_NewGraph();
    tf_model->status = TF_NewStatus();
    graph_opts = TF_NewImportGraphDefOptions();
    TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
    TF_DeleteImportGraphDefOptions(graph_opts);
    TF_DeleteBuffer(graph_def);
    if (TF_GetCode(tf_model->status) != TF_OK){
        TF_DeleteGraph(tf_model->graph);
        TF_DeleteStatus(tf_model->status);
        av_freep(&tf_model);
        av_freep(&model);
        return NULL;
    }

    model->model = (void *)tf_model;
    model->set_input_output = &set_input_output_tf;

    return model;
}
コード例 #3
0
ファイル: dnn_backend_tf.c プロジェクト: DeHackEd/FFmpeg
static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename)
{
    TF_Buffer *graph_def;
    TF_ImportGraphDefOptions *graph_opts;

    graph_def = read_graph(model_filename);
    if (!graph_def){
        return DNN_ERROR;
    }
    tf_model->graph = TF_NewGraph();
    tf_model->status = TF_NewStatus();
    graph_opts = TF_NewImportGraphDefOptions();
    TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
    TF_DeleteImportGraphDefOptions(graph_opts);
    TF_DeleteBuffer(graph_def);
    if (TF_GetCode(tf_model->status) != TF_OK){
        TF_DeleteGraph(tf_model->graph);
        TF_DeleteStatus(tf_model->status);
        return DNN_ERROR;
    }

    return DNN_SUCCESS;
}