00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #include <fstream>
00025
00026 #include "RLModule.h"
00027 #include "Observation.h"
00028 #include "Population.h"
00029 #include "Projection.h"
00030 #include "RBFNeuron.h"
00031 #include "RBFPopulation.h"
00032 #include "TDProjection.h"
00033 #include "TDConnection.h"
00034 #include "WinnerTakeAllPopulation.h"
00035
00036 namespace verve
00037 {
00038 RLModule::RLModule(const Observation& obs, bool isDynamicRBFEnabled,
00039 unsigned int numActions)
00040 {
00041
00042
00043 mLatestInputData.init(obs);
00044
00045
00046 mStateRepresentation = new RBFPopulation();
00047 mStateRepresentation->init(mLatestInputData, isDynamicRBFEnabled);
00048 mAllPopulations.push_back(mStateRepresentation);
00049
00050
00051 mActorPopulation = new WinnerTakeAllPopulation();
00052 mActorPopulation->init(numActions);
00053 mAllPopulations.push_back(mActorPopulation);
00054
00055
00056 mCriticPopulation = new Population();
00057 mCriticPopulation->init(1);
00058 mAllPopulations.push_back(mCriticPopulation);
00059
00060
00061
00062
00063 mStateRepresentation->projectTD(mActorPopulation,
00064 POLICY_TDCONNECTION, WEIGHTS_NEAR_0,
00065 mStateRepresentation->computeMaxActivationSum());
00066
00067
00068
00069 mStateRepresentation->projectTD(mCriticPopulation,
00070 VALUE_FUNCTION_TDCONNECTION, WEIGHTS_NEAR_0,
00071 mStateRepresentation->computeMaxActivationSum());
00072
00073 mFirstStep = true;
00074 mTDError = 0;
00075 mOldValueEstimation = 0;
00076 mNewValueEstimation = 0;
00077 mETraceTimeConstant = 0;
00078 mTDDiscountTimeConstant = 0;
00079 mTDDiscountFactor = 0;
00080 mValueFunctionLearningTimeConstant = 0;
00081 mValueFunctionLearningFactor = 0;
00082 mPolicyLearningMultiplier = 0;
00083 }
00084
00085 RLModule::~RLModule()
00086 {
00087
00088
00089 while (!mAllPopulations.empty())
00090 {
00091 delete mAllPopulations.back();
00092 mAllPopulations.pop_back();
00093 }
00094 }
00095
00096 void RLModule::resetShortTermMemory()
00097 {
00098 mLatestInputData.zeroInputData();
00099
00100 unsigned int size = (unsigned int)mAllPopulations.size();
00101 for (unsigned int i = 0; i < size; ++i)
00102 {
00103 mAllPopulations[i]->resetShortTermMemory();
00104 }
00105
00106
00107 mActiveValueFunctionTDConnections.clearList();
00108 mActivePolicyTDConnections.clearList();
00109
00110 mFirstStep = true;
00111 mTDError = 0;
00112 mOldValueEstimation = 0;
00113 mNewValueEstimation = 0;
00114 }
00115
00116 unsigned int RLModule::update(const Observation& obs, real reinforcement)
00117 {
00118
00119 mLatestInputData.copyInputData(obs.getDiscreteInputData(),
00120 obs.getContinuousInputData());
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138 unsigned int actionIndex = 0;
00139
00140 if (mFirstStep)
00141 {
00142
00143
00144
00145
00146
00147
00148
00149
00150 mStateRepresentation->updateFiringRatesRBF(mLatestInputData,
00151 true);
00152
00153
00154 actionIndex = updateActorOutput();
00155
00156
00157 mFirstStep = false;
00158 }
00159 else
00160 {
00161
00162
00163
00164
00165
00166
00167
00168
00169
00170
00171 updateActiveTDConnectionList();
00172
00173
00174 mActiveValueFunctionTDConnections.increaseETraces();
00175 mActivePolicyTDConnections.increaseETraces();
00176
00177
00178 mOldValueEstimation = updateCriticOutput();
00179
00180
00181
00182
00183 mStateRepresentation->updateFiringRatesRBF(mLatestInputData,
00184 true);
00185
00186
00187 mNewValueEstimation = updateCriticOutput();
00188 actionIndex = updateActorOutput();
00189
00190
00191 mTDError = reinforcement + mTDDiscountFactor *
00192 mNewValueEstimation - mOldValueEstimation;
00193
00194
00195
00196 trainTDRule();
00197
00198
00199 mActiveValueFunctionTDConnections.decayETraces();
00200 mActivePolicyTDConnections.decayETraces();
00201 }
00202
00203
00204
00205
00206
00207
00208
00209 return actionIndex;
00210 }
00211
00212 unsigned int RLModule::updatePolicyOnly(const Observation& obs)
00213 {
00214
00215
00216
00217 mLatestInputData.copyInputData(obs.getDiscreteInputData(),
00218 obs.getContinuousInputData());
00219
00220
00221 mStateRepresentation->updateFiringRatesRBF(mLatestInputData, false);
00222
00223
00224 return updateActorOutput();
00225 }
00226
00227 void RLModule::changeStepSize(real newValue)
00228 {
00229 setETraceTimeConstant(mETraceTimeConstant, newValue);
00230 setTDDiscountTimeConstant(mTDDiscountTimeConstant, newValue);
00231 setTDLearningRate(mValueFunctionLearningTimeConstant,
00232 mPolicyLearningMultiplier, newValue);
00233 }
00234
00235 void RLModule::setETraceTimeConstant(real timeConstant, real stepSize)
00236 {
00237 mETraceTimeConstant = timeConstant;
00238 real ETraceDecayFactor = globals::calcDecayConstant(
00239 mETraceTimeConstant, stepSize);
00240
00241 if (mStateRepresentation)
00242 {
00243 mStateRepresentation->setPostETraceDecayFactors(
00244 ETraceDecayFactor);
00245 }
00246 }
00247
00248 void RLModule::setTDDiscountTimeConstant(real timeConstant, real stepSize)
00249 {
00250 mTDDiscountTimeConstant = timeConstant;
00251 mTDDiscountFactor = globals::calcDecayConstant(
00252 mTDDiscountTimeConstant, stepSize);
00253
00254 if (mStateRepresentation)
00255 {
00256 mStateRepresentation->setPostTDDiscountFactors(
00257 mTDDiscountFactor);
00258 }
00259 }
00260
00261 void RLModule::setTDLearningRate(real valueFunctionTimeConstant,
00262 real policyLearningMultiplier, real stepSize)
00263 {
00264 mValueFunctionLearningTimeConstant = valueFunctionTimeConstant;
00265 mPolicyLearningMultiplier = policyLearningMultiplier;
00266 mValueFunctionLearningFactor = 1 - globals::calcDecayConstant(
00267 mValueFunctionLearningTimeConstant, stepSize);
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277 mValueFunctionLearningFactor = mValueFunctionLearningFactor /
00278 mStateRepresentation->computeMaxActivationSum();
00279 }
00280
00281 real RLModule::getTDError()
00282 {
00283 return mTDError;
00284 }
00285
00286 void RLModule::resetState(const Observation& obs)
00287 {
00288
00289 mLatestInputData.copyInputData(obs.getDiscreteInputData(),
00290 obs.getContinuousInputData());
00291
00292 mStateRepresentation->updateFiringRatesRBF(mLatestInputData,
00293 false);
00294 }
00295
00296 real RLModule::computeValueEstimation(const Observation& obs)
00297 {
00298
00299
00300
00301
00302
00303 RBFInputData tempInputData;
00304 tempInputData.init(obs);
00305
00306
00307
00308 mStateRepresentation->updateFiringRatesRBF(tempInputData, false);
00309
00310 real valueEstimation = updateCriticOutput();
00311
00312
00313
00314 mStateRepresentation->updateFiringRatesRBF(mLatestInputData,
00315 false);
00316
00317 return valueEstimation;
00318 }
00319
00320 void RLModule::saveValueData(unsigned int continuousResolution,
00321 const std::string& filename)
00322 {
00323 if (mLatestInputData.numDiscInputs == 0
00324 && mLatestInputData.numContInputs == 0)
00325 {
00326
00327 return;
00328 }
00329
00330
00331 std::string nameStr = filename;
00332 if (nameStr.empty())
00333 {
00334 static unsigned int count = 0;
00335 char newName[64];
00336 sprintf(newName, "agentValueData%d.dat", count);
00337 nameStr = newName;
00338 ++count;
00339 }
00340
00341 std::ofstream file(nameStr.c_str());
00342
00343
00344
00345 file << "# ";
00346
00347
00348
00349 for (unsigned int i = 0; i < mLatestInputData.numDiscInputs; ++i)
00350 {
00351 file << mLatestInputData.discNumOptionsData[i] << " ";
00352 }
00353
00354 for (unsigned int i = 0; i < mLatestInputData.numContInputs; ++i)
00355 {
00356 file << continuousResolution << " ";
00357 }
00358
00359 file << std::endl;
00360
00361
00362 RBFInputData inputdata;
00363 inputdata.init(mLatestInputData.numDiscInputs,
00364 mLatestInputData.discNumOptionsData,
00365 mLatestInputData.discInputData,
00366 mLatestInputData.numContInputs, continuousResolution,
00367 mLatestInputData.contCircularData,
00368 mLatestInputData.contInputData);
00369 unsigned numStates = inputdata.computeNumUniqueStates(
00370 continuousResolution);
00371 for (unsigned int i = 0; i < numStates; ++i)
00372 {
00373 inputdata.setToUniqueState(i, numStates, continuousResolution);
00374
00375
00376 for (unsigned int i = 0; i < mLatestInputData.numDiscInputs; ++i)
00377 {
00378 file << inputdata.discInputData[i] << " ";
00379 }
00380
00381 for (unsigned int i = 0; i < mLatestInputData.numContInputs; ++i)
00382 {
00383 file << inputdata.contInputData[i] << " ";
00384 }
00385
00386
00387
00388
00389
00390 mStateRepresentation->updateFiringRatesRBF(inputdata, false);
00391 file << updateCriticOutput() << std::endl;
00392 }
00393
00394 file.close();
00395
00396
00397
00398 mStateRepresentation->updateFiringRatesRBF(mLatestInputData,
00399 false);
00400 }
00401
00402 void RLModule::saveStateRBFData(const std::string& filename)
00403 {
00404 if (mLatestInputData.numDiscInputs == 0
00405 && mLatestInputData.numContInputs == 0)
00406 {
00407
00408 return;
00409 }
00410
00411
00412 std::string nameStr = filename;
00413 if (nameStr.empty())
00414 {
00415 static unsigned int count = 0;
00416 char newName[64];
00417 sprintf(newName, "agentStateRBFData%d.dat", count);
00418 nameStr = newName;
00419 ++count;
00420 }
00421
00422 std::ofstream file(nameStr.c_str());
00423
00424
00425
00426 unsigned int numRBFs = mStateRepresentation->getNumNeurons();
00427 for (unsigned int i = 0; i < numRBFs; ++i)
00428 {
00429 RBFNeuron* n = static_cast<RBFNeuron*>(
00430 mStateRepresentation->getNeuron(i));
00431
00432 unsigned int numDiscreteDimensions =
00433 n->getNumDiscreteDimensions();
00434 for (unsigned int dim = 0; dim < numDiscreteDimensions; ++dim)
00435 {
00436 file << n->getDiscretePosition()[dim] << " ";
00437 }
00438
00439 unsigned int numContinuousDimensions =
00440 n->getNumContinuousDimensions();
00441 for (unsigned int dim = 0; dim < numContinuousDimensions; ++dim)
00442 {
00443 file << n->getContinuousPosition()[dim] << " ";
00444 }
00445
00446 file << std::endl;
00447 }
00448
00449 file.close();
00450 }
00451
00452 void RLModule::updateActiveTDConnectionList()
00453 {
00454
00455
00456
00457 unsigned int numActiveNeurons =
00458 mStateRepresentation->getNumActiveNeurons();
00459 for (unsigned int n = 0; n < numActiveNeurons; ++n)
00460 {
00461 Neuron* activeNeuron =
00462 mStateRepresentation->getActiveNeuron(n);
00463
00464
00465 unsigned int numAxons = activeNeuron->getNumAxons();
00466 for (unsigned int a = 0; a < numAxons; ++a)
00467 {
00468
00469
00470
00471 TDConnection* axon = static_cast<TDConnection*>(
00472 activeNeuron->getAxon(a));
00473
00474
00475 if (axon->isInActiveList())
00476 {
00477 continue;
00478 }
00479
00480
00481
00482
00483 switch (axon->getTDConnectionType())
00484 {
00485 case VALUE_FUNCTION_TDCONNECTION:
00486 mActiveValueFunctionTDConnections.
00487 addNewActiveConnection(axon);
00488 break;
00489 case POLICY_TDCONNECTION:
00490
00491
00492
00493
00494 if (axon->getPostNeuron()->getFiringRate() > 0)
00495 {
00496 mActivePolicyTDConnections.
00497 addNewActiveConnection(axon);
00498 }
00499 break;
00500 default:
00501 assert(false);
00502 }
00503 }
00504 }
00505 }
00506
00507 void RLModule::trainTDRule()
00508 {
00509 mActiveValueFunctionTDConnections.trainConnections(
00510 mValueFunctionLearningFactor * mTDError);
00511 mActivePolicyTDConnections.trainConnections(
00512 mValueFunctionLearningFactor * mPolicyLearningMultiplier *
00513 mTDError);
00514 }
00515
00516 real RLModule::updateCriticOutput()
00517 {
00518 mCriticPopulation->updateFiringRatesLinear();
00519 return mCriticPopulation->getNeuron(0)->getFiringRate();
00520 }
00521
00522 unsigned int RLModule::updateActorOutput()
00523 {
00524 mActorPopulation->updateFiringRatesWTA();
00525 int actionIndex = mActorPopulation->getActiveOutput();
00526 assert(-1 != actionIndex);
00527 return actionIndex;
00528 }
00529 }