// Start a new epoch. void BlockRandomizer::StartEpoch(const EpochConfiguration& config) { m_lastSeenChunkId = CHUNKID_MAX; m_config = config; if (config.m_totalEpochSizeInSamples == requestDataSize) { m_epochSize = m_sweepTotalNumberOfSamples; } else { m_epochSize = config.m_totalEpochSizeInSamples; } // Calculates starts of the epoch, prepares a new sweep if needed. m_epochStartPosition = m_epochSize * config.m_epochIndex; PrepareNewSweepIfNeeded(m_epochStartPosition); // Sets sequence cursor to the sequence that corresponds to the epoch start position. // If last epoch ended in the middle of a sequence, the cursor is moved to the next sequence in the sweep. size_t offsetInSweep = m_epochStartPosition % m_sweepTotalNumberOfSamples; size_t newOffset = m_sequenceRandomizer->Seek(offsetInSweep, m_sweep); m_globalSamplePosition = m_sweep * m_sweepTotalNumberOfSamples + newOffset; size_t epochStartFrame = config.m_epochIndex * m_epochSize; fprintf(stderr, "BlockRandomizer::StartEpoch: epoch %" PRIu64 ": frames [%" PRIu64 "..%" PRIu64 "] (first sequence at sample %" PRIu64 "), data subset %" PRIu64 " of %" PRIu64 "\n", config.m_epochIndex, epochStartFrame, epochStartFrame + m_epochSize, m_globalSamplePosition, config.m_workerRank, config.m_numberOfWorkers); }
void BlockRandomizer::SetCurrentSamplePosition(size_t currentSamplePosition) { PrepareNewSweepIfNeeded(currentSamplePosition); // Sets sequence cursor to the sequence that corresponds to the epoch start position. // If last epoch ended in the middle of a sequence, the cursor is moved to the next sequence in the sweep. size_t offsetInSweep = currentSamplePosition % m_sweepSizeInSamples; size_t newOffset = m_sequenceRandomizer->Seek(offsetInSweep, m_sweep); m_globalSamplePosition = m_sweep * m_sweepSizeInSamples + newOffset; // Check if we have some data, if not set to the end of epoch. if (m_config.m_workerRank >= m_chunkRandomizer->GetRandomizedChunks().size()) m_globalSamplePosition = m_epochStartPosition + m_epochSize; }
// Get next sequence descriptions that do not exceed sample count. // Returns true if epoch end is reached. bool BlockRandomizer::GetNextSequenceDescriptions(size_t sampleCount, std::vector<RandomizedSequenceDescription>& result) { assert(sampleCount != 0); PrepareNewSweepIfNeeded(m_globalSamplePosition); // Check epoch end. if (m_globalSamplePosition >= m_epochSize + m_epochStartPosition) { return true; } sampleCount = std::min(sampleCount, m_epochSize + m_epochStartPosition - m_globalSamplePosition); assert(sampleCount != 0); // Check that we do not go over the sweep. sampleCount = std::min(sampleCount, (long)m_sweepTotalNumberOfSamples - m_globalSamplePosition % m_sweepTotalNumberOfSamples); assert(sampleCount != 0); // Randomizing sequences result = m_sequenceRandomizer->GetNextSequenceDescriptions(sampleCount); return false; }
// Get next sequence descriptions for that worker that do not exceed global and local sample count. // Returns true if epoch end is reached. std::tuple<bool, bool, size_t, size_t> BlockRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size_t localSampleCount, std::vector<RandomizedSequenceDescription>& result, ClosedOpenChunkInterval& windowRange, bool atLeastOneSequenceNeeded) { if (globalSampleCount == 0) LogicError("Global sample count must not be zero."); if (localSampleCount == 0) LogicError("Local sample count must not be zero."); PrepareNewSweepIfNeeded(m_globalSamplePosition); auto sweepPosition = m_globalSamplePosition % m_sweepSizeInSamples; auto epochEndPosition = m_epochSize + m_epochStartPosition; // Check epoch end. if (m_globalSamplePosition >= epochEndPosition) { auto reachedEndOfEpoch = true; auto reachedEndOfSweep = (m_globalSamplePosition >= m_sweepSizeInSamples) && (sweepPosition == 0); return std::make_tuple(reachedEndOfSweep, reachedEndOfEpoch, 0, 0); } // Global sample count should not exceed the epoch. globalSampleCount = std::min(globalSampleCount, epochEndPosition - m_globalSamplePosition); // Global sample count should also not exceed the sweep. globalSampleCount = std::min(globalSampleCount, m_sweepSizeInSamples - sweepPosition); if (globalSampleCount == 0) LogicError("Global sample count must not result in zero."); std::function<bool(const RandomizedSequenceDescription*)> isLocalSequence = [this](const RandomizedSequenceDescription* s) { return s->m_chunk->m_chunkId % m_config.m_numberOfWorkers == m_config.m_workerRank; }; size_t actualNumberOfGlobalSamples = 0, actualNumberOfLocalSamples = 0; std::tie(actualNumberOfGlobalSamples, actualNumberOfLocalSamples) = m_sequenceRandomizer->GetNextSequenceDescriptions( globalSampleCount, localSampleCount, isLocalSequence, windowRange, result, atLeastOneSequenceNeeded); if (actualNumberOfLocalSamples > actualNumberOfGlobalSamples) LogicError("Local sample count cannot be greater than the global sample count."); if (m_verbosity >= Debug) fprintf(stderr, "BlockRandomizer::GetNextSequenceDescriptions(): getting %" PRIu64 " sequences for %" PRIu64 "/%" PRIu64 " requested local/global samples in sweep %" PRIu64 "\n", result.size(), localSampleCount, globalSampleCount, m_sweep); // set "reachedEndOfSweep" to true if the minibatch is last in a sweep auto reachedEndOfSweep = (sweepPosition + actualNumberOfGlobalSamples >= m_sweepSizeInSamples); // set "reachedEndOfEpoch" to true if the current batch is last in an epoch. auto reachedEndOfEpoch = (m_globalSamplePosition + actualNumberOfGlobalSamples >= epochEndPosition); // Update the global sample position. m_globalSamplePosition += actualNumberOfGlobalSamples; return std::make_tuple(reachedEndOfSweep, reachedEndOfEpoch, actualNumberOfGlobalSamples, actualNumberOfLocalSamples); }