bool LowerExpectIntrinsic::HandleSwitchExpect(SwitchInst *SI) {
  CallInst *CI = dyn_cast<CallInst>(SI->getCondition());
  if (!CI)
    return false;

  Function *Fn = CI->getCalledFunction();
  if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect)
    return false;

  Value *ArgValue = CI->getArgOperand(0);
  ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
  if (!ExpectedValue)
    return false;

  SwitchInst::CaseIt Case = SI->findCaseValue(ExpectedValue);
  unsigned n = SI->getNumCases(); // +1 for default case.
  std::vector<uint32_t> Weights(n + 1);

  Weights[0] = Case == SI->case_default() ? LikelyBranchWeight
                                          : UnlikelyBranchWeight;
  for (unsigned i = 0; i != n; ++i)
    Weights[i + 1] = i == Case.getCaseIndex() ? LikelyBranchWeight
                                              : UnlikelyBranchWeight;

  SI->setMetadata(LLVMContext::MD_prof,
                  MDBuilder(CI->getContext()).createBranchWeights(Weights));

  SI->setCondition(ArgValue);
  return true;
}
bool LowerExpectIntrinsic::HandleSwitchExpect(SwitchInst *SI) {
  CallInst *CI = dyn_cast<CallInst>(SI->getCondition());
  if (!CI)
    return false;

  Function *Fn = CI->getCalledFunction();
  if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect)
    return false;

  Value *ArgValue = CI->getArgOperand(0);
  ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
  if (!ExpectedValue)
    return false;

  LLVMContext &Context = CI->getContext();
  Type *Int32Ty = Type::getInt32Ty(Context);

  SwitchInst::CaseIt Case = SI->findCaseValue(ExpectedValue);
  std::vector<Value *> Vec;
  unsigned n = SI->getNumCases();
  Vec.resize(n + 1 + 1); // +1 for MDString and +1 for default case

  Vec[0] = MDString::get(Context, "branch_weights");
  Vec[1] = ConstantInt::get(Int32Ty, Case == SI->case_default() ?
                            LikelyBranchWeight : UnlikelyBranchWeight);
  for (unsigned i = 0; i < n; ++i) {
    Vec[i + 1 + 1] = ConstantInt::get(Int32Ty, i == Case.getCaseIndex() ?
        LikelyBranchWeight : UnlikelyBranchWeight);
  }

  MDNode *WeightsNode = llvm::MDNode::get(Context, Vec);
  SI->setMetadata(LLVMContext::MD_prof, WeightsNode);

  SI->setCondition(ArgValue);
  return true;
}