bool SimpleGuidanceAction::operator<(const Action& other_base) const {
   try {
     const SimpleGuidanceAction& other = dynamic_cast<const SimpleGuidanceAction&>(other_base);
     return type < other.type;
   } catch(const std::bad_cast& exp) {
     throw DowncastException("Action", "SimpleGuidanceAction");
   }
 }
示例#2
0
 bool LightWorldAction::operator<(const Action& other_base) const {
   try {
     const LightWorldAction& other = dynamic_cast<const LightWorldAction&>(other_base);
     return type < other.type;
   } catch(const std::bad_cast& exp) {
     throw DowncastException("Action", "LightWorldAction");
   }
 }
 bool SimpleGuidanceModel::isTerminalState(const State::ConstPtr& state_base) const {
   boost::shared_ptr<const SimpleGuidanceState> state =
     boost::dynamic_pointer_cast<const SimpleGuidanceState>(state_base);
   if (!state) {
     throw DowncastException("State", "SimpleGuidanceState");
   }
   return ((state->x == params_.grid_size/2) && (state->y == params_.grid_size/2));
 }
示例#4
0
 bool LightWorldModel::isTerminalState(const State::ConstPtr& state_base) const {
   boost::shared_ptr<const LightWorldState> state = boost::dynamic_pointer_cast<const LightWorldState>(state_base);
   if (!state) {
     throw DowncastException("State", "LightWorldState");
   }
   return
     /* ((!(state->goal_unlocked)) && (state->unlock_attempts_left == 0)) || */
     ((state->x == params_.goal_x) && (state->y == params_.goal_y) && (state->goal_unlocked));
 }
示例#5
0
 void LightWorldModel::getActionsAtState(const State::ConstPtr& state_base,
                                         std::vector<Action::ConstPtr>& actions) const {
   actions.clear();
   boost::shared_ptr<const LightWorldState> state = boost::dynamic_pointer_cast<const LightWorldState>(state_base);
   if (!state) {
     throw DowncastException("State", "LightWorldState");
   }
   if (!isTerminalState(state)) {
     boost::shared_ptr<const LightWorldState> state = boost::dynamic_pointer_cast<const LightWorldState>(state_base);
     if (!state) {
       throw DowncastException("State", "LightWorldState");
     }
     if (state->x > 0) {
       actions.push_back(left_action);
     }
     if (state->x < params_.grid_size - 1) {
       actions.push_back(right_action);
     }
     if (state->y > 0) {
       actions.push_back(down_action);
     }
     if (state->y < params_.grid_size - 1) {
       actions.push_back(up_action);
     }
     if (!state->key_picked_up) {
       actions.push_back(pickup_action);
     } else {
       // Can only unlock when you have the key.
       if (!state->goal_unlocked && state->unlock_attempts_left > 0) {
         actions.push_back(unlock_action);
       }
     }
   }
   if (actions.size() == 0) {
     std::stringstream ss;
     ss << *state;
     throw std::runtime_error("Bug! Found 0 actions at state " + ss.str());
   }
 }
  void SimpleGuidanceModel::getActionsAtState(const State::ConstPtr& state_base,
                                              std::vector<Action::ConstPtr>& actions) const {
    actions.clear();
    boost::shared_ptr<const SimpleGuidanceState> state = boost::dynamic_pointer_cast<const SimpleGuidanceState>(state_base);
    if (!state) {
      throw DowncastException("State", "SimpleGuidanceState");
    }
    if (!isTerminalState(state)) {
      boost::shared_ptr<const SimpleGuidanceState> state = boost::dynamic_pointer_cast<const SimpleGuidanceState>(state_base);
      if (!state) {
        throw DowncastException("State", "SimpleGuidanceState");
      }

      // Check if we're at a roundabout;
      bool at_roundabout = false;
      for (int roundabout_idx = 0; roundabout_idx < roundabouts_.size(); ++roundabout_idx) {
        const std::pair<int, int>& roundabout = roundabouts_[roundabout_idx];
        if ((state->x == roundabout.first) && (state->y == roundabout.second)) {
          at_roundabout = true;
          break;
        }
      }

      if (!at_roundabout) {
        actions.push_back(noop_action);
      } else {
        actions.push_back(left_action);
        actions.push_back(right_action);
        actions.push_back(up_action);
        actions.push_back(down_action);
      }

    }
    if (actions.size() == 0) {
      std::stringstream ss;
      ss << *state;
      throw std::runtime_error("Bug! Found 0 actions at state " + ss.str());
    }
  }
示例#7
0
  bool GridState::operator<(const State& other_base) const {
    try {
      const GridState& other = dynamic_cast<const GridState&>(other_base);
      if (x < other.x) return true;
      if (x > other.x) return false;

      if (y < other.y) return true;
      if (y > other.y) return false;

      return false;
    } catch(const std::bad_cast& exp) {
      throw DowncastException("State", "GridState");
    }
  }
  bool SimpleGuidanceState::operator<(const State& other_base) const {
    try {
      const SimpleGuidanceState& other = dynamic_cast<const SimpleGuidanceState&>(other_base);
      if (x < other.x) return true;
      if (x > other.x) return false;

      if (y < other.y) return true;
      if (y > other.y) return false;

      if (prev_action < other.prev_action) return true;
      if (other.prev_action < prev_action) return false;

      return false;
    } catch(const std::bad_cast& exp) {
      throw DowncastException("State", "SimpleGuidanceState");
    }
  }
