nnet3-latgen-faster-batch.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-latgen-faster-parallel.cc
2 
3 // Copyright 2012-2016 Johns Hopkins University (author: Daniel Povey)
4 // 2014 Guoguo Chen
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 
22 #include "base/timer.h"
23 #include "base/kaldi-common.h"
25 #include "fstext/fstext-lib.h"
26 #include "hmm/transition-model.h"
28 #include "nnet3/nnet-utils.h"
29 #include "util/kaldi-thread.h"
30 #include "tree/context-dep.h"
31 #include "util/common-utils.h"
32 
33 namespace kaldi {
34 
35 void HandleOutput(bool determinize,
36  const fst::SymbolTable *word_syms,
37  nnet3::NnetBatchDecoder *decoder,
38  CompactLatticeWriter *clat_writer,
39  LatticeWriter *lat_writer) {
40  // Write out any lattices that are ready.
41  std::string output_utterance_id, sentence;
42  if (determinize) {
43  CompactLattice clat;
44  while (decoder->GetOutput(&output_utterance_id, &clat, &sentence)) {
45  if (word_syms != NULL)
46  std::cerr << output_utterance_id << ' ' << sentence << '\n';
47  clat_writer->Write(output_utterance_id, clat);
48  }
49  } else {
50  Lattice lat;
51  while (decoder->GetOutput(&output_utterance_id, &lat, &sentence)) {
52  if (word_syms != NULL)
53  std::cerr << output_utterance_id << ' ' << sentence << '\n';
54  lat_writer->Write(output_utterance_id, lat);
55  }
56  }
57 }
58 
59 } // namespace kaldi
60 
61 int main(int argc, char *argv[]) {
62  // note: making this program work with GPUs is as simple as initializing the
63  // device, but it probably won't make a huge difference in speed for typical
64  // setups.
65  try {
66  using namespace kaldi;
67  using namespace kaldi::nnet3;
68  typedef kaldi::int32 int32;
69  using fst::SymbolTable;
70  using fst::Fst;
71  using fst::StdArc;
72 
73  const char *usage =
74  "Generate lattices using nnet3 neural net model. This version is optimized\n"
75  "for GPU-based inference.\n"
76  "Usage: nnet3-latgen-faster-batch [options] <nnet-in> <fst-in> <features-rspecifier>"
77  " <lattice-wspecifier>\n";
78  ParseOptions po(usage);
79 
80  bool allow_partial = false;
81  LatticeFasterDecoderConfig decoder_opts;
82  NnetBatchComputerOptions compute_opts;
83  std::string use_gpu = "yes";
84 
85  std::string word_syms_filename;
86  std::string ivector_rspecifier,
87  online_ivector_rspecifier,
88  utt2spk_rspecifier;
89  int32 online_ivector_period = 0, num_threads = 1;
90  decoder_opts.Register(&po);
91  compute_opts.Register(&po);
92  po.Register("word-symbol-table", &word_syms_filename,
93  "Symbol table for words [for debug output]");
94  po.Register("allow-partial", &allow_partial,
95  "If true, produce output even if end state was not reached.");
96  po.Register("ivectors", &ivector_rspecifier, "Rspecifier for "
97  "iVectors as vectors (i.e. not estimated online); per utterance "
98  "by default, or per speaker if you provide the --utt2spk option.");
99  po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for "
100  "iVectors estimated online, as matrices. If you supply this,"
101  " you must set the --online-ivector-period option.");
102  po.Register("online-ivector-period", &online_ivector_period, "Number of frames "
103  "between iVectors in matrices supplied to the --online-ivectors "
104  "option");
105  po.Register("num-threads", &num_threads, "Number of decoder (i.e. "
106  "graph-search) threads. The number of model-evaluation threads "
107  "is always 1; this is optimized for use with the GPU.");
108  po.Register("use-gpu", &use_gpu,
109  "yes|no|optional|wait, only has effect if compiled with CUDA");
110 
111 #if HAVE_CUDA==1
112  CuDevice::RegisterDeviceOptions(&po);
113 #endif
114 
115  po.Read(argc, argv);
116 
117  if (po.NumArgs() != 4) {
118  po.PrintUsage();
119  exit(1);
120  }
121 
122 #if HAVE_CUDA==1
123  CuDevice::Instantiate().AllowMultithreading();
124  CuDevice::Instantiate().SelectGpuId(use_gpu);
125 #endif
126 
127  std::string model_in_rxfilename = po.GetArg(1),
128  fst_in_rxfilename = po.GetArg(2),
129  feature_rspecifier = po.GetArg(3),
130  lattice_wspecifier = po.GetArg(4);
131 
132  TransitionModel trans_model;
133  AmNnetSimple am_nnet;
134  {
135  bool binary;
136  Input ki(model_in_rxfilename, &binary);
137  trans_model.Read(ki.Stream(), binary);
138  am_nnet.Read(ki.Stream(), binary);
139  SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
140  SetDropoutTestMode(true, &(am_nnet.GetNnet()));
141  CollapseModel(CollapseModelConfig(), &(am_nnet.GetNnet()));
142  }
143 
144  bool determinize = decoder_opts.determinize_lattice;
145  CompactLatticeWriter compact_lattice_writer;
146  LatticeWriter lattice_writer;
147  if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier)
148  : lattice_writer.Open(lattice_wspecifier)))
149  KALDI_ERR << "Could not open table for writing lattices: "
150  << lattice_wspecifier;
151 
152  RandomAccessBaseFloatMatrixReader online_ivector_reader(
153  online_ivector_rspecifier);
155  ivector_rspecifier, utt2spk_rspecifier);
156 
157  fst::SymbolTable *word_syms = NULL;
158  if (word_syms_filename != "")
159  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
160  KALDI_ERR << "Could not read symbol table from file "
161  << word_syms_filename;
162 
163 
164  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
165 
166  Fst<StdArc> *decode_fst = fst::ReadFstKaldiGeneric(fst_in_rxfilename);
167 
168  int32 num_success;
169  {
170  NnetBatchComputer computer(compute_opts, am_nnet.GetNnet(),
171  am_nnet.Priors());
172  NnetBatchDecoder decoder(*decode_fst, decoder_opts,
173  trans_model, word_syms, allow_partial,
174  num_threads, &computer);
175 
176  for (; !feature_reader.Done(); feature_reader.Next()) {
177  std::string utt = feature_reader.Key();
178  const Matrix<BaseFloat> &features (feature_reader.Value());
179 
180  if (features.NumRows() == 0) {
181  KALDI_WARN << "Zero-length utterance: " << utt;
182  decoder.UtteranceFailed();
183  continue;
184  }
185  const Matrix<BaseFloat> *online_ivectors = NULL;
186  const Vector<BaseFloat> *ivector = NULL;
187  if (!ivector_rspecifier.empty()) {
188  if (!ivector_reader.HasKey(utt)) {
189  KALDI_WARN << "No iVector available for utterance " << utt;
190  decoder.UtteranceFailed();
191  continue;
192  } else {
193  ivector = &ivector_reader.Value(utt);
194  }
195  }
196  if (!online_ivector_rspecifier.empty()) {
197  if (!online_ivector_reader.HasKey(utt)) {
198  KALDI_WARN << "No online iVector available for utterance " << utt;
199  decoder.UtteranceFailed();
200  continue;
201  } else {
202  online_ivectors = &online_ivector_reader.Value(utt);
203  }
204  }
205 
206  decoder.AcceptInput(utt, features, ivector, online_ivectors,
207  online_ivector_period);
208 
209  HandleOutput(decoder_opts.determinize_lattice, word_syms, &decoder,
210  &compact_lattice_writer, &lattice_writer);
211  }
212  num_success = decoder.Finished();
213  HandleOutput(decoder_opts.determinize_lattice, word_syms, &decoder,
214  &compact_lattice_writer, &lattice_writer);
215 
216  // At this point the decoder and batch-computer objects will print
217  // diagnostics from their destructors (they are going out of scope).
218  }
219  delete decode_fst;
220  delete word_syms;
221 
222 #if HAVE_CUDA==1
223  CuDevice::Instantiate().PrintProfile();
224 #endif
225 
226  return (num_success != 0 ? 0 : 1);
227  } catch(const std::exception &e) {
228  std::cerr << e.what();
229  return -1;
230  }
231 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Decoder object that uses multiple CPU threads for the graph search, plus a GPU for the neural net inf...
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
bool Open(const std::string &wspecifier)
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
Definition: kaldi-fst-io.cc:45
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
fst::StdArc StdArc
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
void HandleOutput(bool determinize, const fst::SymbolTable *word_syms, nnet3::NnetBatchDecoder *decoder, CompactLatticeWriter *clat_writer, LatticeWriter *lat_writer)
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
const Nnet & GetNnet() const
void Write(const std::string &key, const T &value) const
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
This file contains some miscellaneous functions dealing with class Nnet.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
std::istream & Stream()
Definition: kaldi-io.cc:826
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
const VectorBase< BaseFloat > & Priors() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
const T & Value(const std::string &key)
int main(int argc, char *argv[])
This class does neural net inference in a way that is optimized for GPU use: it combines chunks of mu...
bool GetOutput(std::string *utterance_id, CompactLattice *clat, std::string *sentence)
The user should call this to obtain output (This version should only be called if config...
Config class for the CollapseModel function.
Definition: nnet-utils.h:240