RLModule.cpp

Go to the documentation of this file.
00001 /************************************************************************
00002 * Verve                                                                 *
00003 * Copyright (C) 2004-2006                                               *
00004 * Tyler Streeter  tylerstreeter@gmail.com                               *
00005 * All rights reserved.                                                  *
00006 * Web: http://verve-agents.sourceforge.net                              *
00007 *                                                                       *
00008 * This library is free software; you can redistribute it and/or         *
00009 * modify it under the terms of EITHER:                                  *
00010 *   (1) The GNU Lesser General Public License as published by the Free  *
00011 *       Software Foundation; either version 2.1 of the License, or (at  *
00012 *       your option) any later version. The text of the GNU Lesser      *
00013 *       General Public License is included with this library in the     *
00014 *       file license-LGPL.txt.                                          *
00015 *   (2) The BSD-style license that is included with this library in     *
00016 *       the file license-BSD.txt.                                       *
00017 *                                                                       *
00018 * This library is distributed in the hope that it will be useful,       *
00019 * but WITHOUT ANY WARRANTY; without even the implied warranty of        *
00020 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the files    *
00021 * license-LGPL.txt and license-BSD.txt for more details.                *
00022 ************************************************************************/
00023 
00024 #include <fstream>
00025 
00026 #include "RLModule.h"
00027 #include "Observation.h"
00028 #include "Population.h"
00029 #include "Projection.h"
00030 #include "RBFNeuron.h"
00031 #include "RBFPopulation.h"
00032 #include "TDProjection.h"
00033 #include "TDConnection.h"
00034 #include "WinnerTakeAllPopulation.h"
00035 
00036 namespace verve
00037 {
00038         RLModule::RLModule(const Observation& obs, bool isDynamicRBFEnabled, 
00039                 unsigned int numActions)
00040         {
00041                 // Make the stored input data match the size of the given 
00042                 // Observation.
00043                 mLatestInputData.init(obs);
00044 
00045                 // Create the state representation Population.
00046                 mStateRepresentation = new RBFPopulation();
00047                 mStateRepresentation->init(mLatestInputData, isDynamicRBFEnabled);
00048                 mAllPopulations.push_back(mStateRepresentation);
00049 
00050                 // Create the actor Population.
00051                 mActorPopulation = new WinnerTakeAllPopulation();
00052                 mActorPopulation->init(numActions);
00053                 mAllPopulations.push_back(mActorPopulation);
00054 
00055                 // Create the critic Population.
00056                 mCriticPopulation = new Population();
00057                 mCriticPopulation->init(1);
00058                 mAllPopulations.push_back(mCriticPopulation);
00059 
00060                 // Create a Projection from the state representation to the actor.  
00061                 // Start all policy Connections with zero weights to give each 
00062                 // action an equal initial selection probability.
00063                 mStateRepresentation->projectTD(mActorPopulation, 
00064                         POLICY_TDCONNECTION, WEIGHTS_NEAR_0, 
00065                         mStateRepresentation->computeMaxActivationSum());
00066 
00067                 // Create a Projection from the state representation to the critic.  
00068                 // Initialize weights to zero.
00069                 mStateRepresentation->projectTD(mCriticPopulation, 
00070                         VALUE_FUNCTION_TDCONNECTION, WEIGHTS_NEAR_0, 
00071                         mStateRepresentation->computeMaxActivationSum());
00072 
00073                 mFirstStep = true;
00074                 mTDError = 0;
00075                 mOldValueEstimation = 0;
00076                 mNewValueEstimation = 0;
00077                 mETraceTimeConstant = 0;
00078                 mTDDiscountTimeConstant = 0;
00079                 mTDDiscountFactor = 0;
00080                 mValueFunctionLearningTimeConstant = 0;
00081                 mValueFunctionLearningFactor = 0;
00082                 mPolicyLearningMultiplier = 0;
00083         }
00084 
00085         RLModule::~RLModule()
00086         {
00087                 // Destroy Populations, including the Neurons and Projections 
00088                 // contained within them.
00089                 while (!mAllPopulations.empty())
00090                 {
00091                         delete mAllPopulations.back();
00092                         mAllPopulations.pop_back();
00093                 }
00094         }
00095 
00096         void RLModule::resetShortTermMemory()
00097         {
00098                 mLatestInputData.zeroInputData();
00099 
00100                 unsigned int size = (unsigned int)mAllPopulations.size();
00101                 for (unsigned int i = 0; i < size; ++i)
00102                 {
00103                         mAllPopulations[i]->resetShortTermMemory();
00104                 }
00105 
00106                 // Empty the lists of active TDConnections.
00107                 mActiveValueFunctionTDConnections.clearList();
00108                 mActivePolicyTDConnections.clearList();
00109 
00110                 mFirstStep = true;
00111                 mTDError = 0;
00112                 mOldValueEstimation = 0;
00113                 mNewValueEstimation = 0;
00114         }
00115 
00116         unsigned int RLModule::update(const Observation& obs, real reinforcement)
00117         {
00118                 // Keep a copy of the latest input data.
00119                 mLatestInputData.copyInputData(obs.getDiscreteInputData(), 
00120                         obs.getContinuousInputData());
00121 
00122                 // The order of events here should be functionally equivalent to the 
00123                 // pseudocode listed on page 174 of Sutton's & Barto's book 
00124                 // Reinforcement Learning.  Here is the modified pseudocode being 
00125                 // implemented (V and A are the critic and actor, respectively): 
00126                 // 
00127                 // Increase eligibility traces using s.
00128                 // Compute V(s).
00129                 // Compute V(s').
00130                 // Compute action a using A(s').
00131                 // TD error = r' + gamma * V(s') - V(s).
00132                 // Train A and V using TD error.
00133                 // Decay eligibility traces.
00134                 // Replace s with s'.
00135                 // (Take action a, letting it affect the environment which then 
00136                 // updates r' and s' for next time.)
00137 
00138                 unsigned int actionIndex = 0;
00139 
00140                 if (mFirstStep)
00141                 {
00142                         // If this is the very first step, we simply choose a new 
00143                         // action based on the latest Observation.  This is to make sure 
00144                         // we don't try to use the state representation s before it is 
00145                         // actually valid: it won't be valid until the next update.
00146 
00147                         // Update the state to represent s'.  We assume here that 
00148                         // learning is enabled (which lets the state representation add 
00149                         // new RBFs if necessary).
00150                         mStateRepresentation->updateFiringRatesRBF(mLatestInputData, 
00151                                 true);
00152 
00153                         // Compute A(s').
00154                         actionIndex = updateActorOutput();
00155 
00156                         // Now the state will correctly represent s for the next update.
00157                         mFirstStep = false;
00158                 }
00159                 else
00160                 {
00161                         // The state at this point should represent s (which is the same 
00162                         // as s' from the previous update).
00163 
00164                         // Check for new active (i.e. eligible) TDConnections.  This 
00165                         // must be called after the state and actor have been updated 
00166                         // (in that order) because the actor's input TDConnections 
00167                         // depend on the pre- and post-synaptic firing rates.  Also, 
00168                         // we should not decay the eligibility traces for the new 
00169                         // active TDConnections until after we increase them at least 
00170                         // once so they don't get removed immediately.
00171                         updateActiveTDConnectionList();
00172 
00173                         // Increase eligibility traces, still using s.
00174                         mActiveValueFunctionTDConnections.increaseETraces();
00175                         mActivePolicyTDConnections.increaseETraces();
00176 
00177                         // Compute V(s).
00178                         mOldValueEstimation = updateCriticOutput();
00179 
00180                         // Update the state to represent s'.  This will add any 
00181                         // newly-active TDConnections to the active list and 
00182                         // initially increase their eligibility traces.
00183                         mStateRepresentation->updateFiringRatesRBF(mLatestInputData, 
00184                                 true);
00185 
00186                         // Compute V(s') and A(s').
00187                         mNewValueEstimation = updateCriticOutput();
00188                         actionIndex = updateActorOutput();
00189 
00190                         // TD error = r' + gamma * V(s') - V(s).
00191                         mTDError = reinforcement + mTDDiscountFactor * 
00192                                 mNewValueEstimation - mOldValueEstimation;
00193 
00194                         // Train A and V using TD error.  Note that the current 
00195                         // eligibility traces are still from state s.
00196                         trainTDRule();
00197 
00198                         // Decay eligibility traces.
00199                         mActiveValueFunctionTDConnections.decayETraces();
00200                         mActivePolicyTDConnections.decayETraces();
00201                 }
00202 
00203                 // The state representation (currently s') will implicitly become 
00204                 // s the next time this function is called.  Before this is called 
00205                 // again, we assume the Agent will take action a, letting it affect 
00206                 // the environment, which then updates r' and s' for the next 
00207                 // update.
00208 
00209                 return actionIndex;
00210         }
00211 
00212         unsigned int RLModule::updatePolicyOnly(const Observation& obs)
00213         {
00214                 // We simply choose a new action based on the given Observation.
00215 
00216                 // Keep a copy of the latest input data.
00217                 mLatestInputData.copyInputData(obs.getDiscreteInputData(), 
00218                         obs.getContinuousInputData());
00219 
00220                 // Update the state to represent s'.
00221                 mStateRepresentation->updateFiringRatesRBF(mLatestInputData, false);
00222 
00223                 // Compute and return A(s').
00224                 return updateActorOutput();
00225         }
00226 
00227         void RLModule::changeStepSize(real newValue)
00228         {
00229                 setETraceTimeConstant(mETraceTimeConstant, newValue);
00230                 setTDDiscountTimeConstant(mTDDiscountTimeConstant, newValue);
00231                 setTDLearningRate(mValueFunctionLearningTimeConstant, 
00232                         mPolicyLearningMultiplier, newValue);
00233         }
00234 
00235         void RLModule::setETraceTimeConstant(real timeConstant, real stepSize)
00236         {
00237                 mETraceTimeConstant = timeConstant;
00238                 real ETraceDecayFactor = globals::calcDecayConstant(
00239                         mETraceTimeConstant, stepSize);
00240 
00241                 if (mStateRepresentation)
00242                 {
00243                         mStateRepresentation->setPostETraceDecayFactors(
00244                                 ETraceDecayFactor);
00245                 }
00246         }
00247 
00248         void RLModule::setTDDiscountTimeConstant(real timeConstant, real stepSize)
00249         {
00250                 mTDDiscountTimeConstant = timeConstant;
00251                 mTDDiscountFactor = globals::calcDecayConstant(
00252                         mTDDiscountTimeConstant, stepSize);
00253 
00254                 if (mStateRepresentation)
00255                 {
00256                         mStateRepresentation->setPostTDDiscountFactors(
00257                                 mTDDiscountFactor);
00258                 }
00259         }
00260 
00261         void RLModule::setTDLearningRate(real valueFunctionTimeConstant, 
00262                 real policyLearningMultiplier, real stepSize)
00263         {
00264                 mValueFunctionLearningTimeConstant = valueFunctionTimeConstant;
00265                 mPolicyLearningMultiplier = policyLearningMultiplier;
00266                 mValueFunctionLearningFactor = 1 - globals::calcDecayConstant(
00267                         mValueFunctionLearningTimeConstant, stepSize);
00268 
00269                 // The learning factor should be normalized as follows: 
00270                 // 
00271                 // learning factor = learning factor / # of active features
00272                 // 
00273                 // This method allows us to change the number of active features 
00274                 // in the state representation without making learning unstable.  
00275                 // Since we're using an RBF state representation, the number of 
00276                 // active features is equal to the total sum of RBF activation.
00277                 mValueFunctionLearningFactor = mValueFunctionLearningFactor / 
00278                         mStateRepresentation->computeMaxActivationSum();
00279         }
00280 
00281         real RLModule::getTDError()
00282         {
00283                 return mTDError;
00284         }
00285 
00286         void RLModule::resetState(const Observation& obs)
00287         {
00288                 // Keep a copy of the latest input data.
00289                 mLatestInputData.copyInputData(obs.getDiscreteInputData(), 
00290                         obs.getContinuousInputData());
00291 
00292                 mStateRepresentation->updateFiringRatesRBF(mLatestInputData, 
00293                         false);
00294         }
00295 
00296         real RLModule::computeValueEstimation(const Observation& obs)
00297         {
00298                 // Note: We don't want to enable any learning here because this 
00299                 // function should have no residual effects.
00300 
00301                 // Use this input data structure to send data to the state 
00302                 // representation.
00303                 RBFInputData tempInputData;
00304                 tempInputData.init(obs);
00305 
00306                 // Temporarily update the state representation with the given 
00307                 // Observation (with learning disabled).
00308                 mStateRepresentation->updateFiringRatesRBF(tempInputData, false);
00309 
00310                 real valueEstimation = updateCriticOutput();
00311 
00312                 // Put the state back to what it was before this function was 
00313                 // called.
00314                 mStateRepresentation->updateFiringRatesRBF(mLatestInputData, 
00315                         false);
00316 
00317                 return valueEstimation;
00318         }
00319 
00320         void RLModule::saveValueData(unsigned int continuousResolution, 
00321                 const std::string& filename)
00322         {
00323                 if (mLatestInputData.numDiscInputs == 0 
00324                         && mLatestInputData.numContInputs == 0)
00325                 {
00326                         // This Agent has no inputs.
00327                         return;
00328                 }
00329 
00330                 // Check if we need to auto generate a unique filename.
00331                 std::string nameStr = filename;
00332                 if (nameStr.empty())
00333                 {
00334                         static unsigned int count = 0;
00335                         char newName[64];
00336                         sprintf(newName, "agentValueData%d.dat", count);
00337                         nameStr = newName;
00338                         ++count;
00339                 }
00340 
00341                 std::ofstream file(nameStr.c_str());
00342 
00343                 // Output a '#' for the header line.  '#' lines are ignored by 
00344                 // gnuplot.
00345                 file << "# ";
00346 
00347                 // For the header, output the number of distinct points being 
00348                 // checked along every input dimension.
00349                 for (unsigned int i = 0; i < mLatestInputData.numDiscInputs; ++i)
00350                 {
00351                         file << mLatestInputData.discNumOptionsData[i] << " ";
00352                 }
00353 
00354                 for (unsigned int i = 0; i < mLatestInputData.numContInputs; ++i)
00355                 {
00356                         file << continuousResolution << " ";
00357                 }
00358 
00359                 file << std::endl;
00360 
00361                 // Iterate over every possible combination of inputs.
00362                 RBFInputData inputdata;
00363                 inputdata.init(mLatestInputData.numDiscInputs, 
00364                         mLatestInputData.discNumOptionsData, 
00365                         mLatestInputData.discInputData, 
00366                         mLatestInputData.numContInputs, continuousResolution, 
00367                         mLatestInputData.contCircularData, 
00368                         mLatestInputData.contInputData);
00369                 unsigned numStates = inputdata.computeNumUniqueStates(
00370                         continuousResolution);
00371                 for (unsigned int i = 0; i < numStates; ++i)
00372                 {
00373                         inputdata.setToUniqueState(i, numStates, continuousResolution);
00374 
00375                         // Print the state input data.
00376                         for (unsigned int i = 0; i < mLatestInputData.numDiscInputs; ++i)
00377                         {
00378                                 file << inputdata.discInputData[i] << " ";
00379                         }
00380 
00381                         for (unsigned int i = 0; i < mLatestInputData.numContInputs; ++i)
00382                         {
00383                                 file << inputdata.contInputData[i] << " ";
00384                         }
00385 
00386                         // Compute and print the estimated value.  Temporarily update 
00387                         // the state representation with this data (with learning 
00388                         // disabled because this function should have no residual 
00389                         // effects).
00390                         mStateRepresentation->updateFiringRatesRBF(inputdata, false);
00391                         file << updateCriticOutput() << std::endl;
00392                 }
00393 
00394                 file.close();
00395 
00396                 // Put the state back to what it was before this function was 
00397                 // called.
00398                 mStateRepresentation->updateFiringRatesRBF(mLatestInputData, 
00399                         false);
00400         }
00401 
00402         void RLModule::saveStateRBFData(const std::string& filename)
00403         {
00404                 if (mLatestInputData.numDiscInputs == 0 
00405                         && mLatestInputData.numContInputs == 0)
00406                 {
00407                         // This Agent has no inputs.
00408                         return;
00409                 }
00410 
00411                 // Check if we need to auto generate a unique filename.
00412                 std::string nameStr = filename;
00413                 if (nameStr.empty())
00414                 {
00415                         static unsigned int count = 0;
00416                         char newName[64];
00417                         sprintf(newName, "agentStateRBFData%d.dat", count);
00418                         nameStr = newName;
00419                         ++count;
00420                 }
00421 
00422                 std::ofstream file(nameStr.c_str());
00423 
00424                 // Save each RBF position on a separate line with discrete data 
00425                 // before continuous data.
00426                 unsigned int numRBFs = mStateRepresentation->getNumNeurons();
00427                 for (unsigned int i = 0; i < numRBFs; ++i)
00428                 {
00429                         RBFNeuron* n = static_cast<RBFNeuron*>(
00430                                 mStateRepresentation->getNeuron(i));
00431 
00432                         unsigned int numDiscreteDimensions = 
00433                                 n->getNumDiscreteDimensions();
00434                         for (unsigned int dim = 0; dim < numDiscreteDimensions; ++dim)
00435                         {
00436                                 file << n->getDiscretePosition()[dim] << " ";
00437                         }
00438 
00439                         unsigned int numContinuousDimensions = 
00440                                 n->getNumContinuousDimensions();
00441                         for (unsigned int dim = 0; dim < numContinuousDimensions; ++dim)
00442                         {
00443                                 file << n->getContinuousPosition()[dim] << " ";
00444                         }
00445 
00446                         file << std::endl;
00447                 }
00448 
00449                 file.close();
00450         }
00451 
00452         void RLModule::updateActiveTDConnectionList()
00453         {
00454                 // This checks for new active TDConnections based on the current 
00455                 // list of active Neurons.
00456 
00457                 unsigned int numActiveNeurons = 
00458                         mStateRepresentation->getNumActiveNeurons();
00459                 for (unsigned int n = 0; n < numActiveNeurons; ++n)
00460                 {
00461                         Neuron* activeNeuron = 
00462                                 mStateRepresentation->getActiveNeuron(n);
00463 
00464                         // Look at the axons from the active Neuron.
00465                         unsigned int numAxons = activeNeuron->getNumAxons();
00466                         for (unsigned int a = 0; a < numAxons; ++a)
00467                         {
00468                                 // Assume here that the axons are TDConnections.  This 
00469                                 // will always be true since we only add Neurons to the 
00470                                 // active list that have TDConnection axons.
00471                                 TDConnection* axon = static_cast<TDConnection*>(
00472                                         activeNeuron->getAxon(a));
00473 
00474                                 // Ignore TDConnections already in the list.
00475                                 if (axon->isInActiveList())
00476                                 {
00477                                         continue;
00478                                 }
00479 
00480                                 // If the TDConnection is eligible (according to its own 
00481                                 // method for determining eligibility), we will add it to 
00482                                 // the active TDConnection list.
00483                                 switch (axon->getTDConnectionType())
00484                                 {
00485                                         case VALUE_FUNCTION_TDCONNECTION:
00486                                                 mActiveValueFunctionTDConnections.
00487                                                         addNewActiveConnection(axon);
00488                                                 break;
00489                                         case POLICY_TDCONNECTION:
00490                                                 // We already know the pre-synaptic Neuron is active, 
00491                                                 // so we'll check the post-synaptic Neuron.  This 
00492                                                 // assumes the post-synaptic Neuron never has a 
00493                                                 // negative firing rate.
00494                                                 if (axon->getPostNeuron()->getFiringRate() > 0)
00495                                                 {
00496                                                         mActivePolicyTDConnections.
00497                                                                 addNewActiveConnection(axon);
00498                                                 }
00499                                                 break;
00500                                         default:
00501                                                 assert(false);
00502                                 }
00503                         }
00504                 }
00505         }
00506 
00507         void RLModule::trainTDRule()
00508         {
00509                 mActiveValueFunctionTDConnections.trainConnections(
00510                         mValueFunctionLearningFactor * mTDError);
00511                 mActivePolicyTDConnections.trainConnections(
00512                         mValueFunctionLearningFactor * mPolicyLearningMultiplier * 
00513                         mTDError);
00514         }
00515 
00516         real RLModule::updateCriticOutput()
00517         {
00518                 mCriticPopulation->updateFiringRatesLinear();
00519                 return mCriticPopulation->getNeuron(0)->getFiringRate();
00520         }
00521 
00522         unsigned int RLModule::updateActorOutput()
00523         {
00524                 mActorPopulation->updateFiringRatesWTA();
00525                 int actionIndex = mActorPopulation->getActiveOutput();
00526                 assert(-1 != actionIndex);
00527                 return actionIndex;
00528         }
00529 }

Generated on Tue Jan 24 21:46:37 2006 for Verve by  doxygen 1.4.6-NO