示例#9
0
  bool LightWorldState::operator<(const State& other_base) const {
    try {
      const LightWorldState& other = dynamic_cast<const LightWorldState&>(other_base);
      if (x < other.x) return true;
      if (x > other.x) return false;

      if (y < other.y) return true;
      if (y > other.y) return false;

      if (key_picked_up < other.key_picked_up) return true;
      if (key_picked_up > other.key_picked_up) return false;

      if (goal_unlocked < other.goal_unlocked) return true;
      if (goal_unlocked > other.goal_unlocked) return false;

      if (unlock_attempts_left < other.unlock_attempts_left) return true;
      if (unlock_attempts_left > other.unlock_attempts_left) return false;

      return false;
    } catch(const std::bad_cast& exp) {
      throw DowncastException("State", "LightWorldState");
    }
  }
示例#10
0
  void LightWorldModel::getTransitionDynamics(const State::ConstPtr& state_base,
                                              const Action::ConstPtr& action_base,
                                              std::vector<State::ConstPtr>& next_states,
                                              std::vector<float>& rewards,
                                              std::vector<float>& probabilities) const {

    boost::shared_ptr<const LightWorldState> state = boost::dynamic_pointer_cast<const LightWorldState>(state_base);
    if (!state) {
      throw DowncastException("State", "LightWorldState");
    }

    boost::shared_ptr<const LightWorldAction> action = boost::dynamic_pointer_cast<const LightWorldAction>(action_base);
    if (!action) {
      throw DowncastException("Action", "LightWorldAction");
    }

    next_states.clear();
    rewards.clear();
    probabilities.clear();

    if (!isTerminalState(state)) {

      if (action->type == PICKUP || action->type == UNLOCK) {
        LightWorldState next_state = *state;
        float reward = -1.0f;
        if (action->type == PICKUP) {
          if (state->x == params_.key_x && state->y == params_.key_y) {
            next_state.key_picked_up = true;
          } else {
            reward = params_.incorrect_pickup_reward;
          }
        } else if (action->type == UNLOCK) {
          if (state->unlock_attempts_left > 0) {
            next_state.unlock_attempts_left -= 1;
            if (state->x == params_.lock_x &&
                state->y == params_.lock_y) {
              next_state.goal_unlocked = true;
            } else {
              reward = params_.incorrect_unlock_reward;
            }
          }
        }
        rewards.push_back(reward);
        probabilities.push_back(1.0f);

        int idx = getStateIndex(next_state);
        next_states.push_back(complete_state_vector_[idx]);
      } else {
        // We're performing a navigation action.
        int num_valid_nav_actions = 0;
        if (state->x > 0) {
          ++num_valid_nav_actions;
        }
        if (state->x < params_.grid_size - 1) {
          ++num_valid_nav_actions;
        }
        if (state->y > 0) {
          ++num_valid_nav_actions;
        }
        if (state->y < params_.grid_size - 1) {
          ++num_valid_nav_actions;
        }

        if (state->x > 0) {
          // Left action is valid;
          LightWorldState next_state = *state;
          --next_state.x;
          float p = (action->type == LEFT) ?
            ((1.0f - params_.nondeterminism) + (params_.nondeterminism / num_valid_nav_actions)) :
            (params_.nondeterminism / num_valid_nav_actions);
          int idx = getStateIndex(next_state);
          float reward = -1.0f;
          if (isTerminalState(complete_state_vector_[idx])) {
            reward += 100.0f;
          }
          next_states.push_back(complete_state_vector_[idx]);
          probabilities.push_back(p);
          rewards.push_back(reward);
        }
        if (state->x < params_.grid_size - 1) {
          // Right action is valid;
          LightWorldState next_state = *state;
          ++next_state.x;
          float p = (action->type == RIGHT) ?
            ((1.0f - params_.nondeterminism) + (params_.nondeterminism / num_valid_nav_actions)) :
            (params_.nondeterminism / num_valid_nav_actions);
          int idx = getStateIndex(next_state);
          float reward = -1.0f;
          if (isTerminalState(complete_state_vector_[idx])) {
            reward += 100.0f;
          }
          next_states.push_back(complete_state_vector_[idx]);
          probabilities.push_back(p);
          rewards.push_back(reward);
        }
        if (state->y > 0) {
          // Down action is valid;
          LightWorldState next_state = *state;
          --next_state.y;
          float p = (action->type == DOWN) ?
            ((1.0f - params_.nondeterminism) + (params_.nondeterminism / num_valid_nav_actions)) :
            (params_.nondeterminism / num_valid_nav_actions);
          int idx = getStateIndex(next_state);
          float reward = -1.0f;
          if (isTerminalState(complete_state_vector_[idx])) {
            reward += 100.0f;
          }
          next_states.push_back(complete_state_vector_[idx]);
          probabilities.push_back(p);
          rewards.push_back(reward);
        }
        if (state->y < params_.grid_size - 1) {
          // Up action is valid;
          LightWorldState next_state = *state;
          ++next_state.y;
          float p = (action->type == UP) ?
            ((1.0f - params_.nondeterminism) + (params_.nondeterminism / num_valid_nav_actions)) :
            (params_.nondeterminism / num_valid_nav_actions);
          int idx = getStateIndex(next_state);
          float reward = -1.0f;
          if (isTerminalState(complete_state_vector_[idx])) {
            reward += 100.0f;
          }
          next_states.push_back(complete_state_vector_[idx]);
          probabilities.push_back(p);
          rewards.push_back(reward);
        }

      }
    }
  }
