NnetComputeProb Class Reference

This class is for computing cross-entropy and accuracy values in a neural network, for diagnostics. More...

#include <nnet-diagnostics.h>

Collaboration diagram for NnetComputeProb:

Public Member Functions

 NnetComputeProb (const NnetComputeProbOptions &config, const Nnet &nnet)
 
 NnetComputeProb (const NnetComputeProbOptions &config, Nnet *nnet)
 
void Reset ()
 
void Compute (const NnetExample &eg)
 
bool PrintTotalStats () const
 
const SimpleObjectiveInfoGetObjective (const std::string &output_name) const
 
double GetTotalObjective (double *tot_weight) const
 
const NnetGetDeriv () const
 
 ~NnetComputeProb ()
 

Private Member Functions

void ProcessOutputs (const NnetExample &eg, NnetComputer *computer)
 

Private Attributes

NnetComputeProbOptions config_
 
const Nnetnnet_
 
bool deriv_nnet_owned_
 
Nnetderiv_nnet_
 
CachingOptimizingCompiler compiler_
 
int32 num_minibatches_processed_
 
unordered_map< std::string, SimpleObjectiveInfo, StringHasherobjf_info_
 
unordered_map< std::string, PerDimObjectiveInfo, StringHasheraccuracy_info_
 

Detailed Description

This class is for computing cross-entropy and accuracy values in a neural network, for diagnostics.

Note: because we put a "logsoftmax" component in the nnet, the actual objective function becomes linear at the output, but the printed messages reflect the fact that it's the cross-entropy objective.

Definition at line 107 of file nnet-diagnostics.h.

Constructor & Destructor Documentation

◆ NnetComputeProb() [1/2]

NnetComputeProb ( const NnetComputeProbOptions config,
const Nnet nnet 
)

Definition at line 26 of file nnet-diagnostics.cc.

References NnetComputeProbOptions::compute_deriv, NnetComputeProb::config_, NnetComputeProb::deriv_nnet_, KALDI_ERR, NnetComputeProb::nnet_, kaldi::nnet3::ScaleNnet(), kaldi::nnet3::SetNnetAsGradient(), and NnetComputeProbOptions::store_component_stats.

27  :
28  config_(config),
29  nnet_(nnet),
30  deriv_nnet_owned_(true),
31  deriv_nnet_(NULL),
34  if (config_.compute_deriv) {
35  deriv_nnet_ = new Nnet(nnet_);
36  ScaleNnet(0.0, deriv_nnet_);
37  SetNnetAsGradient(deriv_nnet_); // force simple update
38  } else if (config_.store_component_stats) {
39  KALDI_ERR << "If you set store_component_stats == true and "
40  << "compute_deriv == false, use the other constructor.";
41  }
42 }
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
NnetComputeProbOptions config_
void SetNnetAsGradient(Nnet *nnet)
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableCompo...
Definition: nnet-utils.cc:292
#define KALDI_ERR
Definition: kaldi-error.h:147
CachingOptimizingCompiler compiler_
CachingOptimizingCompilerOptions compiler_config

◆ NnetComputeProb() [2/2]

NnetComputeProb ( const NnetComputeProbOptions config,
Nnet nnet 
)

Definition at line 45 of file nnet-diagnostics.cc.

References NnetComputeProbOptions::compute_deriv, KALDI_ASSERT, and NnetComputeProbOptions::store_component_stats.

46  :
47  config_(config),
48  nnet_(*nnet),
49  deriv_nnet_owned_(false),
50  deriv_nnet_(nnet),
53  KALDI_ASSERT(config.store_component_stats && !config.compute_deriv);
54 }
NnetComputeProbOptions config_
CachingOptimizingCompiler compiler_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
CachingOptimizingCompilerOptions compiler_config

◆ ~NnetComputeProb()

Definition at line 64 of file nnet-diagnostics.cc.

References NnetComputeProb::deriv_nnet_, and NnetComputeProb::deriv_nnet_owned_.

64  {
66  delete deriv_nnet_; // delete does nothing if pointer is NULL.
67 }

Member Function Documentation

◆ Compute()

void Compute ( const NnetExample eg)

Definition at line 79 of file nnet-diagnostics.cc.

References NnetComputer::AcceptInputs(), CachingOptimizingCompiler::Compile(), NnetComputeProb::compiler_, NnetComputeProbOptions::compute_config, NnetComputeProbOptions::compute_deriv, NnetComputeProb::config_, NnetComputeProb::deriv_nnet_, kaldi::nnet3::GetComputationRequest(), NnetExample::io, NnetComputeProb::nnet_, NnetComputeProb::ProcessOutputs(), NnetComputer::Run(), and NnetComputeProbOptions::store_component_stats.

