bool LocalAccessChainConvertPass::ConvertLocalAccessChains(Function* func) {
  FindTargetVars(func);
  // Replace access chains of all targeted variables with equivalent
  // extract and insert sequences
  bool modified = false;
  for (auto bi = func->begin(); bi != func->end(); ++bi) {
    std::vector<Instruction*> dead_instructions;
    for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
      switch (ii->opcode()) {
        case SpvOpLoad: {
          uint32_t varId;
          Instruction* ptrInst = GetPtr(&*ii, &varId);
          if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
          if (!IsTargetVar(varId)) break;
          std::vector<std::unique_ptr<Instruction>> newInsts;
          ReplaceAccessChainLoad(ptrInst, &*ii);
          modified = true;
        } break;
        case SpvOpStore: {
          uint32_t varId;
          Instruction* ptrInst = GetPtr(&*ii, &varId);
          if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
          if (!IsTargetVar(varId)) break;
          std::vector<std::unique_ptr<Instruction>> newInsts;
          uint32_t valId = ii->GetSingleWordInOperand(kStoreValIdInIdx);
          GenAccessChainStoreReplacement(ptrInst, valId, &newInsts);
          dead_instructions.push_back(&*ii);
          ++ii;
          ii = ii.InsertBefore(std::move(newInsts));
          ++ii;
          ++ii;
          modified = true;
        } break;
        default:
          break;
      }
    }

    while (!dead_instructions.empty()) {
      Instruction* inst = dead_instructions.back();
      dead_instructions.pop_back();
      DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
        auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
                           other_inst);
        if (i != dead_instructions.end()) {
          dead_instructions.erase(i);
        }
      });
    }
  }
  return modified;
}
bool LocalAccessChainConvertPass::ConvertLocalAccessChains(ir::Function* func) {
  FindTargetVars(func);
  // Replace access chains of all targeted variables with equivalent
  // extract and insert sequences
  bool modified = false;
  for (auto bi = func->begin(); bi != func->end(); ++bi) {
    for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
      switch (ii->opcode()) {
      case SpvOpLoad: {
        uint32_t varId;
        ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
        if (!IsNonPtrAccessChain(ptrInst->opcode()))
          break;
        if (!IsTargetVar(varId))
          break;
        std::vector<std::unique_ptr<ir::Instruction>> newInsts;
        uint32_t replId =
            GenAccessChainLoadReplacement(ptrInst, &newInsts);
        ReplaceAndDeleteLoad(&*ii, replId);
        ++ii;
        ii = ii.InsertBefore(&newInsts);
        ++ii;
        modified = true;
      } break;
      case SpvOpStore: {
        uint32_t varId;
        ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
        if (!IsNonPtrAccessChain(ptrInst->opcode()))
          break;
        if (!IsTargetVar(varId))
          break;
        std::vector<std::unique_ptr<ir::Instruction>> newInsts;
        uint32_t valId = ii->GetSingleWordInOperand(kStoreValIdInIdx);
        GenAccessChainStoreReplacement(ptrInst, valId, &newInsts);
        def_use_mgr_->KillInst(&*ii);
        DeleteIfUseless(ptrInst);
        ++ii;
        ii = ii.InsertBefore(&newInsts);
        ++ii;
        ++ii;
        modified = true;
      } break;
      default:
        break;
      }
    }
  }
  return modified;
}
void LocalAccessChainConvertPass::FindTargetVars(ir::Function* func) {
  for (auto bi = func->begin(); bi != func->end(); ++bi) {
    for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
      switch (ii->opcode()) {
      case SpvOpStore:
      case SpvOpLoad: {
        uint32_t varId;
        ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
        if (!IsTargetVar(varId))
          break;
        const SpvOp op = ptrInst->opcode();
        // Rule out variables with non-supported refs eg function calls
        if (!HasOnlySupportedRefs(varId)) {
          seen_non_target_vars_.insert(varId);
          seen_target_vars_.erase(varId);
          break;
        }
        // Rule out variables with nested access chains
        // TODO(): Convert nested access chains
        if (IsNonPtrAccessChain(op) &&
            ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx) != varId) {
          seen_non_target_vars_.insert(varId);
          seen_target_vars_.erase(varId);
          break;
        }
        // Rule out variables accessed with non-constant indices
        if (!IsConstantIndexAccessChain(ptrInst)) {
          seen_non_target_vars_.insert(varId);
          seen_target_vars_.erase(varId);
          break;
        }
      } break;
      default:
        break;
      }
    }
  }
}
bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim(
    Function* func) {
  // Perform local store/load, load/load and store/store elimination
  // on each block
  bool modified = false;
  std::vector<Instruction*> instructions_to_kill;
  std::unordered_set<Instruction*> instructions_to_save;
  for (auto bi = func->begin(); bi != func->end(); ++bi) {
    var2store_.clear();
    var2load_.clear();
    auto next = bi->begin();
    for (auto ii = next; ii != bi->end(); ii = next) {
      ++next;
      switch (ii->opcode()) {
        case SpvOpStore: {
          // Verify store variable is target type
          uint32_t varId;
          Instruction* ptrInst = GetPtr(&*ii, &varId);
          if (!IsTargetVar(varId)) continue;
          if (!HasOnlySupportedRefs(varId)) continue;
          // If a store to the whole variable, remember it for succeeding
          // loads and stores. Otherwise forget any previous store to that
          // variable.
          if (ptrInst->opcode() == SpvOpVariable) {
            // If a previous store to same variable, mark the store
            // for deletion if not still used.
            auto prev_store = var2store_.find(varId);
            if (prev_store != var2store_.end() &&
                instructions_to_save.count(prev_store->second) == 0) {
              instructions_to_kill.push_back(prev_store->second);
              modified = true;
            }

            bool kill_store = false;
            auto li = var2load_.find(varId);
            if (li != var2load_.end()) {
              if (ii->GetSingleWordInOperand(kStoreValIdInIdx) ==
                  li->second->result_id()) {
                // We are storing the same value that already exists in the
                // memory location.  The store does nothing.
                kill_store = true;
              }
            }

            if (!kill_store) {
              var2store_[varId] = &*ii;
              var2load_.erase(varId);
            } else {
              instructions_to_kill.push_back(&*ii);
              modified = true;
            }
          } else {
            assert(IsNonPtrAccessChain(ptrInst->opcode()));
            var2store_.erase(varId);
            var2load_.erase(varId);
          }
        } break;
        case SpvOpLoad: {
          // Verify store variable is target type
          uint32_t varId;
          Instruction* ptrInst = GetPtr(&*ii, &varId);
          if (!IsTargetVar(varId)) continue;
          if (!HasOnlySupportedRefs(varId)) continue;
          uint32_t replId = 0;
          if (ptrInst->opcode() == SpvOpVariable) {
            // If a load from a variable, look for a previous store or
            // load from that variable and use its value.
            auto si = var2store_.find(varId);
            if (si != var2store_.end()) {
              replId = si->second->GetSingleWordInOperand(kStoreValIdInIdx);
            } else {
              auto li = var2load_.find(varId);
              if (li != var2load_.end()) {
                replId = li->second->result_id();
              }
            }
          } else {
            // If a partial load of a previously seen store, remember
            // not to delete the store.
            auto si = var2store_.find(varId);
            if (si != var2store_.end()) instructions_to_save.insert(si->second);
          }
          if (replId != 0) {
            // replace load's result id and delete load
            context()->KillNamesAndDecorates(&*ii);
            context()->ReplaceAllUsesWith(ii->result_id(), replId);
            instructions_to_kill.push_back(&*ii);
            modified = true;
          } else {
            if (ptrInst->opcode() == SpvOpVariable)
              var2load_[varId] = &*ii;  // register load
          }
        } break;
        case SpvOpFunctionCall: {
          // Conservatively assume all locals are redefined for now.
          // TODO(): Handle more optimally
          var2store_.clear();
          var2load_.clear();
        } break;
        default:
          break;
      }
    }
  }

  for (Instruction* inst : instructions_to_kill) {
    context()->KillInst(inst);
  }

  return modified;
}