/* * CitusCopyFrom implements the COPY table_name FROM ... for hash-partitioned * and range-partitioned tables. */ void CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) { Oid tableId = RangeVarGetRelid(copyStatement->relation, NoLock, false); char *relationName = get_rel_name(tableId); Relation distributedRelation = NULL; char partitionMethod = '\0'; Var *partitionColumn = NULL; TupleDesc tupleDescriptor = NULL; uint32 columnCount = 0; Datum *columnValues = NULL; bool *columnNulls = NULL; TypeCacheEntry *typeEntry = NULL; FmgrInfo *hashFunction = NULL; FmgrInfo *compareFunction = NULL; int shardCount = 0; List *shardIntervalList = NULL; ShardInterval **shardIntervalCache = NULL; bool useBinarySearch = false; HTAB *shardConnectionHash = NULL; ShardConnections *shardConnections = NULL; List *connectionList = NIL; EState *executorState = NULL; MemoryContext executorTupleContext = NULL; ExprContext *executorExpressionContext = NULL; CopyState copyState = NULL; CopyOutState copyOutState = NULL; FmgrInfo *columnOutputFunctions = NULL; uint64 processedRowCount = 0; /* disallow COPY to/from file or program except for superusers */ if (copyStatement->filename != NULL && !superuser()) { if (copyStatement->is_program) { ereport(ERROR, (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE), errmsg("must be superuser to COPY to or from an external program"), errhint("Anyone can COPY to stdout or from stdin. " "psql's \\copy command also works for anyone."))); } else { ereport(ERROR, (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE), errmsg("must be superuser to COPY to or from a file"), errhint("Anyone can COPY to stdout or from stdin. " "psql's \\copy command also works for anyone."))); } } partitionColumn = PartitionColumn(tableId, 0); partitionMethod = PartitionMethod(tableId); if (partitionMethod != DISTRIBUTE_BY_RANGE && partitionMethod != DISTRIBUTE_BY_HASH) { ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), errmsg("COPY is only supported for hash- and " "range-partitioned tables"))); } /* resolve hash function for partition column */ typeEntry = lookup_type_cache(partitionColumn->vartype, TYPECACHE_HASH_PROC_FINFO); hashFunction = &(typeEntry->hash_proc_finfo); /* resolve compare function for shard intervals */ compareFunction = ShardIntervalCompareFunction(partitionColumn, partitionMethod); /* allocate column values and nulls arrays */ distributedRelation = heap_open(tableId, RowExclusiveLock); tupleDescriptor = RelationGetDescr(distributedRelation); columnCount = tupleDescriptor->natts; columnValues = palloc0(columnCount * sizeof(Datum)); columnNulls = palloc0(columnCount * sizeof(bool)); /* load the list of shards and verify that we have shards to copy into */ shardIntervalList = LoadShardIntervalList(tableId); if (shardIntervalList == NIL) { if (partitionMethod == DISTRIBUTE_BY_HASH) { ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), errmsg("could not find any shards into which to copy"), errdetail("No shards exist for distributed table \"%s\".", relationName), errhint("Run master_create_worker_shards to create shards " "and try again."))); } else { ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), errmsg("could not find any shards into which to copy"), errdetail("No shards exist for distributed table \"%s\".", relationName))); } } /* prevent concurrent placement changes and non-commutative DML statements */ LockAllShards(shardIntervalList); /* initialize the shard interval cache */ shardCount = list_length(shardIntervalList); shardIntervalCache = SortedShardIntervalArray(shardIntervalList); /* determine whether to use binary search */ if (partitionMethod != DISTRIBUTE_BY_HASH || !IsUniformHashDistribution(shardIntervalCache, shardCount)) { useBinarySearch = true; } /* initialize copy state to read from COPY data source */ copyState = BeginCopyFrom(distributedRelation, copyStatement->filename, copyStatement->is_program, copyStatement->attlist, copyStatement->options); executorState = CreateExecutorState(); executorTupleContext = GetPerTupleMemoryContext(executorState); executorExpressionContext = GetPerTupleExprContext(executorState); copyOutState = (CopyOutState) palloc0(sizeof(CopyOutStateData)); copyOutState->binary = true; copyOutState->fe_msgbuf = makeStringInfo(); copyOutState->rowcontext = executorTupleContext; columnOutputFunctions = ColumnOutputFunctions(tupleDescriptor, copyOutState->binary); /* * Create a mapping of shard id to a connection for each of its placements. * The hash should be initialized before the PG_TRY, since it is used and * PG_CATCH. Otherwise, it may be undefined in the PG_CATCH (see sigsetjmp * documentation). */ shardConnectionHash = CreateShardConnectionHash(); /* we use a PG_TRY block to roll back on errors (e.g. in NextCopyFrom) */ PG_TRY(); { ErrorContextCallback errorCallback; /* set up callback to identify error line number */ errorCallback.callback = CopyFromErrorCallback; errorCallback.arg = (void *) copyState; errorCallback.previous = error_context_stack; error_context_stack = &errorCallback; /* ensure transactions have unique names on worker nodes */ InitializeDistributedTransaction(); while (true) { bool nextRowFound = false; Datum partitionColumnValue = 0; ShardInterval *shardInterval = NULL; int64 shardId = 0; bool shardConnectionsFound = false; MemoryContext oldContext = NULL; ResetPerTupleExprContext(executorState); oldContext = MemoryContextSwitchTo(executorTupleContext); /* parse a row from the input */ nextRowFound = NextCopyFrom(copyState, executorExpressionContext, columnValues, columnNulls, NULL); if (!nextRowFound) { MemoryContextSwitchTo(oldContext); break; } CHECK_FOR_INTERRUPTS(); /* find the partition column value */ if (columnNulls[partitionColumn->varattno - 1]) { ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), errmsg("cannot copy row with NULL value " "in partition column"))); } partitionColumnValue = columnValues[partitionColumn->varattno - 1]; /* find the shard interval and id for the partition column value */ shardInterval = FindShardInterval(partitionColumnValue, shardIntervalCache, shardCount, partitionMethod, compareFunction, hashFunction, useBinarySearch); if (shardInterval == NULL) { ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), errmsg("could not find shard for partition column " "value"))); } shardId = shardInterval->shardId; MemoryContextSwitchTo(oldContext); /* get existing connections to the shard placements, if any */ shardConnections = GetShardConnections(shardConnectionHash, shardId, &shardConnectionsFound); if (!shardConnectionsFound) { /* open connections and initiate COPY on shard placements */ OpenCopyTransactions(copyStatement, shardConnections); /* send binary headers to shard placements */ resetStringInfo(copyOutState->fe_msgbuf); AppendCopyBinaryHeaders(copyOutState); SendCopyDataToAll(copyOutState->fe_msgbuf, shardConnections->connectionList); } /* replicate row to shard placements */ resetStringInfo(copyOutState->fe_msgbuf); AppendCopyRowData(columnValues, columnNulls, tupleDescriptor, copyOutState, columnOutputFunctions); SendCopyDataToAll(copyOutState->fe_msgbuf, shardConnections->connectionList); processedRowCount += 1; } connectionList = ConnectionList(shardConnectionHash); /* send binary footers to all shard placements */ resetStringInfo(copyOutState->fe_msgbuf); AppendCopyBinaryFooters(copyOutState); SendCopyDataToAll(copyOutState->fe_msgbuf, connectionList); /* all lines have been copied, stop showing line number in errors */ error_context_stack = errorCallback.previous; /* close the COPY input on all shard placements */ EndRemoteCopy(connectionList, true); if (CopyTransactionManager == TRANSACTION_MANAGER_2PC) { PrepareRemoteTransactions(connectionList); } EndCopyFrom(copyState); heap_close(distributedRelation, NoLock); /* check for cancellation one last time before committing */ CHECK_FOR_INTERRUPTS(); } PG_CATCH(); { List *abortConnectionList = NIL; /* roll back all transactions */ abortConnectionList = ConnectionList(shardConnectionHash); EndRemoteCopy(abortConnectionList, false); AbortRemoteTransactions(abortConnectionList); CloseConnections(abortConnectionList); PG_RE_THROW(); } PG_END_TRY(); /* * Ready to commit the transaction, this code is below the PG_TRY block because * we do not want any of the transactions rolled back if a failure occurs. Instead, * they should be rolled forward. */ CommitRemoteTransactions(connectionList); CloseConnections(connectionList); if (completionTag != NULL) { snprintf(completionTag, COMPLETION_TAG_BUFSIZE, "COPY " UINT64_FORMAT, processedRowCount); } }
/* * OpenTransactionsToAllShardPlacements opens connections to all placements * using the provided shard identifier list and returns it as a shard ID -> * ShardConnections hash. connectionFlags can be used to specify whether * the command is FOR_DML or FOR_DDL. */ HTAB * OpenTransactionsToAllShardPlacements(List *shardIntervalList, int connectionFlags) { HTAB *shardConnectionHash = NULL; ListCell *shardIntervalCell = NULL; List *newConnectionList = NIL; shardConnectionHash = CreateShardConnectionHash(CurrentMemoryContext); /* open connections to shards which don't have connections yet */ foreach(shardIntervalCell, shardIntervalList) { ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); uint64 shardId = shardInterval->shardId; ShardConnections *shardConnections = NULL; bool shardConnectionsFound = false; List *shardPlacementList = NIL; ListCell *placementCell = NULL; shardConnections = GetShardHashConnections(shardConnectionHash, shardId, &shardConnectionsFound); if (shardConnectionsFound) { continue; } shardPlacementList = FinalizedShardPlacementList(shardId); if (shardPlacementList == NIL) { /* going to have to have some placements to do any work */ ereport(ERROR, (errmsg("could not find any shard placements for the shard " UINT64_FORMAT, shardId))); } foreach(placementCell, shardPlacementList) { ShardPlacement *shardPlacement = (ShardPlacement *) lfirst(placementCell); MultiConnection *connection = NULL; WorkerNode *workerNode = FindWorkerNode(shardPlacement->nodeName, shardPlacement->nodePort); if (workerNode == NULL) { ereport(ERROR, (errmsg("could not find worker node %s:%d", shardPlacement->nodeName, shardPlacement->nodePort))); } connection = StartPlacementConnection(connectionFlags, shardPlacement, NULL); ClaimConnectionExclusively(connection); shardConnections->connectionList = lappend(shardConnections->connectionList, connection); newConnectionList = lappend(newConnectionList, connection); /* * Every individual failure should cause entire distributed * transaction to fail. */ MarkRemoteTransactionCritical(connection); }
/* * OpenTransactionsForAllTasks opens a connection for each task, * taking into account which shards are read and modified by the task * to select the appopriate connection, or error out if no appropriate * connection can be found. The set of connections is returned as an * anchor shard ID -> ShardConnections hash. */ HTAB * OpenTransactionsForAllTasks(List *taskList, int connectionFlags) { HTAB *shardConnectionHash = NULL; ListCell *taskCell = NULL; List *newConnectionList = NIL; shardConnectionHash = CreateShardConnectionHash(CurrentMemoryContext); connectionFlags |= CONNECTION_PER_PLACEMENT; /* open connections to shards which don't have connections yet */ foreach(taskCell, taskList) { Task *task = (Task *) lfirst(taskCell); ShardPlacementAccessType accessType = PLACEMENT_ACCESS_SELECT; uint64 shardId = task->anchorShardId; ShardConnections *shardConnections = NULL; bool shardConnectionsFound = false; List *shardPlacementList = NIL; ListCell *placementCell = NULL; shardConnections = GetShardHashConnections(shardConnectionHash, shardId, &shardConnectionsFound); if (shardConnectionsFound) { continue; } shardPlacementList = FinalizedShardPlacementList(shardId); if (shardPlacementList == NIL) { /* going to have to have some placements to do any work */ ereport(ERROR, (errmsg("could not find any shard placements for the shard " UINT64_FORMAT, shardId))); } if (task->taskType == MODIFY_TASK) { accessType = PLACEMENT_ACCESS_DML; } else { /* can only open connections for DDL and DML commands */ Assert(task->taskType == DDL_TASK || VACUUM_ANALYZE_TASK); accessType = PLACEMENT_ACCESS_DDL; } foreach(placementCell, shardPlacementList) { ShardPlacement *shardPlacement = (ShardPlacement *) lfirst(placementCell); ShardPlacementAccess placementModification; List *placementAccessList = NIL; MultiConnection *connection = NULL; WorkerNode *workerNode = FindWorkerNode(shardPlacement->nodeName, shardPlacement->nodePort); if (workerNode == NULL) { ereport(ERROR, (errmsg("could not find worker node %s:%d", shardPlacement->nodeName, shardPlacement->nodePort))); } /* add placement access for modification */ placementModification.placement = shardPlacement; placementModification.accessType = accessType; placementAccessList = lappend(placementAccessList, &placementModification); if (accessType == PLACEMENT_ACCESS_DDL) { List *placementDDLList = BuildPlacementDDLList(shardPlacement->groupId, task->relationShardList); /* * All relations appearing inter-shard DDL commands should be marked * with DDL access. */ placementAccessList = list_concat(placementAccessList, placementDDLList); } else { List *placementSelectList = BuildPlacementSelectList(shardPlacement->groupId, task->relationShardList); /* add additional placement accesses for subselects (e.g. INSERT .. SELECT) */ placementAccessList = list_concat(placementAccessList, placementSelectList); } /* * Find a connection that sees preceding writes and cannot self-deadlock, * or error out if no such connection exists. */ connection = StartPlacementListConnection(connectionFlags, placementAccessList, NULL); ClaimConnectionExclusively(connection); shardConnections->connectionList = lappend(shardConnections->connectionList, connection); newConnectionList = lappend(newConnectionList, connection); /* * Every individual failure should cause entire distributed * transaction to fail. */ MarkRemoteTransactionCritical(connection); }