示例#11
0
  void SimpleGuidanceModel::getTransitionDynamics(const State::ConstPtr& state_base,
                                              const Action::ConstPtr& action_base,
                                              std::vector<State::ConstPtr>& next_states,
                                              std::vector<float>& rewards,
                                              std::vector<float>& probabilities) const {

    boost::shared_ptr<const SimpleGuidanceState> state = boost::dynamic_pointer_cast<const SimpleGuidanceState>(state_base);
    if (!state) {
      throw DowncastException("State", "SimpleGuidanceState");
    }

    boost::shared_ptr<const SimpleGuidanceAction> action = boost::dynamic_pointer_cast<const SimpleGuidanceAction>(action_base);
    if (!action) {
      throw DowncastException("Action", "SimpleGuidanceAction");
    }

    next_states.clear();
    rewards.clear();
    probabilities.clear();

    if (!isTerminalState(state)) {

      float p_left = (state->prev_action.type == LEFT) ? 1.0f : 0.0f;
      float p_right = (state->prev_action.type == RIGHT) ? 1.0f : 0.0f;
      float p_up = (state->prev_action.type == UP) ? 1.0f : 0.0f;
      float p_down = (state->prev_action.type == DOWN) ? 1.0f : 0.0f;

      if (action->type != NOOP) {
        float success_prob = ((1.0f - params_.nondeterminism) + (params_.nondeterminism / 4));
        float failure_prob = (params_.nondeterminism / 4);
        p_left = (action->type == LEFT) ? success_prob : failure_prob;
        p_right = (action->type == RIGHT) ? success_prob : failure_prob;
        p_up = (action->type == UP) ? success_prob : failure_prob;
        p_down = (action->type == DOWN) ? success_prob : failure_prob;
      }

      bool left_blocked = (state->x == 0);
      bool right_blocked = (state->x == params_.grid_size - 1);
      bool up_blocked = (state->y == params_.grid_size - 1);
      bool down_blocked = (state->y == 0);

      if (left_blocked) {
        if (up_blocked) {
          p_down += p_left;
        } else if (down_blocked) {
          p_up += p_left;
        } else {
          p_up += p_left / 2;
          p_down += p_left / 2;
        }
        p_left = 0.0f;
      }

      if (right_blocked) {
        if (up_blocked) {
          p_down += p_right;
        } else if (down_blocked) {
          p_up += p_right;
        } else {
          p_up += p_right / 2;
          p_down += p_right / 2;
        }
        p_right = 0.0f;
      }

      if (up_blocked) {
        if (left_blocked) {
          p_right += p_up;
        } else if (right_blocked) {
          p_left += p_up;
        } else {
          p_right += p_up / 2;
          p_left += p_up / 2;
        }
        p_up = 0.0f;
      }

      if (down_blocked) {
        if (left_blocked) {
          p_right += p_down;
        } else if (right_blocked) {
          p_left += p_down;
        } else {
          p_right += p_down / 2;
          p_left += p_down / 2;
        }
        p_down = 0.0f;
      }

      if (p_up != 0.0f) {
        SimpleGuidanceState next_state = *state;
        next_state.y += 1;
        next_state.prev_action.type = UP;

        probabilities.push_back(p_up);
        rewards.push_back(-1);
        next_states.push_back(complete_state_vector_[getStateIndex(next_state)]);
      }

      if (p_down != 0.0f) {
        SimpleGuidanceState next_state = *state;
        next_state.y -= 1;
        next_state.prev_action.type = DOWN;

        probabilities.push_back(p_down);
        rewards.push_back(-1);
        next_states.push_back(complete_state_vector_[getStateIndex(next_state)]);
      }

      if (p_left != 0.0f) {
        SimpleGuidanceState next_state = *state;
        next_state.x -= 1;
        next_state.prev_action.type = LEFT;

        probabilities.push_back(p_left);
        rewards.push_back(-1);
        next_states.push_back(complete_state_vector_[getStateIndex(next_state)]);
      }

      if (p_right != 0.0f) {
        SimpleGuidanceState next_state = *state;
        next_state.x += 1;
        next_state.prev_action.type = RIGHT;

        probabilities.push_back(p_right);
        rewards.push_back(-1);
        next_states.push_back(complete_state_vector_[getStateIndex(next_state)]);
      }

      if (probabilities.size() == 0) {
        std::cout << *state << " " << *action << std::endl;
        assert(probabilities.size() == 0);
      }
    }
  }