bool CommonUniformElimPass::CommonUniformLoadElimBlock(ir::Function* func) {
  bool modified = false;
  for (auto& blk : *func) {
    uniform2load_id_.clear();
    for (auto ii = blk.begin(); ii != blk.end(); ++ii) {
      if (ii->opcode() != SpvOpLoad)
        continue;
      uint32_t varId;
      ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
      if (ptrInst->opcode() != SpvOpVariable)
        continue;
      if (!IsUniformVar(varId))
        continue;
      if (!IsSamplerOrImageVar(varId))
        continue;
      if (HasUnsupportedDecorates(ii->result_id()))
        continue;
      uint32_t replId;
      const auto uItr = uniform2load_id_.find(varId);
      if (uItr != uniform2load_id_.end()) {
        replId = uItr->second;
      }
      else {
        uniform2load_id_[varId] = ii->result_id();
        continue;
      }
      ReplaceAndDeleteLoad(&*ii, replId, ptrInst);
      modified = true;
    }
  }
  return modified;
}
bool CommonUniformElimPass::UniformAccessChainConvert(ir::Function* func) {
  bool modified = false;
  for (auto bi = func->begin(); bi != func->end(); ++bi) {
    for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
      if (ii->opcode() != SpvOpLoad)
        continue;
      uint32_t varId;
      ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
      if (!IsNonPtrAccessChain(ptrInst->opcode()))
        continue;
      // Do not convert nested access chains
      if (ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx) != varId)
        continue;
      if (!IsUniformVar(varId))
        continue;
      if (!IsConstantIndexAccessChain(ptrInst))
        continue;
      if (HasUnsupportedDecorates(ii->result_id()))
        continue;
      if (HasUnsupportedDecorates(ptrInst->result_id()))
        continue;
      std::vector<std::unique_ptr<ir::Instruction>> newInsts;
      uint32_t replId;
      GenACLoadRepl(ptrInst, &newInsts, &replId);
      ReplaceAndDeleteLoad(&*ii, replId, ptrInst);
      ++ii;
      ii = ii.InsertBefore(&newInsts);
      ++ii;
      modified = true;
    }
  }
  return modified;
}
bool CommonUniformElimPass::CommonExtractElimination(ir::Function* func) {
  // Find all composite ids with duplicate extracts.
  for (auto bi = func->begin(); bi != func->end(); ++bi) {
    for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
      if (ii->opcode() != SpvOpCompositeExtract)
        continue;
      // TODO(greg-lunarg): Support multiple indices
      if (ii->NumInOperands() > 2)
        continue;
      if (HasUnsupportedDecorates(ii->result_id()))
        continue;
      uint32_t compId = ii->GetSingleWordInOperand(kExtractCompositeIdInIdx);
      uint32_t idx = ii->GetSingleWordInOperand(kExtractIdx0InIdx);
      comp2idx2inst_[compId][idx].push_back(&*ii);
    }
  }
  // For all defs of ids with duplicate extracts, insert new extracts
  // after def, and replace and delete old extracts
  bool modified = false;
  for (auto bi = func->begin(); bi != func->end(); ++bi) {
    for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
      const auto cItr = comp2idx2inst_.find(ii->result_id());
      if (cItr == comp2idx2inst_.end())
        continue;
      for (auto idxItr : cItr->second) {
        if (idxItr.second.size() < 2)
          continue;
        uint32_t replId = TakeNextId();
        std::unique_ptr<ir::Instruction> newExtract(new ir::Instruction(*idxItr.second.front()));
        newExtract->SetResultId(replId);
        def_use_mgr_->AnalyzeInstDefUse(&*newExtract);
        ++ii;
        ii = ii.InsertBefore(std::move(newExtract));
        for (auto instItr : idxItr.second) {
          uint32_t resId = instItr->result_id();
          KillNamesAndDecorates(resId);
          (void)def_use_mgr_->ReplaceAllUsesWith(resId, replId);
          def_use_mgr_->KillInst(instItr);
        }
        modified = true;
      }
    }
  }
  return modified;
}
bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim(
    Function* func) {
  // Perform local store/load, load/load and store/store elimination
  // on each block
  bool modified = false;
  std::vector<Instruction*> instructions_to_kill;
  std::unordered_set<Instruction*> instructions_to_save;
  for (auto bi = func->begin(); bi != func->end(); ++bi) {
    var2store_.clear();
    var2load_.clear();
    auto next = bi->begin();
    for (auto ii = next; ii != bi->end(); ii = next) {
      ++next;
      switch (ii->opcode()) {
        case SpvOpStore: {
          // Verify store variable is target type
          uint32_t varId;
          Instruction* ptrInst = GetPtr(&*ii, &varId);
          if (!IsTargetVar(varId)) continue;
          if (!HasOnlySupportedRefs(varId)) continue;
          // If a store to the whole variable, remember it for succeeding
          // loads and stores. Otherwise forget any previous store to that
          // variable.
          if (ptrInst->opcode() == SpvOpVariable) {
            // If a previous store to same variable, mark the store
            // for deletion if not still used.
            auto prev_store = var2store_.find(varId);
            if (prev_store != var2store_.end() &&
                instructions_to_save.count(prev_store->second) == 0) {
              instructions_to_kill.push_back(prev_store->second);
              modified = true;
            }

            bool kill_store = false;
            auto li = var2load_.find(varId);
            if (li != var2load_.end()) {
              if (ii->GetSingleWordInOperand(kStoreValIdInIdx) ==
                  li->second->result_id()) {
                // We are storing the same value that already exists in the
                // memory location.  The store does nothing.
                kill_store = true;
              }
            }

            if (!kill_store) {
              var2store_[varId] = &*ii;
              var2load_.erase(varId);
            } else {
              instructions_to_kill.push_back(&*ii);
              modified = true;
            }
          } else {
            assert(IsNonPtrAccessChain(ptrInst->opcode()));
            var2store_.erase(varId);
            var2load_.erase(varId);
          }
        } break;
        case SpvOpLoad: {
          // Verify store variable is target type
          uint32_t varId;
          Instruction* ptrInst = GetPtr(&*ii, &varId);
          if (!IsTargetVar(varId)) continue;
          if (!HasOnlySupportedRefs(varId)) continue;
          uint32_t replId = 0;
          if (ptrInst->opcode() == SpvOpVariable) {
            // If a load from a variable, look for a previous store or
            // load from that variable and use its value.
            auto si = var2store_.find(varId);
            if (si != var2store_.end()) {
              replId = si->second->GetSingleWordInOperand(kStoreValIdInIdx);
            } else {
              auto li = var2load_.find(varId);
              if (li != var2load_.end()) {
                replId = li->second->result_id();
              }
            }
          } else {
            // If a partial load of a previously seen store, remember
            // not to delete the store.
            auto si = var2store_.find(varId);
            if (si != var2store_.end()) instructions_to_save.insert(si->second);
          }
          if (replId != 0) {
            // replace load's result id and delete load
            context()->KillNamesAndDecorates(&*ii);
            context()->ReplaceAllUsesWith(ii->result_id(), replId);
            instructions_to_kill.push_back(&*ii);
            modified = true;
          } else {
            if (ptrInst->opcode() == SpvOpVariable)
              var2load_[varId] = &*ii;  // register load
          }
        } break;
        case SpvOpFunctionCall: {
          // Conservatively assume all locals are redefined for now.
          // TODO(): Handle more optimally
          var2store_.clear();
          var2load_.clear();
        } break;
        default:
          break;
      }
    }
  }

  for (Instruction* inst : instructions_to_kill) {
    context()->KillInst(inst);
  }

  return modified;
}
bool CommonUniformElimPass::CommonUniformLoadElimination(ir::Function* func) {
  // Process all blocks in structured order. This is just one way (the
  // simplest?) to keep track of the most recent block outside of control
  // flow, used to copy common instructions, guaranteed to dominate all
  // following load sites.
  std::list<ir::BasicBlock*> structuredOrder;
  ComputeStructuredOrder(func, &structuredOrder);
  uniform2load_id_.clear();
  bool modified = false;
  // Find insertion point in first block to copy non-dominating loads.
  auto insertItr = func->begin()->begin();
  while (insertItr->opcode() == SpvOpVariable ||
      insertItr->opcode() == SpvOpNop)
    ++insertItr;
  uint32_t mergeBlockId = 0;
  for (auto bi = structuredOrder.begin(); bi != structuredOrder.end(); ++bi) {
    ir::BasicBlock* bp = *bi;
    // Check if we are exiting outermost control construct. If so, remember
    // new load insertion point. Trying to keep register pressure down.
    if (mergeBlockId == bp->id()) {
      mergeBlockId = 0;
      insertItr = bp->begin();
    }
    for (auto ii = bp->begin(); ii != bp->end(); ++ii) {
      if (ii->opcode() != SpvOpLoad)
        continue;
      uint32_t varId;
      ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
      if (ptrInst->opcode() != SpvOpVariable)
        continue;
      if (!IsUniformVar(varId))
        continue;
      if (IsSamplerOrImageVar(varId))
        continue;
      if (HasUnsupportedDecorates(ii->result_id()))
        continue;
      uint32_t replId;
      const auto uItr = uniform2load_id_.find(varId);
      if (uItr != uniform2load_id_.end()) {
        replId = uItr->second;
      }
      else {
        if (mergeBlockId == 0) {
          // Load is in dominating block; just remember it
          uniform2load_id_[varId] = ii->result_id();
          continue;
        }
        else {
          // Copy load into most recent dominating block and remember it
          replId = TakeNextId();
          std::unique_ptr<ir::Instruction> newLoad(new ir::Instruction(SpvOpLoad,
            ii->type_id(), replId, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}}}));
          def_use_mgr_->AnalyzeInstDefUse(&*newLoad);
          insertItr = insertItr.InsertBefore(std::move(newLoad));
          ++insertItr;
          uniform2load_id_[varId] = replId;
        }
      }
      ReplaceAndDeleteLoad(&*ii, replId, ptrInst);
      modified = true;
    }
    // If we are outside of any control construct and entering one, remember
    // the id of the merge block
    if (mergeBlockId == 0) {
      uint32_t dummy;
      mergeBlockId = MergeBlockIdIfAny(*bp, &dummy);
    }
  }
  return modified;
}
Esempio n. 6
0
bool
TaintAnalysis::transfer(const Function& func, const DataflowNode& node_, NodeState& state, const std::vector<Lattice*>& dfInfo) {
    static size_t ncalls = 0;
    if (debug) {
        *debug <<"TaintAnalysis::transfer-" <<++ncalls <<"(func=" <<func.get_name() <<",\n"
               <<"                        node={" <<StringUtility::makeOneLine(node_.toString()) <<"},\n"
               <<"                        state={" <<state.str(this, "                            ") <<",\n"
               <<"                        dfInfo[" <<dfInfo.size() <<"]={...})\n";
    }

    SgNode *node = node_.getNode();
    assert(!dfInfo.empty());
    FiniteVarsExprsProductLattice *prodLat = dynamic_cast<FiniteVarsExprsProductLattice*>(dfInfo.front());
    bool modified = magic_tainted(node, prodLat); // some values are automatically tainted based on their name

    // Process AST nodes that transfer taintedness.  Most of these operations have one or more inputs from which a result
    // is always calculated the same way.  So we just gather up the inputs and do the calculation at the very end of this
    // function.  The other operations are handled individually within their "if" bodies.
    TaintLattice *result = NULL;                    // result pointer into the taint lattice
    std::vector<TaintLattice*> inputs;              // input pointers into the taint lattice
    if (isSgAssignInitializer(node)) {
        // as in "int a = b"
        SgAssignInitializer *xop = isSgAssignInitializer(node);
        TaintLattice *in1 = dynamic_cast<TaintLattice*>(prodLat->getVarLattice(SgExpr2Var(xop->get_operand())));
        inputs.push_back(in1);

    } else if (isSgAggregateInitializer(node)) {
        // as in "int a[1] = {b}"
        SgAggregateInitializer *xop = isSgAggregateInitializer(node);
        const SgExpressionPtrList &exprs = xop->get_initializers()->get_expressions();
        for (size_t i=0; i<exprs.size(); ++i) {
            varID in_id = SgExpr2Var(exprs[i]);
            TaintLattice *in = dynamic_cast<TaintLattice*>(prodLat->getVarLattice(in_id));
            inputs.push_back(in);
        }

    } else if (isSgInitializedName(node)) {
        SgInitializedName *xop = isSgInitializedName(node);
        if (xop->get_initializer()) {
            varID in1_id = SgExpr2Var(xop->get_initializer());
            TaintLattice *in1 = dynamic_cast<TaintLattice*>(prodLat->getVarLattice(in1_id));
            inputs.push_back(in1);
        }

    } else if (isSgValueExp(node)) {
        // numeric and character constants
        SgValueExp *xop = isSgValueExp(node);
        result = dynamic_cast<TaintLattice*>(prodLat->getVarLattice(SgExpr2Var(xop)));
        if (result)
            modified = result->set_vertex(TaintLattice::VERTEX_UNTAINTED);
        
    } else if (isSgAddressOfOp(node)) {
        // as in "&x".  The result taintedness has nothing to do with the value in x.
        /*void*/

    } else if (isSgBinaryOp(node)) {
        // as in "a + b"
        SgBinaryOp *xop = isSgBinaryOp(node);
        varID in1_id = SgExpr2Var(isSgExpression(xop->get_lhs_operand()));
        TaintLattice *in1 = dynamic_cast<TaintLattice*>(prodLat->getVarLattice(in1_id));
        inputs.push_back(in1);
        varID in2_id = SgExpr2Var(isSgExpression(xop->get_rhs_operand()));
        TaintLattice *in2 = dynamic_cast<TaintLattice*>(prodLat->getVarLattice(in2_id));
        inputs.push_back(in2);

        if (isSgAssignOp(node)) { // copy the rhs lattice to the lhs lattice (as well as the entire '=' expression result)
            assert(in1 && in2);
            modified = in1->meetUpdate(in2);
        }

    } else if (isSgUnaryOp(node)) {
        // as in "-a"
        SgUnaryOp *xop = isSgUnaryOp(node);
        varID in1_id = SgExpr2Var(xop->get_operand());
        TaintLattice *in1 = dynamic_cast<TaintLattice*>(prodLat->getVarLattice(in1_id));
        inputs.push_back(in1);

    } else if (isSgReturnStmt(node)) {
        // as in "return a".  The result will always be dead, so we're just doing this to get some debugging output.  Most
        // of our test inputs are functions, and the test examines the function's returned taintedness.
        SgReturnStmt *xop = isSgReturnStmt(node);
        varID in1_id = SgExpr2Var(xop->get_expression());
        TaintLattice *in1 = dynamic_cast<TaintLattice*>(prodLat->getVarLattice(in1_id));
        inputs.push_back(in1);

    }


    // Update the result lattice (unless dead) with the inputs (unless dead) by using the meedUpdate() method.  All this
    // means is that the new result will be the maximum of the old result and all inputs, where "maximum" is defined such
    // that "tainted" is greater than "untainted" (and both of them are greater than bottom/unknown).
    for (size_t i=0; i<inputs.size(); ++i)
        if (debug)
            *debug <<"TaintAnalysis::transfer: input " <<(i+1) <<" is " <<lattice_info(inputs[i]) <<"\n";
    if (!result && varID::isValidVarExp(node)) {
        varID result_id(node); // NOTE: constructor doesn't handle all SgExpression nodes, thus the next "if"
        result = dynamic_cast<TaintLattice*>(prodLat->getVarLattice(result_id));
    }
    if (!result && isSgExpression(node)) {
        varID result_id = SgExpr2Var(isSgExpression(node));
        result = dynamic_cast<TaintLattice*>(prodLat->getVarLattice(result_id));
    }
    if (result) {
        for (size_t i=0; i<inputs.size(); ++i) {
            if (inputs[i])
                modified = result->meetUpdate(inputs[i]) || modified;
        }
    }
    if (debug)
        *debug <<"TaintAnalysis::transfer: result is " <<lattice_info(result) <<(modified?" (modified)":" (not modified)") <<"\n";

    return modified;
}