游戏开发论坛

 找回密码
 立即注册
搜索
查看: 1261|回复: 0

训练神经清单,把每个成员的作用都理解了 ,就O了

[复制链接]

3

主题

12

帖子

12

积分

新手上路

Rank: 1

积分
12
发表于 2009-1-8 12:23:00 | 显示全部楼层 |阅读模式
#pragma once

#include <vector>

using namespace std;

enum
{
    NLT_INPUT,
    NLT_HIDDEN,
    NLT_OUTPUT
};

enum
{
    ACT_LOGISTIC,
    ACT_BIPOLAR,
    ACT_STEP,
    ACT_TANH,
    ACT_SOFTMAX,
    ACT_LINEAR
};

struct Neuron
{
    vector<float> m_weights;   //connection strengths
    vector<float> m_lastDelta;//used for inertia in updating the weights while learning
    float m_output;           //the fired potential of the neuron
    float m_error;            //the error gradient of the potential from the expected
                              //potential; used when learning
};


class NLayer
{
public:
    NLayer(int nNeurons, int nInputs, int type = NLT_INPUT);
    void Propagate(int type, NLayer& nextLayer);
    void BackPropagate(int type, NLayer& nextLayer);
    void AdjustWeights(NLayer& inputs,float lrate = 0.1f, float momentum = 0.9f);
   
    //activation functions
    float ActLogistic(float value);
    float ActStep(float value);
    float ActTanh(float value);
    float ActBipolarSigmoid(float value);
    void  ActSoftmax(NLayer& outputs);

    //inverse functions for backprop
    float DerLogistic(float value);
    float DerTanh(float value);
    float DerBipolarSigmoid(float value);
   
    //data
    vector<Neuron*> m_neurons;
    int             m_type;
    float           m_threshold;
};


class NeuralNet
{
   
public:
   
    NeuralNet(int nIns,int nOuts,int nHiddenLays,int nNodesinHiddenLays);
    void Init();
   
    //access methods   
    void Use(vector<float> &inputs,vector<float> &outputs);
    void Train(vector<float> &inputs,vector<float> &outputs);
    float GetError()    {return m_error;}
    void WriteWeights();
    void ReadWeights();

protected:
    //internal methods
    void AddLayer(int nNeurons,int nInputs,int type);
    void SetInputs(vector<float>& inputs);
    void FindError(vector<float>& outputs);
    void Propagate();
    void BackPropagate();

    //data
    vector<NLayer>  m_layers;
    NLayer*         m_inputLayer;
    NLayer*         m_outputLayer;

    float           m_learningRate;
    float           m_momentum;
    float           m_error;

    int             m_nInputs;
    int             m_nOutputs;
    int             m_nLayers;
    int             m_nHiddenNodesperLayer;
    int             m_actType;
    int             m_outputActType;
};

[em1] [em1]
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

作品发布|文章投稿|广告合作|关于本站|游戏开发论坛 ( 闽ICP备17032699号-3 )

GMT+8, 2026-1-20 13:37

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表