nnet-am-decodable-simple.h
Go to the documentation of this file.
1 // nnet3/nnet-am-decodable-simple.h
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NNET_AM_DECODABLE_SIMPLE_H_
21 #define KALDI_NNET3_NNET_AM_DECODABLE_SIMPLE_H_
22 
23 #include <vector>
24 #include "base/kaldi-common.h"
25 #include "gmm/am-diag-gmm.h"
26 #include "hmm/transition-model.h"
27 #include "itf/decodable-itf.h"
28 #include "nnet3/nnet-optimize.h"
29 #include "nnet3/nnet-compute.h"
30 #include "nnet3/am-nnet-simple.h"
31 
32 namespace kaldi {
33 namespace nnet3 {
34 
35 
36 // See also the decodable object in decodable-simple-looped.h, which is better
37 // and faster in most situations, including TDNNs and LSTMs (but not for
38 // BLSTMs).
39 
40 
41 // Note: the 'simple' in the name means it applies to networks
42 // for which IsSimpleNnet(nnet) would return true.
55 
57  extra_left_context(0),
58  extra_right_context(0),
59  extra_left_context_initial(-1),
60  extra_right_context_final(-1),
61  frame_subsampling_factor(1),
62  frames_per_chunk(50),
63  acoustic_scale(0.1),
64  debug_computation(false) {
65  compiler_config.cache_capacity += frames_per_chunk;
66  }
67 
68  void Register(OptionsItf *opts) {
69  opts->Register("extra-left-context", &extra_left_context,
70  "Number of frames of additional left-context to add on top "
71  "of the neural net's inherent left context (may be useful in "
72  "recurrent setups");
73  opts->Register("extra-right-context", &extra_right_context,
74  "Number of frames of additional right-context to add on top "
75  "of the neural net's inherent right context (may be useful in "
76  "recurrent setups");
77  opts->Register("extra-left-context-initial", &extra_left_context_initial,
78  "If >= 0, overrides the --extra-left-context value at the "
79  "start of an utterance.");
80  opts->Register("extra-right-context-final", &extra_right_context_final,
81  "If >= 0, overrides the --extra-right-context value at the "
82  "end of an utterance.");
83  opts->Register("frame-subsampling-factor", &frame_subsampling_factor,
84  "Required if the frame-rate of the output (e.g. in 'chain' "
85  "models) is less than the frame-rate of the original "
86  "alignment.");
87  opts->Register("acoustic-scale", &acoustic_scale,
88  "Scaling factor for acoustic log-likelihoods (caution: is a no-op "
89  "if set in the program nnet3-compute");
90  opts->Register("frames-per-chunk", &frames_per_chunk,
91  "Number of frames in each chunk that is separately evaluated "
92  "by the neural net. Measured before any subsampling, if the "
93  "--frame-subsampling-factor options is used (i.e. counts "
94  "input frames");
95  opts->Register("debug-computation", &debug_computation, "If true, turn on "
96  "debug for the actual computation (very verbose!)");
97 
98  // register the optimization options with the prefix "optimization".
99  ParseOptions optimization_opts("optimization", opts);
100  optimize_config.Register(&optimization_opts);
101 
102  // register the compute options with the prefix "computation".
103  ParseOptions compute_opts("computation", opts);
104  compute_config.Register(&compute_opts);
105  }
106 
107  void CheckAndFixConfigs(int32 nnet_modulus) {
108  static bool warned_frames_per_chunk = false;
109  if (frame_subsampling_factor < 1 || frames_per_chunk < 1) {
110  KALDI_ERR << "--frame-subsampling-factor and "
111  << "--frames-per-chunk must be > 0";
112  }
113  KALDI_ASSERT(nnet_modulus > 0);
114  int32 n = Lcm(frame_subsampling_factor, nnet_modulus);
115 
116  if (frames_per_chunk % n != 0) {
117  // round up to the nearest multiple of n.
118  int32 new_frames_per_chunk = n * ((frames_per_chunk + n - 1) / n);
119  if (!warned_frames_per_chunk) {
120  warned_frames_per_chunk = true;
121  if (nnet_modulus == 1) {
122  // simpler error message.
123  KALDI_LOG << "Increasing --frames-per-chunk from " << frames_per_chunk
124  << " to " << new_frames_per_chunk
125  << " to make it a multiple of "
126  << "--frame-subsampling-factor="
128  } else {
129  KALDI_LOG << "Increasing --frames-per-chunk from " << frames_per_chunk
130  << " to " << new_frames_per_chunk << " due to "
131  << "--frame-subsampling-factor=" << frame_subsampling_factor
132  << " and "
133  << "nnet shift-invariance modulus = " << nnet_modulus;
134  }
135  }
136  frames_per_chunk = new_frames_per_chunk;
137  }
138  }
139 };
140 
141 /*
142  This class handles the neural net computation; it's mostly accessed
143  via other wrapper classes.
144 
145  Note: this class used to be called NnetDecodableBase.
146 
147  It can accept just input features, or input features plus iVectors. */
149  public:
182  const Nnet &nnet,
183  const VectorBase<BaseFloat> &priors,
184  const MatrixBase<BaseFloat> &feats,
185  CachingOptimizingCompiler *compiler,
186  const VectorBase<BaseFloat> *ivector = NULL,
187  const MatrixBase<BaseFloat> *online_ivectors = NULL,
188  int32 online_ivector_period = 1);
189 
190 
191  // returns the number of frames of likelihoods. The same as feats_.NumRows()
192  // in the normal case (but may be less if opts_.frame_subsampling_factor !=
193  // 1).
194  inline int32 NumFrames() const { return num_subsampled_frames_; }
195 
196  inline int32 OutputDim() const { return output_dim_; }
197 
198  // Gets the output for a particular frame, with 0 <= frame < NumFrames().
199  // 'output' must be correctly sized (with dimension OutputDim()).
200  void GetOutputForFrame(int32 frame, VectorBase<BaseFloat> *output);
201 
202  // Gets the output for a particular frame and pdf_id, with
203  // 0 <= subsampled_frame < NumFrames(),
204  // and 0 <= pdf_id < OutputDim().
205  inline BaseFloat GetOutput(int32 subsampled_frame, int32 pdf_id) {
206  if (subsampled_frame < current_log_post_subsampled_offset_ ||
207  subsampled_frame >= current_log_post_subsampled_offset_ +
208  current_log_post_.NumRows())
209  EnsureFrameIsComputed(subsampled_frame);
210  return current_log_post_(subsampled_frame -
211  current_log_post_subsampled_offset_,
212  pdf_id);
213  }
214  private:
216 
217  // This call is made to ensure that we have the log-probs for this frame
218  // cached in current_log_post_.
219  void EnsureFrameIsComputed(int32 subsampled_frame);
220 
221  // This function does the actual nnet computation; it is called from
222  // EnsureFrameIsComputed. Any padding at file start/end is done by
223  // the caller of this function (so the input should exceed the output
224  // by a suitable amount of context). It puts its output in current_log_post_.
225  void DoNnetComputation(int32 input_t_start,
226  const MatrixBase<BaseFloat> &input_feats,
227  const VectorBase<BaseFloat> &ivector,
228  int32 output_t_start,
229  int32 num_subsampled_frames);
230 
231  // Gets the iVector that will be used for this chunk of frames, if we are
232  // using iVectors (else does nothing). note: the num_output_frames is
233  // interpreted as the number of t value, which in the subsampled case is not
234  // the same as the number of subsampled frames (it would be larger by
235  // opts_.frame_subsampling_factor).
236  void GetCurrentIvector(int32 output_t_start,
237  int32 num_output_frames,
238  Vector<BaseFloat> *ivector);
239 
240  // called from constructor
241  void CheckAndFixConfigs();
242 
243  // returns dimension of the provided iVectors if supplied, or 0 otherwise.
244  int32 GetIvectorDim() const;
245 
247  const Nnet &nnet_;
251  // the log priors (or the empty vector if the priors are not set in the model)
254  // note: num_subsampled_frames_ will equal feats_.NumRows() in the normal case
255  // when opts_.frame_subsampling_factor == 1.
257 
258  // ivector_ is the iVector if we're using iVectors that are estimated in batch
259  // mode.
261 
262  // online_ivector_feats_ is the iVectors if we're using online-estimated ones.
264  // online_ivector_period_ helps us interpret online_ivector_feats_; it's the
265  // number of frames the rows of ivector_feats are separated by.
267 
268  // a reference to a compiler passed in via the constructor, which may be
269  // declared at the top level of the program so that we don't have to recompile
270  // computations each time.
272 
273  // The current log-posteriors that we got from the last time we
274  // ran the computation.
276  // The time-offset of the current log-posteriors. Note: if
277  // opts_.frame_subsampling_factor > 1, this will be measured in subsampled
278  // frames.
280 };
281 
283  public:
319  const TransitionModel &trans_model,
320  const AmNnetSimple &am_nnet,
321  const MatrixBase<BaseFloat> &feats,
322  const VectorBase<BaseFloat> *ivector = NULL,
323  const MatrixBase<BaseFloat> *online_ivectors = NULL,
324  int32 online_ivector_period = 1,
325  CachingOptimizingCompiler *compiler = NULL);
326 
327 
328  virtual BaseFloat LogLikelihood(int32 frame, int32 transition_id);
329 
330  virtual inline int32 NumFramesReady() const {
331  return decodable_nnet_.NumFrames();
332  }
333 
334  virtual int32 NumIndices() const { return trans_model_.NumTransitionIds(); }
335 
336  virtual bool IsLastFrame(int32 frame) const {
337  KALDI_ASSERT(frame < NumFramesReady());
338  return (frame == NumFramesReady() - 1);
339  }
340 
341  private:
343  // This compiler object is only used if the 'compiler'
344  // argument to the constructor is NULL.
348 };
349 
350 
352  public:
390  const NnetSimpleComputationOptions &opts,
391  const TransitionModel &trans_model,
392  const AmNnetSimple &am_nnet,
393  const MatrixBase<BaseFloat> &feats,
394  const VectorBase<BaseFloat> *ivector = NULL,
395  const MatrixBase<BaseFloat> *online_ivectors = NULL,
396  int32 online_ivector_period = 1);
397 
398 
399  virtual BaseFloat LogLikelihood(int32 frame, int32 transition_id);
400 
401  virtual inline int32 NumFramesReady() const {
402  return decodable_nnet_->NumFrames();
403  }
404 
405  virtual int32 NumIndices() const { return trans_model_.NumTransitionIds(); }
406 
407  virtual bool IsLastFrame(int32 frame) const {
408  KALDI_ASSERT(frame < NumFramesReady());
409  return (frame == NumFramesReady() - 1);
410  }
411 
413  private:
415  void DeletePointers();
416 
419 
423 
425 };
426 
427 
428 
429 } // namespace nnet3
430 } // namespace kaldi
431 
432 #endif // KALDI_NNET3_NNET_AM_DECODABLE_SIMPLE_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Register(OptionsItf *opts)
Definition: nnet-optimize.h:84
virtual int32 NumIndices() const
Returns the number of states in the acoustic model (they will be indexed one-based, i.e.
const MatrixBase< BaseFloat > * online_ivector_feats_
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
Definition: decodable-itf.h:82
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
kaldi::int32 int32
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
virtual bool IsLastFrame(int32 frame) const
Returns true if this is the last frame.
CachingOptimizingCompilerOptions compiler_config
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
I Lcm(I m, I n)
Returns the least common multiple of two integers.
Definition: kaldi-math.h:318
const MatrixBase< BaseFloat > & feats_
const VectorBase< BaseFloat > * ivector_
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
virtual int32 NumIndices() const
Returns the number of states in the acoustic model (they will be indexed one-based, i.e.
struct rnnlm::@11::@12 n
#define KALDI_ERR
Definition: kaldi-error.h:147
virtual bool IsLastFrame(int32 frame) const
Returns true if this is the last frame.
void Register(OptionsItf *opts)
Definition: nnet-compute.h:42
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
#define KALDI_LOG
Definition: kaldi-error.h:153
BaseFloat GetOutput(int32 subsampled_frame, int32 pdf_id)