Referenced by kaldi::nnet3::ComputeObjf(), main(), and kaldi::nnet3::RecomputeStats().

79  {
80  bool need_model_derivative = config_.compute_deriv,
81  store_component_stats = config_.store_component_stats;
82  ComputationRequest request;
83  GetComputationRequest(nnet_, eg, need_model_derivative,
84  store_component_stats,
85  &request);
86  std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
87  NnetComputer computer(config_.compute_config, *computation,
89  // give the inputs to the computer object.
90  computer.AcceptInputs(nnet_, eg.io);
91  computer.Run();
92  this->ProcessOutputs(eg, &computer);
94  computer.Run();
95 }
NnetComputeProbOptions config_
void ProcessOutputs(const NnetExample &eg, NnetComputer *computer)
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
CachingOptimizingCompiler compiler_
void GetComputationRequest(const Nnet &nnet, const NnetExample &eg, bool need_model_derivative, bool store_component_stats, ComputationRequest *request)
This function takes a NnetExample (which should already have been frame-selected, if desired...

◆ GetDeriv()

const Nnet & GetDeriv ( ) const

Definition at line 58 of file nnet-diagnostics.cc.

References NnetComputeProbOptions::compute_deriv, NnetComputeProb::config_, NnetComputeProb::deriv_nnet_, and KALDI_ERR.

Referenced by main().

58  {
60  KALDI_ERR << "GetDeriv() called when no derivatives were requested.";
61  return *deriv_nnet_;
62 }
NnetComputeProbOptions config_
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ GetObjective()

const SimpleObjectiveInfo * GetObjective ( const std::string &  output_name) const

Definition at line 299 of file nnet-diagnostics.cc.

References NnetComputeProb::objf_info_.

Referenced by main().

300  {
301  unordered_map<std::string, SimpleObjectiveInfo, StringHasher>::const_iterator
302  iter = objf_info_.find(output_name);
303  if (iter != objf_info_.end())
304  return &(iter->second);
305  else
306  return NULL;
307 }
unordered_map< std::string, SimpleObjectiveInfo, StringHasher > objf_info_

◆ GetTotalObjective()

double GetTotalObjective ( double *  tot_weight) const

Definition at line 309 of file nnet-diagnostics.cc.

References NnetComputeProb::objf_info_.

Referenced by kaldi::nnet3::ComputeObjf().

309  {
310  double tot_objectives = 0.0;
311  double tot_weight = 0.0;
312  unordered_map<std::string, SimpleObjectiveInfo, StringHasher>::const_iterator
313  iter = objf_info_.begin(), end = objf_info_.end();
314  for (; iter != end; ++iter) {
315  tot_objectives += iter->second.tot_objective;
316  tot_weight += iter->second.tot_weight;
317  }
318 
319  if (total_weight) *total_weight = tot_weight;
320  return tot_objectives;
321 }
unordered_map< std::string, SimpleObjectiveInfo, StringHasher > objf_info_

◆ PrintTotalStats()

bool PrintTotalStats ( ) const

Definition at line 149 of file nnet-diagnostics.cc.

References NnetComputeProb::accuracy_info_, Nnet::GetNode(), Nnet::GetNodeIndex(), rnnlm::j, KALDI_ASSERT, KALDI_LOG, kaldi::nnet3::kLinear, NnetComputeProb::nnet_, NetworkNode::objective_type, NnetComputeProb::objf_info_, SimpleObjectiveInfo::tot_objective, PerDimObjectiveInfo::tot_objective_vec, SimpleObjectiveInfo::tot_weight, PerDimObjectiveInfo::tot_weight_vec, and NetworkNode::u.

Referenced by main(), and kaldi::nnet3::RecomputeStats().

149  {
150  bool ans = false;
151  { // First print regular objectives
152  unordered_map<std::string, SimpleObjectiveInfo,
153  StringHasher>::const_iterator iter, end;
154  iter = objf_info_.begin();
155  end = objf_info_.end();
156  for (; iter != end; ++iter) {
157  const std::string &name = iter->first;
158  int32 node_index = nnet_.GetNodeIndex(name);
159  KALDI_ASSERT(node_index >= 0);
160  ObjectiveType obj_type = nnet_.GetNode(node_index).u.objective_type;
161  const SimpleObjectiveInfo &info = iter->second;
162  KALDI_LOG << "Overall "
163  << (obj_type == kLinear ? "log-likelihood" : "objective")
164  << " for '" << name << "' is "
165  << (info.tot_objective / info.tot_weight) << " per frame"
166  << ", over " << info.tot_weight << " frames.";
167  if (info.tot_weight > 0)
168  ans = true;
169  }
170  }
171  {
172  unordered_map<std::string, PerDimObjectiveInfo,
173  StringHasher>::const_iterator iter, end;
174  // now print accuracies.
175  iter = accuracy_info_.begin();
176  end = accuracy_info_.end();
177  for (; iter != end; ++iter) {
178  const std::string &name = iter->first;
179  const PerDimObjectiveInfo &info = iter->second;
180  KALDI_LOG << "Overall accuracy for '" << name << "' is "
181  << (info.tot_objective / info.tot_weight) << " per frame"
182  << ", over " << info.tot_weight << " frames.";
183 
184  if (info.tot_weight_vec.Dim() > 0) {
185  Vector<BaseFloat> accuracy_vec(info.tot_weight_vec.Dim());
186  for (size_t j = 0; j < info.tot_weight_vec.Dim(); j++) {
187  if (info.tot_weight_vec(j) != 0) {
188  accuracy_vec(j) = info.tot_objective_vec(j)
189  / info.tot_weight_vec(j);
190  } else {
191  accuracy_vec(j) = -1.0;
192  }
193  }
194 
195  KALDI_LOG << "Overall per-dim accuracy vector for '" << name
196  << "' is " << accuracy_vec << " per frame"
197  << ", over " << info.tot_weight << " frames.";
198  }
199  // don't bother changing ans; the loop over the regular objective should
200  // already have set it to true if we got any data.
201  }
202  }
203  return ans;
204 }
unordered_map< std::string, PerDimObjectiveInfo, StringHasher > accuracy_info_
kaldi::int32 int32
ObjectiveType objective_type
Definition: nnet-nnet.h:97
unordered_map< std::string, SimpleObjectiveInfo, StringHasher > objf_info_
const NetworkNode & GetNode(int32 node) const
returns const reference to a particular numbered network node.
Definition: nnet-nnet.h:146
ObjectiveType
This enum is for a kind of annotation we associate with output nodes of the network; it&#39;s for the con...
Definition: nnet-nnet.h:52
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
union kaldi::nnet3::NetworkNode::@15 u
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ ProcessOutputs()

void ProcessOutputs ( const NnetExample eg,
NnetComputer computer 
)
private

Definition at line 97 of file nnet-diagnostics.cc.

References NnetComputeProb::accuracy_info_, NnetComputeProbOptions::compute_accuracy, NnetComputeProbOptions::compute_deriv, NnetComputeProbOptions::compute_per_dim_accuracy, kaldi::nnet3::ComputeAccuracy(), kaldi::nnet3::ComputeObjectiveFunction(), NnetComputeProb::config_, NnetIo::features, Nnet::GetNode(), Nnet::GetNodeIndex(), NnetComputer::GetOutput(), NnetExample::io, Nnet::IsOutputNode(), KALDI_ERR, NnetIo::name, NnetComputeProb::nnet_, NnetComputeProb::num_minibatches_processed_, CuMatrixBase< Real >::NumCols(), GeneralMatrix::NumCols(), NetworkNode::objective_type, NnetComputeProb::objf_info_, SimpleObjectiveInfo::tot_objective, PerDimObjectiveInfo::tot_objective_vec, SimpleObjectiveInfo::tot_weight, PerDimObjectiveInfo::tot_weight_vec, and NetworkNode::u.

Referenced by NnetComputeProb::Compute().

98  {
99  std::vector<NnetIo>::const_iterator iter = eg.io.begin(),
100  end = eg.io.end();
101  for (; iter != end; ++iter) {
102  const NnetIo &io = *iter;
103  int32 node_index = nnet_.GetNodeIndex(io.name);
104  if (node_index < 0)
105  KALDI_ERR << "Network has no output named " << io.name;
106  ObjectiveType obj_type = nnet_.GetNode(node_index).u.objective_type;
107  if (nnet_.IsOutputNode(node_index)) {
108  const CuMatrixBase<BaseFloat> &output = computer->GetOutput(io.name);
109  if (output.NumCols() != io.features.NumCols()) {
110  KALDI_ERR << "Nnet versus example output dimension (num-classes) "
111  << "mismatch for '" << io.name << "': " << output.NumCols()
112  << " (nnet) vs. " << io.features.NumCols() << " (egs)\n";
113  }
114  {
115  BaseFloat tot_weight, tot_objf;
116  bool supply_deriv = config_.compute_deriv;
117  ComputeObjectiveFunction(io.features, obj_type, io.name,
118  supply_deriv, computer,
119  &tot_weight, &tot_objf);
120  SimpleObjectiveInfo &totals = objf_info_[io.name];
121  totals.tot_weight += tot_weight;
122  totals.tot_objective += tot_objf;
123  }
124  // May not be meaningful in non-classification tasks
126  BaseFloat tot_weight, tot_accuracy;
127  PerDimObjectiveInfo &acc_totals = accuracy_info_[io.name];
128 
130  acc_totals.tot_objective_vec.Dim() == 0) {
131  acc_totals.tot_objective_vec.Resize(output.NumCols());
132  acc_totals.tot_weight_vec.Resize(output.NumCols());
133  }
134 
135  ComputeAccuracy(io.features, output,
136  &tot_weight, &tot_accuracy,
138  &acc_totals.tot_weight_vec : NULL,
140  &acc_totals.tot_objective_vec : NULL);
141  acc_totals.tot_weight += tot_weight;
142  acc_totals.tot_objective += tot_accuracy;
143  }
144  }
145  }
147 }
unordered_map< std::string, PerDimObjectiveInfo, StringHasher > accuracy_info_
void ComputeObjectiveFunction(const GeneralMatrix &supervision, ObjectiveType objective_type, const std::string &output_name, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, BaseFloat *tot_objf)
This function computes the objective function, and if supply_deriv = true, supplies its derivative to...
NnetComputeProbOptions config_
kaldi::int32 int32
ObjectiveType objective_type
Definition: nnet-nnet.h:97
unordered_map< std::string, SimpleObjectiveInfo, StringHasher > objf_info_
const NetworkNode & GetNode(int32 node) const
returns const reference to a particular numbered network node.
Definition: nnet-nnet.h:146
float BaseFloat
Definition: kaldi-types.h:29
void ComputeAccuracy(const GeneralMatrix &supervision, const CuMatrixBase< BaseFloat > &nnet_output, BaseFloat *tot_weight_out, BaseFloat *tot_accuracy_out, VectorBase< BaseFloat > *tot_weight_vec, VectorBase< BaseFloat > *tot_accuracy_vec)
This function computes the frame accuracy for this minibatch.
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
#define KALDI_ERR
Definition: kaldi-error.h:147
ObjectiveType
This enum is for a kind of annotation we associate with output nodes of the network; it&#39;s for the con...
Definition: nnet-nnet.h:52
union kaldi::nnet3::NetworkNode::@15 u
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466

