コード例 #1
0
ファイル: multi_copy.c プロジェクト: cuulee/citus
/*
 * CitusCopyFrom implements the COPY table_name FROM ... for hash-partitioned
 * and range-partitioned tables.
 */
void
CitusCopyFrom(CopyStmt *copyStatement, char *completionTag)
{
	Oid tableId = RangeVarGetRelid(copyStatement->relation, NoLock, false);
	char *relationName = get_rel_name(tableId);
	Relation distributedRelation = NULL;
	char partitionMethod = '\0';
	Var *partitionColumn = NULL;
	TupleDesc tupleDescriptor = NULL;
	uint32 columnCount = 0;
	Datum *columnValues = NULL;
	bool *columnNulls = NULL;
	TypeCacheEntry *typeEntry = NULL;
	FmgrInfo *hashFunction = NULL;
	FmgrInfo *compareFunction = NULL;

	int shardCount = 0;
	List *shardIntervalList = NULL;
	ShardInterval **shardIntervalCache = NULL;
	bool useBinarySearch = false;

	HTAB *shardConnectionHash = NULL;
	ShardConnections *shardConnections = NULL;
	List *connectionList = NIL;

	EState *executorState = NULL;
	MemoryContext executorTupleContext = NULL;
	ExprContext *executorExpressionContext = NULL;

	CopyState copyState = NULL;
	CopyOutState copyOutState = NULL;
	FmgrInfo *columnOutputFunctions = NULL;
	uint64 processedRowCount = 0;

	/* disallow COPY to/from file or program except for superusers */
	if (copyStatement->filename != NULL && !superuser())
	{
		if (copyStatement->is_program)
		{
			ereport(ERROR,
					(errcode(ERRCODE_INSUFFICIENT_PRIVILEGE),
					 errmsg("must be superuser to COPY to or from an external program"),
					 errhint("Anyone can COPY to stdout or from stdin. "
							 "psql's \\copy command also works for anyone.")));
		}
		else
		{
			ereport(ERROR,
					(errcode(ERRCODE_INSUFFICIENT_PRIVILEGE),
					 errmsg("must be superuser to COPY to or from a file"),
					 errhint("Anyone can COPY to stdout or from stdin. "
							 "psql's \\copy command also works for anyone.")));
		}
	}

	partitionColumn = PartitionColumn(tableId, 0);
	partitionMethod = PartitionMethod(tableId);
	if (partitionMethod != DISTRIBUTE_BY_RANGE && partitionMethod != DISTRIBUTE_BY_HASH)
	{
		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
						errmsg("COPY is only supported for hash- and "
							   "range-partitioned tables")));
	}

	/* resolve hash function for partition column */
	typeEntry = lookup_type_cache(partitionColumn->vartype, TYPECACHE_HASH_PROC_FINFO);
	hashFunction = &(typeEntry->hash_proc_finfo);

	/* resolve compare function for shard intervals */
	compareFunction = ShardIntervalCompareFunction(partitionColumn, partitionMethod);

	/* allocate column values and nulls arrays */
	distributedRelation = heap_open(tableId, RowExclusiveLock);
	tupleDescriptor = RelationGetDescr(distributedRelation);
	columnCount = tupleDescriptor->natts;
	columnValues = palloc0(columnCount * sizeof(Datum));
	columnNulls = palloc0(columnCount * sizeof(bool));

	/* load the list of shards and verify that we have shards to copy into */
	shardIntervalList = LoadShardIntervalList(tableId);
	if (shardIntervalList == NIL)
	{
		if (partitionMethod == DISTRIBUTE_BY_HASH)
		{
			ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
							errmsg("could not find any shards into which to copy"),
							errdetail("No shards exist for distributed table \"%s\".",
									  relationName),
							errhint("Run master_create_worker_shards to create shards "
									"and try again.")));
		}
		else
		{
			ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
							errmsg("could not find any shards into which to copy"),
							errdetail("No shards exist for distributed table \"%s\".",
									  relationName)));
		}
	}

	/* prevent concurrent placement changes and non-commutative DML statements */
	LockAllShards(shardIntervalList);

	/* initialize the shard interval cache */
	shardCount = list_length(shardIntervalList);
	shardIntervalCache = SortedShardIntervalArray(shardIntervalList);

	/* determine whether to use binary search */
	if (partitionMethod != DISTRIBUTE_BY_HASH ||
		!IsUniformHashDistribution(shardIntervalCache, shardCount))
	{
		useBinarySearch = true;
	}

	/* initialize copy state to read from COPY data source */
	copyState = BeginCopyFrom(distributedRelation,
							  copyStatement->filename,
							  copyStatement->is_program,
							  copyStatement->attlist,
							  copyStatement->options);

	executorState = CreateExecutorState();
	executorTupleContext = GetPerTupleMemoryContext(executorState);
	executorExpressionContext = GetPerTupleExprContext(executorState);

	copyOutState = (CopyOutState) palloc0(sizeof(CopyOutStateData));
	copyOutState->binary = true;
	copyOutState->fe_msgbuf = makeStringInfo();
	copyOutState->rowcontext = executorTupleContext;

	columnOutputFunctions = ColumnOutputFunctions(tupleDescriptor, copyOutState->binary);

	/*
	 * Create a mapping of shard id to a connection for each of its placements.
	 * The hash should be initialized before the PG_TRY, since it is used and
	 * PG_CATCH. Otherwise, it may be undefined in the PG_CATCH (see sigsetjmp
	 * documentation).
	 */
	shardConnectionHash = CreateShardConnectionHash();

	/* we use a PG_TRY block to roll back on errors (e.g. in NextCopyFrom) */
	PG_TRY();
	{
		ErrorContextCallback errorCallback;

		/* set up callback to identify error line number */
		errorCallback.callback = CopyFromErrorCallback;
		errorCallback.arg = (void *) copyState;
		errorCallback.previous = error_context_stack;
		error_context_stack = &errorCallback;

		/* ensure transactions have unique names on worker nodes */
		InitializeDistributedTransaction();

		while (true)
		{
			bool nextRowFound = false;
			Datum partitionColumnValue = 0;
			ShardInterval *shardInterval = NULL;
			int64 shardId = 0;
			bool shardConnectionsFound = false;
			MemoryContext oldContext = NULL;

			ResetPerTupleExprContext(executorState);

			oldContext = MemoryContextSwitchTo(executorTupleContext);

			/* parse a row from the input */
			nextRowFound = NextCopyFrom(copyState, executorExpressionContext,
										columnValues, columnNulls, NULL);

			if (!nextRowFound)
			{
				MemoryContextSwitchTo(oldContext);
				break;
			}

			CHECK_FOR_INTERRUPTS();

			/* find the partition column value */

			if (columnNulls[partitionColumn->varattno - 1])
			{
				ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
								errmsg("cannot copy row with NULL value "
									   "in partition column")));
			}

			partitionColumnValue = columnValues[partitionColumn->varattno - 1];

			/* find the shard interval and id for the partition column value */
			shardInterval = FindShardInterval(partitionColumnValue, shardIntervalCache,
											  shardCount, partitionMethod,
											  compareFunction, hashFunction,
											  useBinarySearch);
			if (shardInterval == NULL)
			{
				ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
								errmsg("could not find shard for partition column "
									   "value")));
			}

			shardId = shardInterval->shardId;

			MemoryContextSwitchTo(oldContext);

			/* get existing connections to the shard placements, if any */
			shardConnections = GetShardConnections(shardConnectionHash,
												   shardId,
												   &shardConnectionsFound);
			if (!shardConnectionsFound)
			{
				/* open connections and initiate COPY on shard placements */
				OpenCopyTransactions(copyStatement, shardConnections);

				/* send binary headers to shard placements */
				resetStringInfo(copyOutState->fe_msgbuf);
				AppendCopyBinaryHeaders(copyOutState);
				SendCopyDataToAll(copyOutState->fe_msgbuf,
								  shardConnections->connectionList);
			}

			/* replicate row to shard placements */
			resetStringInfo(copyOutState->fe_msgbuf);
			AppendCopyRowData(columnValues, columnNulls, tupleDescriptor,
							  copyOutState, columnOutputFunctions);
			SendCopyDataToAll(copyOutState->fe_msgbuf, shardConnections->connectionList);

			processedRowCount += 1;
		}

		connectionList = ConnectionList(shardConnectionHash);

		/* send binary footers to all shard placements */
		resetStringInfo(copyOutState->fe_msgbuf);
		AppendCopyBinaryFooters(copyOutState);
		SendCopyDataToAll(copyOutState->fe_msgbuf, connectionList);

		/* all lines have been copied, stop showing line number in errors */
		error_context_stack = errorCallback.previous;

		/* close the COPY input on all shard placements */
		EndRemoteCopy(connectionList, true);

		if (CopyTransactionManager == TRANSACTION_MANAGER_2PC)
		{
			PrepareRemoteTransactions(connectionList);
		}

		EndCopyFrom(copyState);
		heap_close(distributedRelation, NoLock);

		/* check for cancellation one last time before committing */
		CHECK_FOR_INTERRUPTS();
	}
	PG_CATCH();
	{
		List *abortConnectionList = NIL;

		/* roll back all transactions */
		abortConnectionList = ConnectionList(shardConnectionHash);
		EndRemoteCopy(abortConnectionList, false);
		AbortRemoteTransactions(abortConnectionList);
		CloseConnections(abortConnectionList);

		PG_RE_THROW();
	}
	PG_END_TRY();

	/*
	 * Ready to commit the transaction, this code is below the PG_TRY block because
	 * we do not want any of the transactions rolled back if a failure occurs. Instead,
	 * they should be rolled forward.
	 */
	CommitRemoteTransactions(connectionList);
	CloseConnections(connectionList);

	if (completionTag != NULL)
	{
		snprintf(completionTag, COMPLETION_TAG_BUFSIZE,
				 "COPY " UINT64_FORMAT, processedRowCount);
	}
}
コード例 #2
0
ファイル: multi_shard_transaction.c プロジェクト: zmyer/citus
/*
 * OpenTransactionsToAllShardPlacements opens connections to all placements
 * using the provided shard identifier list and returns it as a shard ID ->
 * ShardConnections hash. connectionFlags can be used to specify whether
 * the command is FOR_DML or FOR_DDL.
 */
