#include <nnet-rbm.h>


Public Types | |
| enum | RbmNodeType { Bernoulli, Gaussian } |
Public Types inherited from Component | |
| enum | ComponentType { kUnknown = 0x0, kUpdatableComponent = 0x0100, kAffineTransform, kLinearTransform, kConvolutionalComponent, kLstmProjected, kBlstmProjected, kRecurrentComponent, kActivationFunction = 0x0200, kSoftmax, kHiddenSoftmax, kBlockSoftmax, kSigmoid, kTanh, kParametricRelu, kDropout, kLengthNormComponent, kTranform = 0x0400, kRbm, kSplice, kCopy, kTranspose, kBlockLinearity, kAddShift, kRescale, kKlHmm = 0x0800, kSentenceAveragingComponent, kSimpleSentenceAveragingComponent, kAveragePoolingComponent, kMaxPoolingComponent, kFramePoolingComponent, kParallelComponent, kMultiBasisComponent } |
| Component type identification mechanism,. More... | |
Public Member Functions | |
| RbmBase (int32 dim_in, int32 dim_out) | |
| virtual void | Reconstruct (const CuMatrixBase< BaseFloat > &hid_state, CuMatrix< BaseFloat > *vis_probs)=0 |
| virtual void | RbmUpdate (const CuMatrixBase< BaseFloat > &pos_vis, const CuMatrixBase< BaseFloat > &pos_hid, const CuMatrixBase< BaseFloat > &neg_vis, const CuMatrixBase< BaseFloat > &neg_hid)=0 |
| virtual RbmNodeType | VisType () const =0 |
| virtual RbmNodeType | HidType () const =0 |
| virtual void | WriteAsNnet (std::ostream &os, bool binary) const =0 |
| void | SetRbmTrainOptions (const RbmTrainOptions &opts) |
| Set training hyper-parameters to the network and its UpdatableComponent(s) More... | |
| const RbmTrainOptions & | GetRbmTrainOptions () const |
| Get training hyper-parameters from the network. More... | |
Public Member Functions inherited from Component | |
| Component (int32 input_dim, int32 output_dim) | |
| Generic interface of a component,. More... | |
| virtual | ~Component () |
| virtual Component * | Copy () const =0 |
| Copy component (deep copy),. More... | |
| virtual ComponentType | GetType () const =0 |
| Get Type Identification of the component,. More... | |
| virtual bool | IsUpdatable () const |
| Check if componeny has 'Updatable' interface (trainable components),. More... | |
| virtual bool | IsMultistream () const |
| Check if component has 'Recurrent' interface (trainable and recurrent),. More... | |
| int32 | InputDim () const |
| Get the dimension of the input,. More... | |
| int32 | OutputDim () const |
| Get the dimension of the output,. More... | |
| void | Propagate (const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out) |
| Perform forward-pass propagation 'in' -> 'out',. More... | |
| void | Backpropagate (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff) |
| Perform backward-pass propagation 'out_diff' -> 'in_diff'. More... | |
| void | Write (std::ostream &os, bool binary) const |
| Write the component to a stream,. More... | |
| virtual std::string | Info () const |
| Print some additional info (after <ComponentName> and the dims),. More... | |
| virtual std::string | InfoGradient () const |
| Print some additional info about gradient (after <...> and dims),. More... | |
Protected Attributes | |
| RbmTrainOptions | rbm_opts_ |
Protected Attributes inherited from Component | |
| int32 | input_dim_ |
| Data members,. More... | |
| int32 | output_dim_ |
| Dimension of the output of the Component,. More... | |
Private Member Functions | |
| void | Backpropagate (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff) |
| void | BackpropagateFnc (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff) |
| Backward pass transformation (to be implemented by descending class...) More... | |
Additional Inherited Members | |
Static Public Member Functions inherited from Component | |
| static const char * | TypeToMarker (ComponentType t) |
| Converts component type to marker,. More... | |
| static ComponentType | MarkerToType (const std::string &s) |
| Converts marker to component type (case insensitive),. More... | |
| static Component * | Init (const std::string &conf_line) |
| Initialize component from a line in config file,. More... | |
| static Component * | Read (std::istream &is, bool binary) |
| Read the component from a stream (static method),. More... | |
Static Public Attributes inherited from Component | |
| static const struct key_value | kMarkerMap [] |
| The table with pairs of Component types and markers (defined in nnet-component.cc),. More... | |
Protected Member Functions inherited from Component | |
| virtual void | PropagateFnc (const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)=0 |
| Abstract interface for propagation/backpropagation. More... | |
| virtual void | InitData (std::istream &is) |
| Virtual interface for initialization and I/O,. More... | |
| virtual void | ReadData (std::istream &is, bool binary) |
| Reads the component content. More... | |
| virtual void | WriteData (std::ostream &os, bool binary) const |
| Writes the component content. More... | |
Definition at line 35 of file nnet-rbm.h.
| enum RbmNodeType |
| Enumerator | |
|---|---|
| Bernoulli | |
| Gaussian | |
Definition at line 37 of file nnet-rbm.h.
Definition at line 42 of file nnet-rbm.h.
References RbmBase::HidType(), RbmBase::RbmUpdate(), RbmBase::Reconstruct(), RbmBase::VisType(), and RbmBase::WriteAsNnet().
|
inlineprivate |
Definition at line 81 of file nnet-rbm.h.
|
inlineprivatevirtual |
Backward pass transformation (to be implemented by descending class...)
Implements Component.
Definition at line 86 of file nnet-rbm.h.
|
inline |
Get training hyper-parameters from the network.
Definition at line 71 of file nnet-rbm.h.
References RbmBase::rbm_opts_.
|
pure virtual |
Implemented in Rbm.
Referenced by RbmBase::RbmBase(), and Rbm::WriteAsNnet().
|
pure virtual |
Implemented in Rbm.
Referenced by RbmBase::RbmBase().
|
pure virtual |
Implemented in Rbm.
Referenced by RbmBase::RbmBase().
|
inline |
Set training hyper-parameters to the network and its UpdatableComponent(s)
Definition at line 67 of file nnet-rbm.h.
References RbmBase::rbm_opts_.
|
pure virtual |
Implemented in Rbm.
Referenced by RbmBase::RbmBase().
|
pure virtual |
Implemented in Rbm.
Referenced by RbmBase::RbmBase().
|
protected |
Definition at line 76 of file nnet-rbm.h.
Referenced by RbmBase::GetRbmTrainOptions(), Rbm::RbmUpdate(), and RbmBase::SetRbmTrainOptions().