 * load_shard_id_array returns the shard identifiers for a particular
 * distributed table as a bigint array. Uses pg_shard's shard interval
 * cache if the second parameter is true, otherwise eagerly loads the
 * shard intervals from the backing table.
	Oid distributedTableId = PG_GETARG_OID(0);
	bool useCache = PG_GETARG_BOOL(1);
	ArrayType *shardIdArrayType = NULL;
	ListCell *shardCell = NULL;
	int shardIdIndex = 0;
	Oid shardIdTypeId = INT8OID;

	List *shardList = NIL;
	int shardIdCount = -1;
	Datum *shardIdDatumArray = NULL;

	if (useCache)
		shardList = LookupShardIntervalList(distributedTableId);
		shardList = LoadShardIntervalList(distributedTableId);

	shardIdCount = list_length(shardList);
	shardIdDatumArray = palloc0(shardIdCount * sizeof(Datum));

	foreach(shardCell, shardList)
		ShardInterval *shardId = (ShardInterval *) lfirst(shardCell);
		Datum shardIdDatum = Int64GetDatum(shardId->id);

		shardIdDatumArray[shardIdIndex] = shardIdDatum;
Exemple #2
 * PrunedShardIdsForTable loads the shard intervals for the specified table,
 * prunes them using the provided clauses. It returns an ArrayType containing
 * the shard identifiers, suitable for return from an SQL-facing function.
static ArrayType *
PrunedShardIdsForTable(Oid distributedTableId, List *whereClauseList)
	ArrayType *shardIdArrayType = NULL;
	ListCell *shardCell = NULL;
	int shardIdIndex = 0;
	Oid shardIdTypeId = INT8OID;

	List *shardList = LoadShardIntervalList(distributedTableId);
	int shardIdCount = -1;
	Datum *shardIdDatumArray = NULL;

	shardList = PruneShardList(distributedTableId, whereClauseList, shardList);

	shardIdCount = list_length(shardList);
	shardIdDatumArray = palloc0(shardIdCount * sizeof(Datum));

	foreach(shardCell, shardList)
		ShardInterval *shardId = (ShardInterval *) lfirst(shardCell);
		Datum shardIdDatum = Int64GetDatum(shardId->id);

		shardIdDatumArray[shardIdIndex] = shardIdDatum;
 * TableShardReplicationFactor returns the current replication factor of the
 * given relation by looking into shard placements. It errors out if there
 * are different number of shard placements for different shards. It also
 * errors out if the table does not have any shards.
