Example #1
0
SDValue PTXTargetLowering::
  LowerReturn(SDValue Chain,
              CallingConv::ID CallConv,
              bool isVarArg,
              const SmallVectorImpl<ISD::OutputArg> &Outs,
              const SmallVectorImpl<SDValue> &OutVals,
              DebugLoc dl,
              SelectionDAG &DAG) const {
  if (isVarArg) llvm_unreachable("PTX does not support varargs");

  switch (CallConv) {
    default:
      llvm_unreachable("Unsupported calling convention.");
    case CallingConv::PTX_Kernel:
      assert(Outs.size() == 0 && "Kernel must return void.");
      return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
    case CallingConv::PTX_Device:
      //assert(Outs.size() <= 1 && "Can at most return one value.");
      break;
  }

  MachineFunction& MF = DAG.getMachineFunction();
  PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
  SmallVector<CCValAssign, 16> RVLocs;
  CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(),
                 getTargetMachine(), RVLocs, *DAG.getContext());

  SDValue Flag;

  CCInfo.AnalyzeReturn(Outs, RetCC_PTX);

  for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) {

    CCValAssign& VA  = RVLocs[i];

    assert(VA.isRegLoc() && "CCValAssign must be RegLoc");

    unsigned Reg = VA.getLocReg();

    DAG.getMachineFunction().getRegInfo().addLiveOut(Reg);

    Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag);

    // Guarantee that all emitted copies are stuck together,
    // avoiding something bad
    Flag = Chain.getValue(1);

    MFI->addRetReg(Reg);
  }

  if (Flag.getNode() == 0) {
    return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
  }
  else {
    return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
  }
}
SDValue PTXTargetLowering::
  LowerReturn(SDValue Chain,
              CallingConv::ID CallConv,
              bool isVarArg,
              const SmallVectorImpl<ISD::OutputArg> &Outs,
              const SmallVectorImpl<SDValue> &OutVals,
              DebugLoc dl,
              SelectionDAG &DAG) const {
  if (isVarArg) llvm_unreachable("PTX does not support varargs");

  switch (CallConv) {
    default:
      llvm_unreachable("Unsupported calling convention.");
    case CallingConv::PTX_Kernel:
      assert(Outs.size() == 0 && "Kernel must return void.");
      return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
    case CallingConv::PTX_Device:
      assert(Outs.size() <= 1 && "Can at most return one value.");
      break;
  }

  MachineFunction& MF = DAG.getMachineFunction();
  PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
  PTXParamManager &PM = MFI->getParamManager();

  SDValue Flag;
  const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();

  if (ST.useParamSpaceForDeviceArgs()) {
    assert(Outs.size() < 2 && "Device functions can return at most one value");

    if (Outs.size() == 1) {
      unsigned ParamSize = OutVals[0].getValueType().getSizeInBits();
      unsigned Param = PM.addReturnParam(ParamSize);
      const std::string &ParamName = PM.getParamName(Param);
      SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
                                                       MVT::Other);
      Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
                          ParamValue, OutVals[0]);
    }
  } else {
    for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
      EVT                  RegVT = Outs[i].VT;
      TargetRegisterClass* TRC = 0;

      // Determine which register class we need
      if (RegVT == MVT::i1) {
        TRC = PTX::RegPredRegisterClass;
      }
      else if (RegVT == MVT::i16) {
        TRC = PTX::RegI16RegisterClass;
      }
      else if (RegVT == MVT::i32) {
        TRC = PTX::RegI32RegisterClass;
      }
      else if (RegVT == MVT::i64) {
        TRC = PTX::RegI64RegisterClass;
      }
      else if (RegVT == MVT::f32) {
        TRC = PTX::RegF32RegisterClass;
      }
      else if (RegVT == MVT::f64) {
        TRC = PTX::RegF64RegisterClass;
      }
      else {
        llvm_unreachable("Unknown parameter type");
      }

      unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);

      SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/);
      SDValue OutReg = DAG.getRegister(Reg, RegVT);

      Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg);

      MFI->addRetReg(Reg);
    }
  }

  if (Flag.getNode() == 0) {
    return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
  }
  else {
    return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
  }
}