00001 /************************************************************************ 00002 * Verve * 00003 * Copyright (C) 2004-2006 * 00004 * Tyler Streeter tylerstreeter@gmail.com * 00005 * All rights reserved. * 00006 * Web: http://verve-agents.sourceforge.net * 00007 * * 00008 * This library is free software; you can redistribute it and/or * 00009 * modify it under the terms of EITHER: * 00010 * (1) The GNU Lesser General Public License as published by the Free * 00011 * Software Foundation; either version 2.1 of the License, or (at * 00012 * your option) any later version. The text of the GNU Lesser * 00013 * General Public License is included with this library in the * 00014 * file license-LGPL.txt. * 00015 * (2) The BSD-style license that is included with this library in * 00016 * the file license-BSD.txt. * 00017 * * 00018 * This library is distributed in the hope that it will be useful, * 00019 * but WITHOUT ANY WARRANTY; without even the implied warranty of * 00020 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the files * 00021 * license-LGPL.txt and license-BSD.txt for more details. * 00022 ************************************************************************/ 00023 00024 #include "Agent.h" 00025 #include "Observation.h" 00026 #include "PredictiveModel.h" 00027 #include "RLModule.h" 00028 00029 namespace verve 00030 { 00031 VERVE_EXPORT_FUNCTION Agent* VERVE_CALL createAgent( 00032 const AgentDescriptor& desc) 00033 { 00034 return new Agent(desc); 00035 } 00036 00037 //VERVE_EXPORT_FUNCTION Agent* VERVE_CALL loadAgent( 00038 // const std::string& filename) 00039 //{ 00040 // Agent* a = createAgent(1, 1); 00041 // a->internal_load(filename); 00042 // return a; 00043 //} 00044 00045 Agent::Agent(const AgentDescriptor& desc) 00046 : mDescriptor(desc) 00047 { 00048 mActualPrevObs.init(*this); 00049 mPredCurrentObs.init(*this); 00050 mTempPlanningObs.init(*this); 00051 00052 mRLModule = NULL; 00053 mPredictiveModel = NULL; 00054 00055 // We will always need this. 00056 mRLModule = new RLModule(mPredCurrentObs, 00057 desc.isDynamicRBFEnabled(), desc.getNumOutputs()); 00058 00059 switch(desc.getArchitecture()) 00060 { 00061 case RL: 00062 // Do nothing extra. 00063 break; 00064 case MODEL_RL: 00065 mPredictiveModel = new PredictiveModel(mPredCurrentObs, 00066 desc.isDynamicRBFEnabled(), desc.getNumOutputs()); 00067 break; 00068 case CURIOUS_MODEL_RL: 00069 mPredictiveModel = new PredictiveModel(mPredCurrentObs, 00070 desc.isDynamicRBFEnabled(), desc.getNumOutputs()); 00071 break; 00072 default: 00073 assert(false); 00074 break; 00075 } 00076 00077 mFirstStep = true; 00078 mActionIndex = 0; 00079 mLearningEnabled = true; 00080 mStepSize = defaults::stepSize; 00081 mAgeHours = 0; 00082 mAgeMinutes = 0; 00083 mAgeSeconds = 0; 00084 mLastPlanningSequenceLength = 0; 00085 00086 // The following step size-dependent factors must be initialized 00087 // in the RLModule and PredictiveModel. 00088 setETraceTimeConstant(defaults::eTraceTimeConstant); 00089 setTDDiscountTimeConstant(defaults::TDDiscountTimeConstant); 00090 setTDLearningRate(defaults::valueFunctionLearningTimeConstant, 00091 defaults::policyLearningMultiplier); 00092 setModelLearningRate(defaults::modelLearningTimeConstant); 00093 } 00094 00095 Agent::~Agent() 00096 { 00097 delete mRLModule; 00098 if (mPredictiveModel) 00099 { 00100 delete mPredictiveModel; 00101 } 00102 } 00103 00104 void Agent::destroy() 00105 { 00106 delete this; 00107 } 00108 00109 void Agent::resetShortTermMemory() 00110 { 00111 mRLModule->resetShortTermMemory(); 00112 if (mPredictiveModel) 00113 { 00114 mPredictiveModel->resetShortTermMemory(); 00115 } 00116 00117 mFirstStep = true; 00118 mActionIndex = 0; 00119 mActualPrevObs.zeroInputData(); 00120 mPredCurrentObs.zeroInputData(); 00121 mTempPlanningObs.zeroInputData(); 00122 mLastPlanningSequenceLength = 0; 00123 } 00124 00125 unsigned int Agent::update(real reinforcement, const Observation& obs, 00126 real dt) 00127 { 00128 assert(reinforcement >= -1 && reinforcement <= 1); 00129 00130 if (dt != mStepSize) 00131 { 00132 // We only need to recompute step size-dependent things when 00133 // the dt changes. Thus, as long as the dt is constant between 00134 // successive updates, we rarely need to recompute these things. 00135 // It also keeps the API simple (instead of having a public 00136 // 'setStepSize(real)' and 'update()'). 00137 setStepSize(dt); 00138 } 00139 00140 if (mLearningEnabled) 00141 { 00142 incrementAge(); 00143 00144 if (mDescriptor.getArchitecture() == RL) 00145 { 00146 mActionIndex = mRLModule->update(obs, reinforcement); 00147 } 00148 else 00149 { 00150 // Do planning. 00151 00152 // 2 PLANNING METHODS (currently, METHOD 1 is used): 00153 // 1. RL component always uses predicted observation 00154 // and reward (i.e. predictive state 00155 // representation). 00156 // 2. RL component only uses predicted obs and reward 00157 // when in planning mode; otherwise, it uses the 00158 // actual values. 00159 00160 if (mFirstStep) 00161 { 00162 // We must handle the first step differently 00163 // because we do not yet have a valid previous-step 00164 // Observation. We do not update the model, and we 00165 // simply update the RLModule with the actual 00166 // Observation and reward. 00167 mActionIndex = mRLModule->update(obs, reinforcement); 00168 } 00169 else 00170 { 00171 // We train the predictive model once per update with 00172 // the current Observation and reward, then we train 00173 // the RLModule during planning. Planning sequences 00174 // proceed until the prediction uncertainty is too 00175 // high or we exceed the max planning sequence length. 00176 // These sequences might have zero length at first. 00177 // Whether we do planning or not, at the end we need 00178 // to have the RLModule's policy choose a new action. 00179 00180 // Get the predicted current Observation and reward 00181 // from the model. This also trains the model based 00182 // on the actual data. At this point 'mActionIndex' 00183 // represents the action from the previous step. 00184 // Along with the predicted current Observation and 00185 // reward, this returns a prediction uncertainty 00186 // estimation. 00187 real predCurrentReward = 0; 00188 real predictionUncertainty = 0; 00189 mPredictiveModel->predictAndTrain(mActualPrevObs, 00190 mActionIndex, obs, reinforcement, 00191 mPredCurrentObs, predCurrentReward, 00192 predictionUncertainty); 00193 00194 // Perform a planning sequence to train the RLModule. 00195 mLastPlanningSequenceLength = planningSequence( 00196 mPredCurrentObs, predCurrentReward, 00197 predictionUncertainty); 00198 00199 // Have the RLModule's policy choose a new action from 00200 // the predicted current Observation. 00201 mActionIndex = mRLModule->updatePolicyOnly( 00202 mPredCurrentObs); 00203 } 00204 } 00205 } 00206 else 00207 { 00208 if (mDescriptor.getArchitecture() == RL) 00209 { 00210 mActionIndex = mRLModule->updatePolicyOnly(obs); 00211 } 00212 else 00213 { 00214 if (mFirstStep) 00215 { 00216 // We must handle the first step differently 00217 // because we do not yet have a valid previous-step 00218 // Observation. We do not update the model, and we 00219 // simply update the RLModule's policy with the 00220 // actual Observation. 00221 mActionIndex = mRLModule->updatePolicyOnly(obs); 00222 } 00223 else 00224 { 00225 // At this point 'mActionIndex' represents the 00226 // action from the previous step. Do not allow 00227 // dynamic RBF creation here. Since we're not 00228 // learning here, we just ignore the predicted 00229 // reward and uncertainty. 00230 real predCurrentReward = 0; 00231 real predictionUncertainty = 0; 00232 mPredictiveModel->predict(mActualPrevObs, 00233 mActionIndex, mPredCurrentObs, 00234 predCurrentReward, predictionUncertainty, 00235 false); 00236 mActionIndex = mRLModule->updatePolicyOnly( 00237 mPredCurrentObs); 00238 } 00239 } 00240 } 00241 00242 if (mFirstStep) 00243 { 00244 mFirstStep = false; 00245 } 00246 00247 // Store a copy of the current actual Observation for next time. 00248 mActualPrevObs.copyInputData(obs); 00249 00250 return mActionIndex; 00251 } 00252 00253 unsigned int Agent::planningSequence(const Observation& predCurrentObs, 00254 real predCurrentReward, real currentUncertainty) 00255 { 00256 unsigned int numPlanningSteps = 0; 00257 00258 // Continue planning as long as uncertainty is low enough. This 00259 // uses a predictive state representation: all inputs to the 00260 // RLModule are predicted. 00261 const real uncertaintyThreshold = 00262 mDescriptor.getPlanningUncertaintyThreshold(); 00263 if (currentUncertainty < uncertaintyThreshold) 00264 { 00265 // Make sure the RLModule is fresh; it should not have any STM 00266 // left over from previous updates at this point. 00267 mRLModule->resetShortTermMemory(); 00268 00269 // Setup temporary data to be used during planning. 00270 mTempPlanningObs.copyInputData(predCurrentObs); 00271 unsigned int tempActionIndex = 0; 00272 00273 // We continue the planning sequence until either: 1) the 00274 // prediction uncertainty is too high, or 2) we exceed the 00275 // max plan length. Note that the RLModule will only learn 00276 // here if the number of planning steps is at least two. This 00277 // is because the RLModule must have data from two subsequent 00278 // steps. 00279 while (numPlanningSteps < mDescriptor.getMaxNumPlanningSteps()) 00280 { 00281 real totalReward = 0; 00282 00283 if (mDescriptor.getArchitecture() == CURIOUS_MODEL_RL) 00284 { 00285 // Add in curiosity rewards. 00286 00287 // The current reward equals the predicted reward plus 00288 // an extra curiosity reward proportional to 00289 // uncertainty. Clamp the total reward to +1. 00290 totalReward = predCurrentReward + 3 * currentUncertainty; 00291 if (totalReward > 1) 00292 { 00293 totalReward = 1; 00294 } 00295 } 00296 else 00297 { 00299 //totalReward = currentUncertainty; 00300 00301 totalReward = predCurrentReward; 00302 } 00303 00304 // Give the RLModule the current predicted Observation and 00305 // reward. 00306 tempActionIndex = mRLModule->update(mTempPlanningObs, 00307 totalReward); 00308 00309 if (currentUncertainty > uncertaintyThreshold) 00310 { 00311 // Predicted uncertainty is too high, so we'll stop the 00312 // planning sequence. 00313 break; 00314 } 00315 00316 // If this is not the last step, get new predicted data 00317 // from the model for the next step. 00318 if (numPlanningSteps < mDescriptor.getMaxNumPlanningSteps()) 00319 { 00320 // We are allowing RBF creation here; this is probably 00321 // necessary for planning trajectories that enter new 00322 // territory. 00323 real predictionUncertainty = 0; 00324 mPredictiveModel->predict(mTempPlanningObs, 00325 tempActionIndex, mTempPlanningObs, 00326 predCurrentReward, predictionUncertainty, true); 00327 00328 // Make the current uncertainty equal to the latest 00329 // prediction's uncertainty estimation. 00330 currentUncertainty = predictionUncertainty; 00331 00332 // TODO: Instead of the previous line, try accumulating 00333 // uncertainty here. This hasn't been tested yet, but 00334 // it seems more realistic. 00335 //currentUncertainty += predictionUncertainty; 00336 } 00337 00338 ++numPlanningSteps; 00339 } 00340 } 00341 00342 return numPlanningSteps; 00343 } 00344 00345 //void Agent::internal_load(const std::string& filename) 00346 //{ 00347 // TiXmlDocument file; 00348 // bool success = file.LoadFile(filename.c_str()); 00349 // if (!success) 00350 // { 00351 // VERVE_LOGGER("warning") << 00352 // "verve::Agent::load: Failed to load XML file " 00353 // << filename << "." << std::endl; 00354 // return; 00355 // } 00356 00357 // // Find the root element (i.e. the 'VerveAgent' element). 00358 // TiXmlElement* rootElement = file.RootElement(); 00359 // if (NULL == rootElement) 00360 // { 00361 // VERVE_LOGGER("warning") << 00362 // "verve::Agent::load: Missing root element in " 00363 // << filename << ". Ignoring file." << std::endl; 00364 // return; 00365 // } 00366 00367 // // Load the Agent's age attributes. 00368 // mAgeHours = globals::getAttributeInt(rootElement, "ageHours"); 00369 // mAgeMinutes = globals::getAttributeInt(rootElement, "ageMinutes"); 00370 // mAgeSeconds = globals::getAttributeReal(rootElement, "ageSeconds"); 00371 00372 // if (!mBrain->load(rootElement)) 00373 // { 00374 // VERVE_LOGGER("warning") << 00375 // "verve::Agent::load: Could not load file " 00376 // << filename << std::endl; 00377 // } 00378 00379 // // Update: there are no longer different types of Agents. 00380 // //// Make sure this is the right type of Agent. 00381 // //std::string type = globals::getAttributeString(rootElement, "type"); 00382 // //if ("RL" != type) 00383 // //{ 00384 // // VERVE_LOGGER("warning") << 00385 // // "verve::Agent::load: Wrong Agent type found. Expected " 00386 // // << "type RL but found " << type << ". Ignoring file " 00387 // // << filename << "." << std::endl; 00388 // // return; 00389 // //} 00390 00391 // //success = true; 00392 00393 // //// Check if the actor element exists. 00394 // //TiXmlNode* actorNode = rootElement->FirstChild("Actor"); 00395 // //if (!actorNode) 00396 // //{ 00397 // // VERVE_LOGGER("warning") << 00398 // // "verve::Agent::load: Actor element not found in " 00399 // // << filename << ". Ignoring file." << std::endl; 00400 // // success = false; 00401 // //} 00402 00403 // //// Check if the positive critic element exists. 00404 // //TiXmlNode* posCriticNode = rootElement->FirstChild("PositiveCritic"); 00405 // //if (!posCriticNode) 00406 // //{ 00407 // // VERVE_LOGGER("warning") << 00408 // // "verve::Agent::load: PositiveCritic element not found in " 00409 // // << filename << ". Ignoring file." << std::endl; 00410 // // success = false; 00411 // //} 00412 00413 // //// Check if the negative critic element exists. 00414 // //TiXmlNode* negCriticNode = rootElement->FirstChild("NegativeCritic"); 00415 // //if (!negCriticNode) 00416 // //{ 00417 // // VERVE_LOGGER("warning") << 00418 // // "verve::Agent::load: NegativeCritic element not found in " 00419 // // << filename << ". Ignoring file." << std::endl; 00420 // // success = false; 00421 // //} 00422 00423 // //if (success) 00424 // //{ 00425 // // // Load the actor and critic NeuralArchitectures. 00426 00427 // // success = mActor->load(actorNode, 00428 // // defaults::actorMembraneTimeConstant, 00429 // // defaults::actorLearningRatePercent, 00430 // // defaults::actorETraceDecayPercent, 00431 // // defaults::actorMaxExplorationNoise); 00432 00433 // // if (!success) 00434 // // { 00435 // // VERVE_LOGGER("warning") << 00436 // // "verve::Agent::load: Could not load file " 00437 // // << filename << std::endl; 00438 // // } 00439 00440 // // success = mPositiveCritic->load(posCriticNode, 00441 // // defaults::criticMembraneTimeConstant, 00442 // // defaults::criticLearningRatePercent, 00443 // // defaults::criticETraceDecayPercent, 0); 00444 // // if (!success) 00445 // // { 00446 // // VERVE_LOGGER("warning") << 00447 // // "verve::Agent::load: Could not load file " 00448 // // << filename << std::endl; 00449 // // } 00450 00451 // // success = mNegativeCritic->load(negCriticNode, 00452 // // defaults::criticMembraneTimeConstant, 00453 // // defaults::criticLearningRatePercent, 00454 // // defaults::criticETraceDecayPercent, 0); 00455 // // if (!success) 00456 // // { 00457 // // VERVE_LOGGER("warning") << 00458 // // "verve::Agent::load: Could not load file " 00459 // // << filename << std::endl; 00460 // // } 00461 // //} 00462 //} 00463 00464 //void Agent::save(const std::string& filename) 00465 //{ 00466 // // Check if we need to auto generate a unique filename. 00467 // std::string nameStr = filename; 00468 // if (nameStr.empty()) 00469 // { 00470 // static unsigned int count = 0; 00471 // char newName[64]; 00472 // sprintf(newName, "agent%d_age-%dh.xml", count, (int)mAgeHours); 00473 // nameStr = newName; 00474 // ++count; 00475 // } 00476 00477 // TiXmlDocument file(nameStr.c_str()); 00478 // 00479 // // Add the XML declaration. 00480 // TiXmlDeclaration declaration("1.0", "", "yes"); 00481 // file.InsertEndChild(declaration); 00482 // 00483 // // Create the root element. 00484 // TiXmlElement rootElement("VerveAgent"); 00485 00486 // // Set the Agent's age attributes. 00487 // rootElement.SetAttribute("ageHours", (int)mAgeHours); 00488 // rootElement.SetAttribute("ageMinutes", (int)mAgeMinutes); 00489 // rootElement.SetDoubleAttribute("ageSeconds", mAgeSeconds); 00490 00491 // // Update: there are no longer different types of Agents. 00492 // //// Set the Agent's type attribute. 00493 // //rootElement.SetAttribute("type", "RL"); 00494 00495 // //// Create the actor element and add it to the root element. 00496 // //TiXmlElement actorElement("Actor"); 00497 // //if (mActor->save(&actorElement)) 00498 // //{ 00499 // // rootElement.InsertEndChild(actorElement); 00500 // //} 00501 // //else 00502 // //{ 00503 // // VERVE_LOGGER("warning") << 00504 // // "verve::Agent::save: Could not save file " 00505 // // << nameStr << std::endl; 00506 // // return; 00507 // //} 00508 00509 // //// Create the positive critic element and add it to the root element. 00510 // //TiXmlElement posCriticElement("PositiveCritic"); 00511 // //if (mPositiveCritic->save(&posCriticElement)) 00512 // //{ 00513 // // rootElement.InsertEndChild(posCriticElement); 00514 // //} 00515 // //else 00516 // //{ 00517 // // VERVE_LOGGER("warning") << 00518 // // "verve::Agent::save: Could not save file " 00519 // // << nameStr << std::endl; 00520 // // return; 00521 // //} 00522 00523 // //// Create the negative critic element and add it to the root element. 00524 // //TiXmlElement negCriticElement("NegativeCritic"); 00525 // //if (mNegativeCritic->save(&negCriticElement)) 00526 // //{ 00527 // // rootElement.InsertEndChild(negCriticElement); 00528 // //} 00529 // //else 00530 // //{ 00531 // // VERVE_LOGGER("warning") << 00532 // // "verve::Agent::save: Could not save file " 00533 // // << nameStr << std::endl; 00534 // // return; 00535 // //} 00536 00537 // // Fill the root element. 00538 // if (!mBrain->save(&rootElement)) 00539 // { 00540 // VERVE_LOGGER("warning") << 00541 // "verve::Agent::save: Could not save file " 00542 // << nameStr << std::endl; 00543 // return; 00544 // } 00545 00546 // // Add the root element to the file. 00547 // file.InsertEndChild(rootElement); 00548 // 00549 // // Now save the document to a file. 00550 // if (false == file.SaveFile()) 00551 // { 00552 // VERVE_LOGGER("warning") << 00553 // "verve::Agent::save: Failed to save XML file " 00554 // << nameStr << "." << std::endl; 00555 // } 00556 //} 00557 00558 unsigned int Agent::getNumDiscreteSensors()const 00559 { 00560 return mDescriptor.getNumDiscreteSensors(); 00561 } 00562 00563 unsigned int Agent::getNumContinuousSensors()const 00564 { 00565 return mDescriptor.getNumContinuousSensors(); 00566 } 00567 00568 void Agent::setStepSize(real value) 00569 { 00570 mStepSize = value; 00571 00572 // These components must be notified of the step size change. 00573 mRLModule->changeStepSize(value); 00574 if (mPredictiveModel) 00575 { 00576 mPredictiveModel->changeStepSize(value); 00577 } 00578 } 00579 00580 void Agent::setETraceTimeConstant(real timeConstant) 00581 { 00582 mRLModule->setETraceTimeConstant(timeConstant, mStepSize); 00583 } 00584 00585 void Agent::setTDDiscountTimeConstant(real timeConstant) 00586 { 00587 mRLModule->setTDDiscountTimeConstant(timeConstant, mStepSize); 00588 } 00589 00590 void Agent::setTDLearningRate(real valueFunctionTimeConstant, 00591 real policyLearningMultiplier) 00592 { 00593 mRLModule->setTDLearningRate(valueFunctionTimeConstant, 00594 policyLearningMultiplier, mStepSize); 00595 } 00596 00597 void Agent::setModelLearningRate(real timeConstant) 00598 { 00599 if (mPredictiveModel) 00600 { 00601 mPredictiveModel->setDeltaLearningRate(timeConstant, mStepSize); 00602 } 00603 } 00604 00605 void Agent::setLearningEnabled(bool enabled) 00606 { 00607 mLearningEnabled = enabled; 00608 } 00609 00610 real Agent::computeValueEstimation(const Observation& obs) 00611 { 00612 return mRLModule->computeValueEstimation(obs); 00613 } 00614 00615 const AgentDescriptor* Agent::getDescriptor()const 00616 { 00617 return &mDescriptor; 00618 } 00619 00620 long unsigned int Agent::getAge()const 00621 { 00622 return (3600 * mAgeHours + 60 * mAgeMinutes + 00623 (long unsigned int)mAgeSeconds); 00624 } 00625 00626 std::string Agent::getAgeString()const 00627 { 00628 char temp[32]; 00629 sprintf(temp, "%dh, %dm, %fs", mAgeHours, mAgeMinutes, mAgeSeconds); 00630 return std::string(temp); 00631 } 00632 00633 real Agent::getTDError()const 00634 { 00635 return mRLModule->getTDError(); 00636 } 00637 00638 real Agent::getModelMSE()const 00639 { 00640 if (mPredictiveModel) 00641 { 00642 return mPredictiveModel->getPredictionMSE(); 00643 } 00644 else 00645 { 00646 return 0; 00647 } 00648 } 00649 00650 unsigned int Agent::getLastPlanLength()const 00651 { 00652 return mLastPlanningSequenceLength; 00653 } 00654 00655 void Agent::saveValueData(unsigned int continuousResolution, 00656 const std::string& filename) 00657 { 00658 mRLModule->saveValueData(continuousResolution, filename); 00659 } 00660 00661 void Agent::saveStateRBFData(const std::string& filename) 00662 { 00663 mRLModule->saveStateRBFData(filename); 00664 } 00665 00666 void Agent::incrementAge() 00667 { 00668 mAgeSeconds += mStepSize; 00669 if (mAgeSeconds >= 60) 00670 { 00671 mAgeMinutes += 1; 00672 mAgeSeconds -= 60; 00673 00674 if (60 == mAgeMinutes) 00675 { 00676 mAgeHours += 1; 00677 mAgeMinutes = 0; 00678 } 00679 } 00680 } 00681 }