// performs the rl update at a state
void rl_perform_update( agent *my_agent, double op_value, bool op_rl, Symbol *goal, bool update_efr )
{
	bool using_gaps = ( my_agent->rl_params->temporal_extension->get_value() == soar_module::on );

	if ( !using_gaps || op_rl )
	{		
		rl_data *data = goal->id.rl_info;
		
		if ( !data->prev_op_rl_rules->empty() )
		{			
			rl_et_map::iterator iter;			
			double alpha = my_agent->rl_params->learning_rate->get_value();
			double lambda = my_agent->rl_params->et_decay_rate->get_value();
			double gamma = my_agent->rl_params->discount_rate->get_value();
			double tolerance = my_agent->rl_params->et_tolerance->get_value();
            double theta = my_agent->rl_params->meta_learning_rate->get_value();

			// if temporal_discount is off, don't discount for gaps
			unsigned int effective_age = data->hrl_age + 1;
			if (my_agent->rl_params->temporal_discount->get_value() == soar_module::on) {
				effective_age += data->gap_age;
			}
 
			double discount = pow( gamma, static_cast< double >( effective_age ) );

			// notify of gap closure
			if ( data->gap_age && using_gaps && my_agent->sysparams[ TRACE_RL_SYSPARAM ] )
			{
				char buf[256];
				SNPRINTF( buf, 254, "gap ended (%c%llu)", goal->id.name_letter, static_cast<long long unsigned>(goal->id.name_number) );

				print( my_agent, buf );
				xml_generate_warning( my_agent, buf );
			}			

			// Iterate through eligibility_traces, decay traces. If less than TOLERANCE, remove from map.
			if ( lambda == 0 )
			{
				if ( !data->eligibility_traces->empty() )
				{
					data->eligibility_traces->clear();
				}
			}
			else
			{
				for ( iter = data->eligibility_traces->begin(); iter != data->eligibility_traces->end(); )
				{
					iter->second *= lambda;
					iter->second *= discount;
					if ( iter->second < tolerance ) 
					{
						data->eligibility_traces->erase( iter++ );
					}
					else 
					{
						++iter;
					}
				}
			}
			
			// Update trace for just fired prods
			double sum_old_ecr = 0.0;
			double sum_old_efr = 0.0;
			if ( !data->prev_op_rl_rules->empty() )
			{
				double trace_increment = ( 1.0 / static_cast<double>( data->prev_op_rl_rules->size() ) );
				rl_rule_list::iterator p;
				
				for ( p=data->prev_op_rl_rules->begin(); p!=data->prev_op_rl_rules->end(); p++ )
				{
					sum_old_ecr += (*p)->rl_ecr;
					sum_old_efr += (*p)->rl_efr;
					
					iter = data->eligibility_traces->find( (*p) );
					
					if ( iter != data->eligibility_traces->end() ) 
					{
						iter->second += trace_increment;
					}
					else 
					{
						(*data->eligibility_traces)[ (*p) ] = trace_increment;
					}
				}
			}
			
			// For each prod with a trace, perform update
			{
				double old_ecr, old_efr;
				double delta_ecr, delta_efr;
				double new_combined, new_ecr, new_efr;
                double delta_t = (data->reward + discount * op_value) - (sum_old_ecr + sum_old_efr);
				
				for ( iter = data->eligibility_traces->begin(); iter != data->eligibility_traces->end(); iter++ )
				{	
					production *prod = iter->first;

					// get old vals
					old_ecr = prod->rl_ecr;
					old_efr = prod->rl_efr;

                    // Adjust alpha based on decay policy
                    // Miller 11/14/2011
                    double adjusted_alpha;
                    switch (my_agent->rl_params->decay_mode->get_value())
                    {
                        case rl_param_container::exponential_decay:
                            adjusted_alpha = 1.0 / (prod->rl_update_count + 1.0);
                            break;
                        case rl_param_container::logarithmic_decay:
                            adjusted_alpha = 1.0 / (log(prod->rl_update_count + 1.0) + 1.0);
                            break;
                        case rl_param_container::delta_bar_delta_decay:
                            {
                                // Note that in this case, x_i = 1.0 for all productions that are being updated.
                                // Those values have been included here for consistency with the algorithm as described in the delta bar delta paper.
                                prod->rl_delta_bar_delta_beta = prod->rl_delta_bar_delta_beta + theta * delta_t * 1.0 * prod->rl_delta_bar_delta_h;
                                adjusted_alpha = exp(prod->rl_delta_bar_delta_beta);
                                double decay_term = 1.0 - adjusted_alpha * 1.0 * 1.0;
                                if (decay_term < 0.0) decay_term = 0.0;
                                prod->rl_delta_bar_delta_h = prod->rl_delta_bar_delta_h * decay_term + adjusted_alpha * delta_t * 1.0;
                                break;
                            }
                        case rl_param_container::normal_decay:
                        default:
                            adjusted_alpha = alpha;
                            break;
                    }

                    // calculate updates
                    delta_ecr = ( adjusted_alpha * iter->second * ( data->reward - sum_old_ecr ) );

                    if ( update_efr )
                    {
                        delta_efr = ( adjusted_alpha * iter->second * ( ( discount * op_value ) - sum_old_efr ) );
                    }
                    else
					{
						delta_efr = 0.0;
					}					

					// calculate new vals
					new_ecr = ( old_ecr + delta_ecr );
					new_efr = ( old_efr + delta_efr );
					new_combined = ( new_ecr + new_efr );
					
					// print as necessary
					if ( my_agent->sysparams[ TRACE_RL_SYSPARAM ] ) 
					{
						std::ostringstream ss;						
						ss << "RL update " << prod->name->sc.name << " "
						   << old_ecr << " " << old_efr << " " << old_ecr + old_efr << " -> "
						   << new_ecr << " " << new_efr << " " << new_combined ;

						std::string temp_str( ss.str() );						
						print( my_agent, "%s\n", temp_str.c_str() );
						xml_generate_message( my_agent, temp_str.c_str() );

                        // Log update to file if the log file has been set
                        std::string log_path = my_agent->rl_params->update_log_path->get_value();
                        if (!log_path.empty()) {
                            std::ofstream file(log_path.c_str(), std::ios_base::app);
                            file << ss.str() << std::endl;
                            file.close();
                        }
                    }

                    // Change value of rule
                    symbol_remove_ref( my_agent, rhs_value_to_symbol( prod->action_list->referent ) );
                    prod->action_list->referent = symbol_to_rhs_value( make_float_constant( my_agent, new_combined ) );
                    prod->rl_update_count += 1;
                    prod->rl_ecr = new_ecr;
                    prod->rl_efr = new_efr;

                    // change documentation
                    if ( my_agent->rl_params->meta->get_value() == soar_module::on )
                    {
                        if ( prod->documentation )
                        {
                            free_memory_block_for_string( my_agent, prod->documentation );
                        }
                        std::stringstream doc_ss;
                        const std::vector<std::pair<std::string, param_accessor<double> *> > &documentation_params = my_agent->rl_params->get_documentation_params();
                        for (std::vector<std::pair<std::string, param_accessor<double> *> >::const_iterator doc_params_it = documentation_params.begin();
                                doc_params_it != documentation_params.end(); ++doc_params_it) {
                            doc_ss << doc_params_it->first << "=" << doc_params_it->second->get_param(prod) << ";";
                        }
                        prod->documentation = make_memory_block_for_string(my_agent, doc_ss.str().c_str());

                        /*
						std::string rlupdates( "rlupdates=" );
						std::string val;
						to_string( static_cast< uint64_t >( prod->rl_update_count ), val );
						rlupdates.append( val );

						prod->documentation = make_memory_block_for_string( my_agent, rlupdates.c_str() );
                        */
					}

					// Change value of preferences generated by current instantiations of this rule
					if ( prod->instantiations )
					{
						for ( instantiation *inst = prod->instantiations; inst; inst = inst->next )
						{
							for ( preference *pref = inst->preferences_generated; pref; pref = pref->inst_next )
							{
								symbol_remove_ref( my_agent, pref->referent );
								pref->referent = make_float_constant( my_agent, new_combined );
							}
						}
					}	
				}
			}
		}

		data->gap_age = 0;
		data->hrl_age = 0;
		data->reward = 0.0;
	}
}
Beispiel #2
0
void wma_go( agent* my_agent, wma_go_action go_action )
{
	// update history for all touched elements
	if ( go_action == wma_histories )
	{
		my_agent->wma_timers->history->start();

		wma_update_decay_histories( my_agent );

		my_agent->wma_timers->history->stop();
	}
	// check forgetting queue
	else if ( go_action == wma_forgetting )
	{
		wma_param_container::forgetting_choices forgetting = my_agent->wma_params->forgetting->get_value();

		if ( forgetting != wma_param_container::disabled )
		{
			my_agent->wma_timers->forgetting->start();

			bool forgot_something = false;

			if ( forgetting == wma_param_container::naive )
			{
				forgot_something = wma_forgetting_naive_sweep( my_agent );
			}
			else
			{
				forgot_something = wma_forgetting_update_p_queue( my_agent );
			}

			if ( forgot_something )
			{
				if ( my_agent->sysparams[ TRACE_WM_CHANGES_SYSPARAM ] )
				{
					const char *msg = "\n\nWMA: BEGIN FORGOTTEN WME LIST\n\n";

					print( my_agent, const_cast<char *>( msg ) );
					xml_generate_message( my_agent, const_cast<char *>( msg ) );
				}

				uint64_t wm_removal_diff = my_agent->wme_removal_count;
				{
					do_working_memory_phase( my_agent );
				}
				wm_removal_diff = ( my_agent->wme_removal_count - wm_removal_diff );

				if ( wm_removal_diff > 0 )
				{
					my_agent->wma_stats->forgotten_wmes->set_value( my_agent->wma_stats->forgotten_wmes->get_value() + static_cast< int64_t >( wm_removal_diff ) );
				}

				if ( my_agent->sysparams[ TRACE_WM_CHANGES_SYSPARAM ] )
				{
					const char *msg = "\nWMA: END FORGOTTEN WME LIST\n\n";

					print( my_agent, const_cast<char *>( msg ) );
					xml_generate_message( my_agent, const_cast<char *>( msg ) );
				}
			}

			my_agent->wma_timers->forgetting->stop();
		}
	}
}