Ejemplo n.º 1
0
void CActorFromActionValue::receiveError(double critic, CStateCollection *oldState, CAction *action, CActionData *)
{
	int actionIndex = actions->getIndex(action);
	eTraces->updateETraces(action->getDuration());
	eTraces->addETrace(oldState, (actionIndex == 0) - 0.5);
	eTraces->updateVFunction(critic * getLearningRate());
}
Ejemplo n.º 2
0
void CActorFromQFunction::receiveError(double critic, CStateCollection *state, CAction *action, CActionData *)
{
	DebugPrint('t',"Actor updating Etraces \n");
	eTraces->updateETraces(action);
	eTraces->addETrace(state, action);
	DebugPrint('t', "Actor updating QFunction with critic %f \n", critic);
	eTraces->updateQFunction(critic * getLearningRate());
}
Ejemplo n.º 3
0
void CActorFromQFunctionAndPolicy::receiveError(double critic, CStateCollection *state, CAction *Action, CActionData *)
{
	policy->getActionProbabilities(state, qFunction->getActions(), actionValues);
	double prob = actionValues[qFunction->getActions()->getIndex(Action)];

	eTraces->updateETraces(Action);
	eTraces->addETrace(state, Action, getParameter("PolicyMinimumLearningRate") + 1.0 - prob);
	eTraces->updateQFunction(critic * getLearningRate());
}
Ejemplo n.º 4
0
void AdaDeltaSolver<Dtype>::applyUpdate(){
	CHECK(Dragon::get_root_solver());
	Dtype rate = getLearningRate();
	//	AdaDelta do not need base lr
	if (param.display() && iter%param.display() == 0)
		cout << "Iteration " << iter << ", lr = AdaDelta" << endl;
	clipGradients();
	vector<Blob<Dtype>*> net_params = net->getLearnableParams();
	for (int i = 0; i < net_params.size(); i++){
		normalize(i);
		regularize(i);
		computeUpdateValue(i, rate);
		net_params[i]->update();
	}
}
Ejemplo n.º 5
0
bool cvzMmcm_IDL::read(yarp::os::ConnectionReader& connection) {
  yarp::os::idl::WireReader reader(connection);
  reader.expectAccept();
  if (!reader.readListHeader()) { reader.fail(); return false; }
  yarp::os::ConstString tag = reader.readTag();
  bool direct = (tag=="__direct__");
  if (direct) tag = reader.readTag();
  while (!reader.isError()) {
    // TODO: use quick lookup, this is just a test
    if (tag == "start") {
      start();
      yarp::os::idl::WireWriter writer(reader);
      if (!writer.isNull()) {
        if (!writer.writeListHeader(0)) return false;
      }
      reader.accept();
      return true;
    }
    if (tag == "pause") {
      pause();
      yarp::os::idl::WireWriter writer(reader);
      if (!writer.isNull()) {
        if (!writer.writeListHeader(0)) return false;
      }
      reader.accept();
      return true;
    }
    if (tag == "quit") {
      bool _return;
      _return = quit();
      yarp::os::idl::WireWriter writer(reader);
      if (!writer.isNull()) {
        if (!writer.writeListHeader(1)) return false;
        if (!writer.writeBool(_return)) return false;
      }
      reader.accept();
      return true;
    }
    if (tag == "setLearningRate") {
      double l;
      if (!reader.readDouble(l)) {
        reader.fail();
        return false;
      }
      setLearningRate(l);
      yarp::os::idl::WireWriter writer(reader);
      if (!writer.isNull()) {
        if (!writer.writeListHeader(0)) return false;
      }
      reader.accept();
      return true;
    }
    if (tag == "getLearningRate") {
      double _return;
      _return = getLearningRate();
      yarp::os::idl::WireWriter writer(reader);
      if (!writer.isNull()) {
        if (!writer.writeListHeader(1)) return false;
        if (!writer.writeDouble(_return)) return false;
      }
      reader.accept();
      return true;
    }
    if (tag == "setSigma") {
      double s;
      if (!reader.readDouble(s)) {
        reader.fail();
        return false;
      }
      setSigma(s);
      yarp::os::idl::WireWriter writer(reader);
      if (!writer.isNull()) {
        if (!writer.writeListHeader(0)) return false;
      }
      reader.accept();
      return true;
    }
    if (tag == "getSigma") {
      double _return;
      _return = getSigma();
      yarp::os::idl::WireWriter writer(reader);
      if (!writer.isNull()) {
        if (!writer.writeListHeader(1)) return false;
        if (!writer.writeDouble(_return)) return false;
      }
      reader.accept();
      return true;
    }
    if (tag == "getActivity") {
      int32_t x;
      int32_t y;
      int32_t z;
      if (!reader.readI32(x)) {
        reader.fail();
        return false;
      }
      if (!reader.readI32(y)) {
        reader.fail();
        return false;
      }
      if (!reader.readI32(z)) {
        reader.fail();
        return false;
      }
      double _return;
      _return = getActivity(x,y,z);
      yarp::os::idl::WireWriter writer(reader);
      if (!writer.isNull()) {
        if (!writer.writeListHeader(1)) return false;
        if (!writer.writeDouble(_return)) return false;
      }
      reader.accept();
      return true;
    }
    if (tag == "saveWeightsToFile") {
      std::string path;
      if (!reader.readString(path)) {
        reader.fail();
        return false;
      }
      bool _return;
      _return = saveWeightsToFile(path);
      yarp::os::idl::WireWriter writer(reader);
      if (!writer.isNull()) {
        if (!writer.writeListHeader(1)) return false;
        if (!writer.writeBool(_return)) return false;
      }
      reader.accept();
      return true;
    }
    if (tag == "loadWeightsFromFile") {
      std::string path;
      if (!reader.readString(path)) {
        reader.fail();
        return false;
      }
      bool _return;
      _return = loadWeightsFromFile(path);
      yarp::os::idl::WireWriter writer(reader);
      if (!writer.isNull()) {
        if (!writer.writeListHeader(1)) return false;
        if (!writer.writeBool(_return)) return false;
      }
      reader.accept();
      return true;
    }
    if (tag == "saveRF") {
      std::string path;
      if (!reader.readString(path)) {
        reader.fail();
        return false;
      }
      bool _return;
      _return = saveRF(path);
      yarp::os::idl::WireWriter writer(reader);
      if (!writer.isNull()) {
        if (!writer.writeListHeader(1)) return false;
        if (!writer.writeBool(_return)) return false;
      }
      reader.accept();
      return true;
    }
    if (tag == "help") {
      std::string functionName;
      if (!reader.readString(functionName)) {
        functionName = "--all";
      }
      std::vector<std::string> _return=help(functionName);
      yarp::os::idl::WireWriter writer(reader);
        if (!writer.isNull()) {
          if (!writer.writeListHeader(2)) return false;
          if (!writer.writeTag("many",1, 0)) return false;
          if (!writer.writeListBegin(BOTTLE_TAG_INT, static_cast<uint32_t>(_return.size()))) return false;
          std::vector<std::string> ::iterator _iterHelp;
          for (_iterHelp = _return.begin(); _iterHelp != _return.end(); ++_iterHelp)
          {
            if (!writer.writeString(*_iterHelp)) return false;
           }
          if (!writer.writeListEnd()) return false;
        }
      reader.accept();
      return true;
    }
    if (reader.noMore()) { reader.fail(); return false; }
    yarp::os::ConstString next_tag = reader.readTag();
    if (next_tag=="") break;
    tag = tag + "_" + next_tag;
  }
  return false;
}