Exemple #1
0
SCM make_graph(void)
{
  SCM retval;
  struct tf_graph_t *self = (struct tf_graph_t *)scm_gc_calloc(sizeof(struct tf_graph_t), "make-graph");
  SCM_NEWSMOB(retval, tf_graph_tag, self);
  self->graph = TF_NewGraph();
  return retval;
}
DNNModel *ff_dnn_load_model_tf(const char *model_filename)
{
    DNNModel *model = NULL;
    TFModel *tf_model = NULL;
    TF_Buffer *graph_def;
    TF_ImportGraphDefOptions *graph_opts;

    model = av_malloc(sizeof(DNNModel));
    if (!model){
        return NULL;
    }

    tf_model = av_malloc(sizeof(TFModel));
    if (!tf_model){
        av_freep(&model);
        return NULL;
    }
    tf_model->session = NULL;
    tf_model->input_tensor = NULL;
    tf_model->output_data = NULL;

    graph_def = read_graph(model_filename);
    if (!graph_def){
        av_freep(&tf_model);
        av_freep(&model);
        return NULL;
    }
    tf_model->graph = TF_NewGraph();
    tf_model->status = TF_NewStatus();
    graph_opts = TF_NewImportGraphDefOptions();
    TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
    TF_DeleteImportGraphDefOptions(graph_opts);
    TF_DeleteBuffer(graph_def);
    if (TF_GetCode(tf_model->status) != TF_OK){
        TF_DeleteGraph(tf_model->graph);
        TF_DeleteStatus(tf_model->status);
        av_freep(&tf_model);
        av_freep(&model);
        return NULL;
    }

    model->model = (void *)tf_model;
    model->set_input_output = &set_input_output_tf;

    return model;
}
Exemple #3
0
static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename)
{
    TF_Buffer *graph_def;
    TF_ImportGraphDefOptions *graph_opts;

    graph_def = read_graph(model_filename);
    if (!graph_def){
        return DNN_ERROR;
    }
    tf_model->graph = TF_NewGraph();
    tf_model->status = TF_NewStatus();
    graph_opts = TF_NewImportGraphDefOptions();
    TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
    TF_DeleteImportGraphDefOptions(graph_opts);
    TF_DeleteBuffer(graph_def);
    if (TF_GetCode(tf_model->status) != TF_OK){
        TF_DeleteGraph(tf_model->graph);
        TF_DeleteStatus(tf_model->status);
        return DNN_ERROR;
    }

    return DNN_SUCCESS;
}
DNNModel *ff_dnn_load_default_model_tf(DNNDefaultModel model_type)
{
    DNNModel *model = NULL;
    TFModel *tf_model = NULL;
    TF_OperationDescription *op_desc;
    TF_Operation *op;
    TF_Output input;
    static const int64_t input_shape[] = {1, -1, -1, 1};
    static const char tanh[] = "Tanh";
    static const char sigmoid[] = "Sigmoid";
    static const char relu[] = "Relu";

    static const float *srcnn_consts[] = {
        srcnn_conv1_kernel,
        srcnn_conv1_bias,
        srcnn_conv2_kernel,
        srcnn_conv2_bias,
        srcnn_conv3_kernel,
        srcnn_conv3_bias
    };
    static const long int *srcnn_consts_dims[] = {
        srcnn_conv1_kernel_dims,
        srcnn_conv1_bias_dims,
        srcnn_conv2_kernel_dims,
        srcnn_conv2_bias_dims,
        srcnn_conv3_kernel_dims,
        srcnn_conv3_bias_dims
    };
    static const int srcnn_consts_dims_len[] = {
        4,
        1,
        4,
        1,
        4,
        1
    };
    static const char *srcnn_activations[] = {
        relu,
        relu,
        relu
    };

    static const float *espcn_consts[] = {
        espcn_conv1_kernel,
        espcn_conv1_bias,
        espcn_conv2_kernel,
        espcn_conv2_bias,
        espcn_conv3_kernel,
        espcn_conv3_bias
    };
    static const long int *espcn_consts_dims[] = {
        espcn_conv1_kernel_dims,
        espcn_conv1_bias_dims,
        espcn_conv2_kernel_dims,
        espcn_conv2_bias_dims,
        espcn_conv3_kernel_dims,
        espcn_conv3_bias_dims
    };
    static const int espcn_consts_dims_len[] = {
        4,
        1,
        4,
        1,
        4,
        1
    };
    static const char *espcn_activations[] = {
        tanh,
        tanh,
        sigmoid
    };

    input.index = 0;

    model = av_malloc(sizeof(DNNModel));
    if (!model){
        return NULL;
    }

    tf_model = av_malloc(sizeof(TFModel));
    if (!tf_model){
        av_freep(&model);
        return NULL;
    }
    tf_model->session = NULL;
    tf_model->input_tensor = NULL;
    tf_model->output_data = NULL;

    tf_model->graph = TF_NewGraph();
    tf_model->status = TF_NewStatus();

    #define CLEANUP_ON_ERROR(tf_model, model) { \
        TF_DeleteGraph(tf_model->graph); \
        TF_DeleteStatus(tf_model->status); \
        av_freep(&tf_model); \
        av_freep(&model); \
        return NULL; \
    }

    op_desc = TF_NewOperation(tf_model->graph, "Placeholder", "x");
    TF_SetAttrType(op_desc, "dtype", TF_FLOAT);
    TF_SetAttrShape(op_desc, "shape", input_shape, 4);
    op = TF_FinishOperation(op_desc, tf_model->status);
    if (TF_GetCode(tf_model->status) != TF_OK){
        CLEANUP_ON_ERROR(tf_model, model);
    }

    switch (model_type){
    case DNN_SRCNN:
        op = add_pad_op(tf_model, op, 6);
        if (!op){
            CLEANUP_ON_ERROR(tf_model, model);
        }
        op = add_conv_layers(tf_model, srcnn_consts,
                             srcnn_consts_dims, srcnn_consts_dims_len,
                             srcnn_activations, op, 3);
        if (!op){
            CLEANUP_ON_ERROR(tf_model, model);
        }
        break;
    case DNN_ESPCN:
        op = add_pad_op(tf_model, op, 4);
        if (!op){
            CLEANUP_ON_ERROR(tf_model, model);
        }
        op = add_conv_layers(tf_model, espcn_consts,
                             espcn_consts_dims, espcn_consts_dims_len,
                             espcn_activations, op, 3);
        if (!op){
            CLEANUP_ON_ERROR(tf_model, model);
        }

        op_desc = TF_NewOperation(tf_model->graph, "DepthToSpace", "depth_to_space");
        input.oper = op;
        TF_AddInput(op_desc, input);
        TF_SetAttrType(op_desc, "T", TF_FLOAT);
        TF_SetAttrInt(op_desc, "block_size", 2);
        op = TF_FinishOperation(op_desc, tf_model->status);
        if (TF_GetCode(tf_model->status) != TF_OK){
            CLEANUP_ON_ERROR(tf_model, model);
        }
        break;
    default:
        CLEANUP_ON_ERROR(tf_model, model);
    }

    op_desc = TF_NewOperation(tf_model->graph, "Identity", "y");
    input.oper = op;
    TF_AddInput(op_desc, input);
    TF_FinishOperation(op_desc, tf_model->status);
    if (TF_GetCode(tf_model->status) != TF_OK){
        CLEANUP_ON_ERROR(tf_model, model);
    }

    model->model = (void *)tf_model;
    model->set_input_output = &set_input_output_tf;

    return model;
}