TableShardReplicationFactor(Oid relationId)
    uint32 replicationCount = 0;
    ListCell *shardCell = NULL;

    List *shardIntervalList = LoadShardIntervalList(relationId);
    foreach(shardCell, shardIntervalList)
        ShardInterval *shardInterval = (ShardInterval *) lfirst(shardCell);
        uint64 shardId = shardInterval->shardId;

        List *shardPlacementList = ShardPlacementList(shardId);
        uint32 shardPlacementCount = list_length(shardPlacementList);

         * Get the replication count of the first shard in the list, and error
         * out if there is a shard with different replication count.
        if (replicationCount == 0)
            replicationCount = shardPlacementCount;
        else if (replicationCount != shardPlacementCount)
            char *relationName = get_rel_name(relationId);
            ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
                            errmsg("cannot find the replication factor of the "
                                   "table %s", relationName),
                            errdetail("The shard %ld has different shards replication "
                                      "counts from other shards.", shardId)));
 * master_drop_all_shards attempts to drop all shards for a given relation.
 * Unlike master_apply_delete_command, this function can be called even
 * if the table has already been dropped.
	Oid relationId = PG_GETARG_OID(0);
	text *schemaNameText = PG_GETARG_TEXT_P(1);
	text *relationNameText = PG_GETARG_TEXT_P(2);

	List *shardIntervalList = NIL;
	int droppedShardCount = 0;

	char *schemaName = text_to_cstring(schemaNameText);
	char *relationName = text_to_cstring(relationNameText);


	 * The SQL_DROP trigger calls this function even for tables that are
	 * not distributed. In that case, silently ignore and return -1.
	if (!IsDistributedTable(relationId) || !EnableDDLPropagation)

	CheckTableSchemaNameForDrop(relationId, &schemaName, &relationName);

	 * master_drop_all_shards is typically called from the DROP TABLE trigger,
	 * but could be called by a user directly. Make sure we have an
	 * AccessExlusiveLock to prevent any other commands from running on this table
	 * concurrently.
	LockRelationOid(relationId, AccessExclusiveLock);

	shardIntervalList = LoadShardIntervalList(relationId);
	droppedShardCount = DropShards(relationId, schemaName, relationName,

 * CreateColocatedShards creates shards for the target relation colocated with
 * the source relation.
CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool
	char targetShardStorageType = 0;
	List *existingShardList = NIL;
	List *sourceShardIntervalList = NIL;
	ListCell *sourceShardCell = NULL;
	bool colocatedShard = true;
	List *insertedShardPlacements = NIL;

	/* make sure that tables are hash partitioned */

	 * In contrast to append/range partitioned tables it makes more sense to
	 * require ownership privileges - shards for hash-partitioned tables are
	 * only created once, not continually during ingest as for the other
	 * partitioning types.

	/* we plan to add shards: get an exclusive lock on target relation oid */
	LockRelationOid(targetRelationId, ExclusiveLock);

	/* we don't want source table to get dropped before we colocate with it */
	LockRelationOid(sourceRelationId, AccessShareLock);

	/* prevent placement changes of the source relation until we colocate with them */
	sourceShardIntervalList = LoadShardIntervalList(sourceRelationId);
	LockShardListMetadata(sourceShardIntervalList, ShareLock);

	/* validate that shards haven't already been created for this table */
	existingShardList = LoadShardList(targetRelationId);
	if (existingShardList != NIL)
		char *targetRelationName = get_rel_name(targetRelationId);
						errmsg("table \"%s\" has already had shards created for it",

	targetShardStorageType = ShardStorageType(targetRelationId);

	foreach(sourceShardCell, sourceShardIntervalList)
		ShardInterval *sourceShardInterval = (ShardInterval *) lfirst(sourceShardCell);
		uint64 sourceShardId = sourceShardInterval->shardId;
		uint64 newShardId = GetNextShardId();
		ListCell *sourceShardPlacementCell = NULL;

		int32 shardMinValue = DatumGetInt32(sourceShardInterval->minValue);
		int32 shardMaxValue = DatumGetInt32(sourceShardInterval->maxValue);
		text *shardMinValueText = IntegerToText(shardMinValue);
		text *shardMaxValueText = IntegerToText(shardMaxValue);
		List *sourceShardPlacementList = ShardPlacementList(sourceShardId);

		InsertShardRow(targetRelationId, newShardId, targetShardStorageType,
					   shardMinValueText, shardMaxValueText);

		foreach(sourceShardPlacementCell, sourceShardPlacementList)
			ShardPlacement *sourcePlacement =
				(ShardPlacement *) lfirst(sourceShardPlacementCell);
			uint32 groupId = sourcePlacement->groupId;
			const RelayFileState shardState = FILE_FINALIZED;
			const uint64 shardSize = 0;
			uint64 shardPlacementId = 0;
			ShardPlacement *shardPlacement = NULL;

			 * Optimistically add shard placement row the pg_dist_shard_placement, in case
			 * of any error it will be roll-backed.
			shardPlacementId = InsertShardPlacementRow(newShardId, INVALID_PLACEMENT_ID,
													   shardState, shardSize, groupId);

			shardPlacement = LoadShardPlacement(newShardId, shardPlacementId);
			insertedShardPlacements = lappend(insertedShardPlacements, shardPlacement);
Exemple #6
 * master_create_worker_shards creates empty shards for the given table based
 * on the specified number of initial shards. The function first gets a list of
 * candidate nodes and issues DDL commands on the nodes to create empty shard
 * placements on those nodes. The function then updates metadata on the master
 * node to make this shard (and its placements) visible. Note that the function
 * assumes the table is hash partitioned and calculates the min/max hash token
 * ranges for each shard, giving them an equal split of the hash space.
	text *tableNameText = PG_GETARG_TEXT_P(0);
	int32 shardCount = PG_GETARG_INT32(1);
	int32 replicationFactor = PG_GETARG_INT32(2);

	Oid distributedTableId = ResolveRelationId(tableNameText);
	char relationKind = get_rel_relkind(distributedTableId);
	char *tableName = text_to_cstring(tableNameText);
	char shardStorageType = '\0';
	int32 shardIndex = 0;
	List *workerNodeList = NIL;
	List *ddlCommandList = NIL;
	int32 workerNodeCount = 0;
	uint32 placementAttemptCount = 0;
	uint32 hashTokenIncrement = 0;
	List *existingShardList = NIL;

	/* make sure table is hash partitioned */

	/* validate that shards haven't already been created for this table */
	existingShardList = LoadShardIntervalList(distributedTableId);
	if (existingShardList != NIL)
						errmsg("table \"%s\" has already had shards created for it",

	/* make sure that at least one shard is specified */
	if (shardCount <= 0)
						errmsg("shardCount must be positive")));

	/* make sure that at least one replica is specified */
	if (replicationFactor <= 0)
						errmsg("replicationFactor must be positive")));

	/* calculate the split of the hash space */
	hashTokenIncrement = UINT_MAX / shardCount;

	/* load and sort the worker node list for deterministic placement */
	workerNodeList = ParseWorkerNodeFile(WORKER_LIST_FILENAME);
	workerNodeList = SortList(workerNodeList, CompareWorkerNodes);

	/* make sure we don't process cancel signals until all shards are created */

	/* retrieve the DDL commands for the table */
	ddlCommandList = TableDDLCommandList(distributedTableId);

	workerNodeCount = list_length(workerNodeList);
	if (replicationFactor > workerNodeCount)
						errmsg("replicationFactor (%d) exceeds number of worker nodes "
							   "(%d)", replicationFactor, workerNodeCount),
						errhint("Add more worker nodes or try again with a lower "
								"replication factor.")));

	/* if we have enough nodes, add an extra placement attempt for backup */
	placementAttemptCount = (uint32) replicationFactor;
	if (workerNodeCount > replicationFactor)

	/* set shard storage type according to relation type */
	if (relationKind == RELKIND_FOREIGN_TABLE)
		shardStorageType = SHARD_STORAGE_FOREIGN;
		shardStorageType = SHARD_STORAGE_TABLE;

	for (shardIndex = 0; shardIndex < shardCount; shardIndex++)
		uint64 shardId = NextSequenceId(SHARD_ID_SEQUENCE_NAME);
		int32 placementCount = 0;
		uint32 placementIndex = 0;
		uint32 roundRobinNodeIndex = shardIndex % workerNodeCount;

		List *extendedDDLCommands = ExtendedDDLCommandList(distributedTableId, shardId,

		/* initialize the hash token space for this shard */
		text *minHashTokenText = NULL;
		text *maxHashTokenText = NULL;
		int32 shardMinHashToken = INT_MIN + (shardIndex * hashTokenIncrement);
		int32 shardMaxHashToken = shardMinHashToken + hashTokenIncrement - 1;

		/* if we are at the last shard, make sure the max token value is INT_MAX */
		if (shardIndex == (shardCount - 1))
			shardMaxHashToken = INT_MAX;

		for (placementIndex = 0; placementIndex < placementAttemptCount; placementIndex++)
			int32 candidateNodeIndex =
				(roundRobinNodeIndex + placementIndex) % workerNodeCount;
			WorkerNode *candidateNode = (WorkerNode *) list_nth(workerNodeList,
			char *nodeName = candidateNode->nodeName;
			uint32 nodePort = candidateNode->nodePort;

			bool created = ExecuteRemoteCommandList(nodeName, nodePort,
			if (created)
				uint64 shardPlacementId = 0;
				ShardState shardState = STATE_FINALIZED;

				shardPlacementId = NextSequenceId(SHARD_PLACEMENT_ID_SEQUENCE_NAME);
				InsertShardPlacementRow(shardPlacementId, shardId, shardState,
										nodeName, nodePort);
				ereport(WARNING, (errmsg("could not create shard on \"%s:%u\"",
										 nodeName, nodePort)));

			if (placementCount >= replicationFactor)

		/* check if we created enough shard replicas */
		if (placementCount < replicationFactor)
			ereport(ERROR, (errmsg("could not satisfy specified replication factor"),
							errdetail("Created %d shard replicas, less than the "
									  "requested replication factor of %d.",
									  placementCount, replicationFactor)));

		/* insert the shard metadata row along with its min/max values */
		minHashTokenText = IntegerToText(shardMinHashToken);
		maxHashTokenText = IntegerToText(shardMaxHashToken);
		InsertShardRow(distributedTableId, shardId, shardStorageType,
					   minHashTokenText, maxHashTokenText);

	if (QueryCancelPending)
		ereport(WARNING, (errmsg("cancel requests are ignored during shard creation")));
		QueryCancelPending = false;


 * master_apply_delete_command takes in a delete command, finds shards that
 * match the criteria defined in the delete command, drops the found shards from
 * the worker nodes, and updates the corresponding metadata on the master node.
 * This function drops a shard if and only if all rows in the shard satisfy
 * the conditions in the delete command. Note that this function only accepts
 * conditions on the partition key and if no condition is provided then all
 * shards are deleted.
 * We mark shard placements that we couldn't drop as to be deleted later. If a
 * shard satisfies the given conditions, we delete it from shard metadata table
 * even though related shard placements are not deleted.
	text *queryText = PG_GETARG_TEXT_P(0);
	char *queryString = text_to_cstring(queryText);
	char *relationName = NULL;
	char *schemaName = NULL;
	Oid relationId = InvalidOid;
	List *shardIntervalList = NIL;
	List *deletableShardIntervalList = NIL;
	List *queryTreeList = NIL;
	Query *deleteQuery = NULL;
	Node *whereClause = NULL;
	Node *deleteCriteria = NULL;
	Node *queryTreeNode = NULL;
	DeleteStmt *deleteStatement = NULL;
	int droppedShardCount = 0;
	LOCKMODE lockMode = 0;
	char partitionMethod = 0;
	bool failOK = false;
#if (PG_VERSION_NUM >= 100000)
	RawStmt *rawStmt = (RawStmt *) ParseTreeRawStmt(queryString);
	queryTreeNode = rawStmt->stmt;
	queryTreeNode = ParseTreeNode(queryString);


	if (!IsA(queryTreeNode, DeleteStmt))
		ereport(ERROR, (errmsg("query \"%s\" is not a delete statement",

	deleteStatement = (DeleteStmt *) queryTreeNode;

	schemaName = deleteStatement->relation->schemaname;
	relationName = deleteStatement->relation->relname;

	 * We take an exclusive lock while dropping shards to prevent concurrent
	 * writes. We don't want to block SELECTs, which means queries might fail
	 * if they access a shard that has just been dropped.
	lockMode = ExclusiveLock;

	relationId = RangeVarGetRelid(deleteStatement->relation, lockMode, failOK);

	/* schema-prefix if it is not specified already */
	if (schemaName == NULL)
		Oid schemaId = get_rel_namespace(relationId);
		schemaName = get_namespace_name(schemaId);

	EnsureTablePermissions(relationId, ACL_DELETE);

#if (PG_VERSION_NUM >= 100000)
	queryTreeList = pg_analyze_and_rewrite(rawStmt, queryString, NULL, 0, NULL);
	queryTreeList = pg_analyze_and_rewrite(queryTreeNode, queryString, NULL, 0);
	deleteQuery = (Query *) linitial(queryTreeList);

	/* get where clause and flatten it */
	whereClause = (Node *) deleteQuery->jointree->quals;
	deleteCriteria = eval_const_expressions(NULL, whereClause);

	partitionMethod = PartitionMethod(relationId);
	if (partitionMethod == DISTRIBUTE_BY_HASH)
						errmsg("cannot delete from hash distributed table with this "
						errdetail("Delete statements on hash-partitioned tables "
								  "are not supported with master_apply_delete_command."),
						errhint("Use master_modify_multiple_shards command instead.")));
	else if (partitionMethod == DISTRIBUTE_BY_NONE)
						errmsg("cannot delete from distributed table"),
						errdetail("Delete statements on reference tables "
								  "are not supported.")));

	CheckPartitionColumn(relationId, deleteCriteria);

	shardIntervalList = LoadShardIntervalList(relationId);

	/* drop all shards if where clause is not present */
	if (deleteCriteria == NULL)
		deletableShardIntervalList = shardIntervalList;
		ereport(DEBUG2, (errmsg("dropping all shards for \"%s\"", relationName)));
		deletableShardIntervalList = ShardsMatchingDeleteCriteria(relationId,

	droppedShardCount = DropShards(relationId, schemaName, relationName,

Exemple #8
 * CitusCopyFrom implements the COPY table_name FROM ... for hash-partitioned
 * and range-partitioned tables.
CitusCopyFrom(CopyStmt *copyStatement, char *completionTag)
	Oid tableId = RangeVarGetRelid(copyStatement->relation, NoLock, false);
	char *relationName = get_rel_name(tableId);
	Relation distributedRelation = NULL;
	char partitionMethod = '\0';
	Var *partitionColumn = NULL;
	TupleDesc tupleDescriptor = NULL;
	uint32 columnCount = 0;
	Datum *columnValues = NULL;
	bool *columnNulls = NULL;
	TypeCacheEntry *typeEntry = NULL;
	FmgrInfo *hashFunction = NULL;
	FmgrInfo *compareFunction = NULL;

	int shardCount = 0;
	List *shardIntervalList = NULL;
	ShardInterval **shardIntervalCache = NULL;
	bool useBinarySearch = false;

	HTAB *shardConnectionHash = NULL;
	ShardConnections *shardConnections = NULL;
	List *connectionList = NIL;

	EState *executorState = NULL;
	MemoryContext executorTupleContext = NULL;
	ExprContext *executorExpressionContext = NULL;

	CopyState copyState = NULL;
	CopyOutState copyOutState = NULL;
	FmgrInfo *columnOutputFunctions = NULL;
	uint64 processedRowCount = 0;

	/* disallow COPY to/from file or program except for superusers */
	if (copyStatement->filename != NULL && !superuser())
		if (copyStatement->is_program)
					 errmsg("must be superuser to COPY to or from an external program"),
					 errhint("Anyone can COPY to stdout or from stdin. "
							 "psql's \\copy command also works for anyone.")));
					 errmsg("must be superuser to COPY to or from a file"),
					 errhint("Anyone can COPY to stdout or from stdin. "
							 "psql's \\copy command also works for anyone.")));

	partitionColumn = PartitionColumn(tableId, 0);
	partitionMethod = PartitionMethod(tableId);
	if (partitionMethod != DISTRIBUTE_BY_RANGE && partitionMethod != DISTRIBUTE_BY_HASH)
						errmsg("COPY is only supported for hash- and "
							   "range-partitioned tables")));

	/* resolve hash function for partition column */
	typeEntry = lookup_type_cache(partitionColumn->vartype, TYPECACHE_HASH_PROC_FINFO);
	hashFunction = &(typeEntry->hash_proc_finfo);

	/* resolve compare function for shard intervals */
	compareFunction = ShardIntervalCompareFunction(partitionColumn, partitionMethod);

	/* allocate column values and nulls arrays */
	distributedRelation = heap_open(tableId, RowExclusiveLock);
	tupleDescriptor = RelationGetDescr(distributedRelation);
	columnCount = tupleDescriptor->natts;
	columnValues = palloc0(columnCount * sizeof(Datum));
	columnNulls = palloc0(columnCount * sizeof(bool));

	/* load the list of shards and verify that we have shards to copy into */
	shardIntervalList = LoadShardIntervalList(tableId);
	if (shardIntervalList == NIL)
		if (partitionMethod == DISTRIBUTE_BY_HASH)
							errmsg("could not find any shards into which to copy"),
							errdetail("No shards exist for distributed table \"%s\".",
							errhint("Run master_create_worker_shards to create shards "
									"and try again.")));
							errmsg("could not find any shards into which to copy"),
							errdetail("No shards exist for distributed table \"%s\".",

	/* prevent concurrent placement changes and non-commutative DML statements */

	/* initialize the shard interval cache */
	shardCount = list_length(shardIntervalList);
	shardIntervalCache = SortedShardIntervalArray(shardIntervalList);

	/* determine whether to use binary search */
	if (partitionMethod != DISTRIBUTE_BY_HASH ||
		!IsUniformHashDistribution(shardIntervalCache, shardCount))
		useBinarySearch = true;

	/* initialize copy state to read from COPY data source */
	copyState = BeginCopyFrom(distributedRelation,

	executorState = CreateExecutorState();
	executorTupleContext = GetPerTupleMemoryContext(executorState);
	executorExpressionContext = GetPerTupleExprContext(executorState);

	copyOutState = (CopyOutState) palloc0(sizeof(CopyOutStateData));
	copyOutState->binary = true;
	copyOutState->fe_msgbuf = makeStringInfo();
	copyOutState->rowcontext = executorTupleContext;

	columnOutputFunctions = ColumnOutputFunctions(tupleDescriptor, copyOutState->binary);

	 * Create a mapping of shard id to a connection for each of its placements.
	 * The hash should be initialized before the PG_TRY, since it is used and
	 * PG_CATCH. Otherwise, it may be undefined in the PG_CATCH (see sigsetjmp
	 * documentation).
	shardConnectionHash = CreateShardConnectionHash();

	/* we use a PG_TRY block to roll back on errors (e.g. in NextCopyFrom) */
		ErrorContextCallback errorCallback;

		/* set up callback to identify error line number */
		errorCallback.callback = CopyFromErrorCallback;
		errorCallback.arg = (void *) copyState;
		errorCallback.previous = error_context_stack;
		error_context_stack = &errorCallback;

		/* ensure transactions have unique names on worker nodes */

		while (true)
			bool nextRowFound = false;
			Datum partitionColumnValue = 0;
			ShardInterval *shardInterval = NULL;
			int64 shardId = 0;
			bool shardConnectionsFound = false;
			MemoryContext oldContext = NULL;


			oldContext = MemoryContextSwitchTo(executorTupleContext);

			/* parse a row from the input */
			nextRowFound = NextCopyFrom(copyState, executorExpressionContext,
										columnValues, columnNulls, NULL);

			if (!nextRowFound)


			/* find the partition column value */

			if (columnNulls[partitionColumn->varattno - 1])
								errmsg("cannot copy row with NULL value "
									   "in partition column")));

			partitionColumnValue = columnValues[partitionColumn->varattno - 1];

			/* find the shard interval and id for the partition column value */
			shardInterval = FindShardInterval(partitionColumnValue, shardIntervalCache,
											  shardCount, partitionMethod,
											  compareFunction, hashFunction,
			if (shardInterval == NULL)
								errmsg("could not find shard for partition column "

			shardId = shardInterval->shardId;


			/* get existing connections to the shard placements, if any */
			shardConnections = GetShardConnections(shardConnectionHash,
			if (!shardConnectionsFound)
				/* open connections and initiate COPY on shard placements */
				OpenCopyTransactions(copyStatement, shardConnections);

				/* send binary headers to shard placements */

			/* replicate row to shard placements */
			AppendCopyRowData(columnValues, columnNulls, tupleDescriptor,
							  copyOutState, columnOutputFunctions);
			SendCopyDataToAll(copyOutState->fe_msgbuf, shardConnections->connectionList);

			processedRowCount += 1;

		connectionList = ConnectionList(shardConnectionHash);

		/* send binary footers to all shard placements */
		SendCopyDataToAll(copyOutState->fe_msgbuf, connectionList);

		/* all lines have been copied, stop showing line number in errors */
		error_context_stack = errorCallback.previous;

		/* close the COPY input on all shard placements */
		EndRemoteCopy(connectionList, true);

		if (CopyTransactionManager == TRANSACTION_MANAGER_2PC)

		heap_close(distributedRelation, NoLock);

		/* check for cancellation one last time before committing */
		List *abortConnectionList = NIL;

		/* roll back all transactions */
		abortConnectionList = ConnectionList(shardConnectionHash);
		EndRemoteCopy(abortConnectionList, false);


	 * Ready to commit the transaction, this code is below the PG_TRY block because
	 * we do not want any of the transactions rolled back if a failure occurs. Instead,
	 * they should be rolled forward.

	if (completionTag != NULL)
		snprintf(completionTag, COMPLETION_TAG_BUFSIZE,
				 "COPY " UINT64_FORMAT, processedRowCount);