HTAB *
OpenTransactionsToAllShardPlacements(List *shardIntervalList, int connectionFlags)
{
	HTAB *shardConnectionHash = NULL;
	ListCell *shardIntervalCell = NULL;
	List *newConnectionList = NIL;

	shardConnectionHash = CreateShardConnectionHash(CurrentMemoryContext);

	/* open connections to shards which don't have connections yet */
	foreach(shardIntervalCell, shardIntervalList)
	{
		ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell);
		uint64 shardId = shardInterval->shardId;
		ShardConnections *shardConnections = NULL;
		bool shardConnectionsFound = false;
		List *shardPlacementList = NIL;
		ListCell *placementCell = NULL;

		shardConnections = GetShardHashConnections(shardConnectionHash, shardId,
												   &shardConnectionsFound);
		if (shardConnectionsFound)
		{
			continue;
		}

		shardPlacementList = FinalizedShardPlacementList(shardId);
		if (shardPlacementList == NIL)
		{
			/* going to have to have some placements to do any work */
			ereport(ERROR, (errmsg("could not find any shard placements for the shard "
								   UINT64_FORMAT, shardId)));
		}

		foreach(placementCell, shardPlacementList)
		{
			ShardPlacement *shardPlacement = (ShardPlacement *) lfirst(placementCell);
			MultiConnection *connection = NULL;

			WorkerNode *workerNode = FindWorkerNode(shardPlacement->nodeName,
													shardPlacement->nodePort);
			if (workerNode == NULL)
			{
				ereport(ERROR, (errmsg("could not find worker node %s:%d",
									   shardPlacement->nodeName,
									   shardPlacement->nodePort)));
			}

			connection = StartPlacementConnection(connectionFlags,
												  shardPlacement,
												  NULL);

			ClaimConnectionExclusively(connection);

			shardConnections->connectionList = lappend(shardConnections->connectionList,
													   connection);

			newConnectionList = lappend(newConnectionList, connection);

			/*
			 * Every individual failure should cause entire distributed
			 * transaction to fail.
			 */
			MarkRemoteTransactionCritical(connection);
		}
コード例 #3
0
/*
 * OpenTransactionsForAllTasks opens a connection for each task,
 * taking into account which shards are read and modified by the task
 * to select the appopriate connection, or error out if no appropriate
 * connection can be found. The set of connections is returned as an
 * anchor shard ID -> ShardConnections hash.
 */
HTAB *
OpenTransactionsForAllTasks(List *taskList, int connectionFlags)
{
	HTAB *shardConnectionHash = NULL;
	ListCell *taskCell = NULL;
	List *newConnectionList = NIL;

	shardConnectionHash = CreateShardConnectionHash(CurrentMemoryContext);

	connectionFlags |= CONNECTION_PER_PLACEMENT;

	/* open connections to shards which don't have connections yet */
	foreach(taskCell, taskList)
	{
		Task *task = (Task *) lfirst(taskCell);
		ShardPlacementAccessType accessType = PLACEMENT_ACCESS_SELECT;
		uint64 shardId = task->anchorShardId;
		ShardConnections *shardConnections = NULL;
		bool shardConnectionsFound = false;
		List *shardPlacementList = NIL;
		ListCell *placementCell = NULL;

		shardConnections = GetShardHashConnections(shardConnectionHash, shardId,
												   &shardConnectionsFound);
		if (shardConnectionsFound)
		{
			continue;
		}

		shardPlacementList = FinalizedShardPlacementList(shardId);
		if (shardPlacementList == NIL)
		{
			/* going to have to have some placements to do any work */
			ereport(ERROR, (errmsg("could not find any shard placements for the shard "
								   UINT64_FORMAT, shardId)));
		}

		if (task->taskType == MODIFY_TASK)
		{
			accessType = PLACEMENT_ACCESS_DML;
		}
		else
		{
			/* can only open connections for DDL and DML commands */
			Assert(task->taskType == DDL_TASK || VACUUM_ANALYZE_TASK);

			accessType = PLACEMENT_ACCESS_DDL;
		}

		foreach(placementCell, shardPlacementList)
		{
			ShardPlacement *shardPlacement = (ShardPlacement *) lfirst(placementCell);
			ShardPlacementAccess placementModification;
			List *placementAccessList = NIL;
			MultiConnection *connection = NULL;

			WorkerNode *workerNode = FindWorkerNode(shardPlacement->nodeName,
													shardPlacement->nodePort);
			if (workerNode == NULL)
			{
				ereport(ERROR, (errmsg("could not find worker node %s:%d",
									   shardPlacement->nodeName,
									   shardPlacement->nodePort)));
			}

			/* add placement access for modification */
			placementModification.placement = shardPlacement;
			placementModification.accessType = accessType;

			placementAccessList = lappend(placementAccessList, &placementModification);

			if (accessType == PLACEMENT_ACCESS_DDL)
			{
				List *placementDDLList = BuildPlacementDDLList(shardPlacement->groupId,
															   task->relationShardList);

				/*
				 * All relations appearing inter-shard DDL commands should be marked
				 * with DDL access.
				 */
				placementAccessList = list_concat(placementAccessList, placementDDLList);
			}
			else
			{
				List *placementSelectList =
					BuildPlacementSelectList(shardPlacement->groupId,
											 task->relationShardList);

				/* add additional placement accesses for subselects (e.g. INSERT .. SELECT) */
				placementAccessList =
					list_concat(placementAccessList, placementSelectList);
			}

			/*
			 * Find a connection that sees preceding writes and cannot self-deadlock,
			 * or error out if no such connection exists.
			 */
			connection = StartPlacementListConnection(connectionFlags,
													  placementAccessList, NULL);

			ClaimConnectionExclusively(connection);

			shardConnections->connectionList = lappend(shardConnections->connectionList,
													   connection);

			newConnectionList = lappend(newConnectionList, connection);

			/*
			 * Every individual failure should cause entire distributed
			 * transaction to fail.
			 */
			MarkRemoteTransactionCritical(connection);
		}