◆ Reset()

void Reset ( )

Definition at line 69 of file nnet-diagnostics.cc.

References NnetComputeProb::accuracy_info_, NnetComputeProb::deriv_nnet_, NnetComputeProb::num_minibatches_processed_, NnetComputeProb::objf_info_, kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::SetNnetAsGradient().

Referenced by kaldi::nnet3::ComputeObjf().

69  {
71  objf_info_.clear();
72  accuracy_info_.clear();
73  if (deriv_nnet_) {
74  ScaleNnet(0.0, deriv_nnet_);
76  }
77 }
unordered_map< std::string, PerDimObjectiveInfo, StringHasher > accuracy_info_
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
unordered_map< std::string, SimpleObjectiveInfo, StringHasher > objf_info_
void SetNnetAsGradient(Nnet *nnet)
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableCompo...
Definition: nnet-utils.cc:292

Member Data Documentation

◆ accuracy_info_

unordered_map<std::string, PerDimObjectiveInfo, StringHasher> accuracy_info_
private

◆ compiler_

CachingOptimizingCompiler compiler_
private

Definition at line 154 of file nnet-diagnostics.h.

Referenced by NnetComputeProb::Compute().

◆ config_

◆ deriv_nnet_

◆ deriv_nnet_owned_

bool deriv_nnet_owned_
private

Definition at line 152 of file nnet-diagnostics.h.

Referenced by NnetComputeProb::~NnetComputeProb().

◆ nnet_

◆ num_minibatches_processed_

int32 num_minibatches_processed_
private

Definition at line 157 of file nnet-diagnostics.h.

Referenced by NnetComputeProb::ProcessOutputs(), and NnetComputeProb::Reset().

◆ objf_info_


The documentation for this class was generated from the following files: