Ejemplo n.º 1
0
static void test_safe_equal(void)
{
    assert_true(StringSafeEqual(NULL, NULL));
    assert_false(StringSafeEqual("a", NULL));
    assert_false(StringSafeEqual(NULL, "a"));
    assert_false(StringSafeEqual("a", "b"));
    assert_true(StringSafeEqual("a", "a"));
}
Ejemplo n.º 2
0
static bool MethodsParseTreeCheck(const Promise *pp, Seq *errors)
{
    bool success = true;

    for (size_t i = 0; i < SeqLength(pp->conlist); i++)
    {
        const Constraint *cp = SeqAt(pp->conlist, i);

        // ensure: if call and callee are resolved, then they have matching arity
        if (StringSafeEqual(cp->lval, "usebundle"))
        {
            if (cp->rval.type == RVAL_TYPE_FNCALL)
            {
                const FnCall *call = (const FnCall *)cp->rval.item;
                const Bundle *callee = PolicyGetBundle(PolicyFromPromise(pp), NULL, "agent", call->name);
                if (!callee)
                {
                    callee = PolicyGetBundle(PolicyFromPromise(pp), NULL, "common", call->name);
                }

                if (callee)
                {
                    if (RlistLen(call->args) != RlistLen(callee->args))
                    {
                        SeqAppend(errors, PolicyErrorNew(POLICY_ELEMENT_TYPE_CONSTRAINT, cp,
                                                         POLICY_ERROR_METHODS_BUNDLE_ARITY,
                                                         call->name, RlistLen(callee->args), RlistLen(call->args)));
                        success = false;
                    }
                }
            }
        }
    }
    return success;
}
Ejemplo n.º 3
0
static void RenameMainBundle(EvalContext *ctx, Policy *policy)
{
    assert(policy != NULL);
    assert(ctx != NULL);
    assert(policy->bundles != NULL);
    char *const entry_point = GetRealPath(EvalContextGetEntryPoint(ctx));
    if (NULL_OR_EMPTY(entry_point))
    {
        free(entry_point);
        return;
    }
    Seq *bundles = policy->bundles;
    int length = SeqLength(bundles);
    bool removed = false;
    for (int i = 0; i < length; ++i)
    {
        Bundle *const bundle = SeqAt(bundles, i);
        if (StringSafeEqual(bundle->name, "__main__"))
        {
            char *abspath = GetRealPath(bundle->source_path);
            if (StringSafeEqual(abspath, entry_point))
            {
                Log(LOG_LEVEL_VERBOSE,
                    "Redefining __main__ bundle from file %s to be main",
                    abspath);
                strncpy(bundle->name, "main", 4+1);
                // "__main__" is always big enough for "main"
            }
            else
            {
                Log(LOG_LEVEL_VERBOSE,
                    "Dropping __main__ bundle from file %s (entry point: %s)",
                    abspath,
                    entry_point);
                removed = true;
                SeqSet(bundles, i, NULL); // SeqSet calls destroy function
            }
            free(abspath);
        }
    }
    if (removed)
    {
        SeqRemoveNulls(bundles);
    }
    free(entry_point);
}
Ejemplo n.º 4
0
static bool MethodsParseTreeCheck(const Promise *pp, Seq *errors)
{
    bool success = true;

    for (size_t i = 0; i < SeqLength(pp->conlist); i++)
    {
        const Constraint *cp = SeqAt(pp->conlist, i);

        // ensure: if call and callee are resolved, then they have matching arity
        if (StringSafeEqual(cp->lval, "usebundle"))
        {
            if (cp->rval.type == RVAL_TYPE_FNCALL)
            {
                // HACK: exploiting the fact that class-references and call-references are similar
                FnCall *call = RvalFnCallValue(cp->rval);
                ClassRef ref = ClassRefParse(call->name);
                if (!ClassRefIsQualified(ref))
                {
                    ClassRefQualify(&ref, PromiseGetNamespace(pp));
                }

                const Bundle *callee = PolicyGetBundle(PolicyFromPromise(pp), ref.ns, "agent", ref.name);
                if (!callee)
                {
                    callee = PolicyGetBundle(PolicyFromPromise(pp), ref.ns, "common", ref.name);
                }

                ClassRefDestroy(ref);

                if (callee)
                {
                    if (RlistLen(call->args) != RlistLen(callee->args))
                    {
                        SeqAppend(errors, PolicyErrorNew(POLICY_ELEMENT_TYPE_CONSTRAINT, cp,
                                                         POLICY_ERROR_METHODS_BUNDLE_ARITY,
                                                         call->name, RlistLen(callee->args), RlistLen(call->args)));
                        success = false;
                    }
                }
            }
        }
    }
    return success;
}
Ejemplo n.º 5
0
static void *CFTestD_ServeReport(void *config_arg)
{
    CFTestD_Config *config = (CFTestD_Config *) config_arg;

    /* Set prefix for all Log()ging: */
    LoggingPrivContext *prior = LoggingPrivGetContext();
    LoggingPrivContext log_ctx = {
        .log_hook = LogAddPrefix,
        .param = config->address
    };
    LoggingPrivSetContext(&log_ctx);

    char *priv_key_path = NULL;
    char *pub_key_path = NULL;
    if (config->key_file != NULL)
    {
        priv_key_path = config->key_file;
        pub_key_path = xstrdup(priv_key_path);
        StringReplace(pub_key_path, strlen(pub_key_path) + 1,
                      "priv", "pub");
    }

    LoadSecretKeys(priv_key_path, pub_key_path, &(config->priv_key), &(config->pub_key));
    free(pub_key_path);

    char *report_file = config->report_file;

    if (report_file != NULL)
    {
        Log(LOG_LEVEL_NOTICE, "Got file argument: '%s'", report_file);
        if (!FileCanOpen(report_file, "r"))
        {
            Log(LOG_LEVEL_ERR,
                "Can't open file '%s' for reading",
                report_file);
            exit(EXIT_FAILURE);
        }

        Writer *contents = FileRead(report_file, SIZE_MAX, NULL);
        if (!contents)
        {
            Log(LOG_LEVEL_ERR, "Error reading report file '%s'", report_file);
            exit(EXIT_FAILURE);
        }

        size_t report_data_len = StringWriterLength(contents);
        config->report_data = StringWriterClose(contents);

        Seq *report = SeqNew(64, NULL);
        size_t report_len = 0;

        StringRef ts_ref = StringGetToken(config->report_data, report_data_len, 0, "\n");
        char *ts = (char *) ts_ref.data;
        *(ts + ts_ref.len) = '\0';
        SeqAppend(report, ts);

        /* start right after the newline after the timestamp header */
        char *position = ts + ts_ref.len + 1;
        char *report_line;
        size_t report_line_len;
        while (CFTestD_GetReportLine(position, &report_line, &report_line_len))
        {
            *(report_line + report_line_len) = '\0';
            SeqAppend(report, report_line);
            report_len += report_line_len;
            position = report_line + report_line_len + 1; /* there's an extra newline after each report_line */
        }

        config->report = report;
        config->report_len = report_len;

        Log(LOG_LEVEL_NOTICE,
            "Read %d bytes for report contents",
            config->report_len);

        if (config->report_len <= 0)
        {
            Log(LOG_LEVEL_ERR, "Report file contained no bytes");
            exit(EXIT_FAILURE);
        }
    }

    Log(LOG_LEVEL_INFO, "Starting server at %s...", config->address);
    fflush(stdout); // for debugging startup

    config->ret = CFTestD_StartServer(config);

    free(config->report_data);

    /* we don't really need to do this here because the process is about the
     * terminate, but it's a good way the cleanup actually works and doesn't
     * cause a segfault or something */
    ServerTLSDeInitialize(&(config->priv_key), &(config->pub_key), &(config->ssl_ctx));

    LoggingPrivSetContext(prior);

    return NULL;
}

static void HandleSignal(int signum)
{
    switch (signum)
    {
    case SIGTERM:
    case SIGINT:
        // flush all logging before process ends.
        fflush(stdout);
        fprintf(stderr, "Terminating...\n");
        TERMINATE = true;
        break;
    default:
        break;
    }
}

/**
 * @param ip_str string representation of an IPv4 address (the usual one, with
 *               4 octets separated by dots)
 * @return a new string representing the incremented IP address (HAS TO BE FREED)
 */
static char *IncrementIPaddress(const char *ip_str)
{
    uint32_t ip = (uint32_t) inet_addr(ip_str);
    if (ip == INADDR_NONE)
    {
        Log(LOG_LEVEL_ERR, "Failed to parse address: '%s'", ip_str);
        return NULL;
    }

    int step = 1;
    char *last_dot = strrchr(ip_str, '.');
    assert(last_dot != NULL);   /* the doc comment says there must be dots! */
    if (StringSafeEqual(last_dot + 1, "255"))
    {
        /* avoid the network address (ending with 0) */
        step = 2;
    }
    else if (StringSafeEqual(last_dot + 1, "254"))
    {
        /* avoid the broadcast address and the network address */
        step = 3;
    }

    uint32_t ip_num = ntohl(ip);
    ip_num += step;
    ip = htonl(ip_num);

    struct in_addr ip_struct;
    ip_struct.s_addr = ip;

    return xstrdup(inet_ntoa(ip_struct));
}
Ejemplo n.º 6
0
bool StringSafeEqual_untyped(const void *a, const void *b)
{
    return StringSafeEqual(a, b);
}