kgrams  0.1.0
Sampler.h
1 
5 #ifndef SAMPLER_H
6 #define SAMPLER_H
7 
8 #include "special_tokens.h"
9 
13 
14 template<class Smoother>
15 class Sampler {
16  //--------Private variables--------//
17 
19  Smoother prob_;
20 
21  //--------Private methods--------//
22 
25  std::string sample_word(std::string context, double T = 1.0) {
26  std::string res;
27  double best = 0, tmp;
28  std::string word;
29  // Sample word from P(word|context) using Gumbel-Max trick
30  for (size_t i = 1; i <= prob_.V_; ++i) {
31  word = prob_.f_.dictionary()->word(std::to_string(i));
32  tmp = std::pow(prob_(word, context), 1 / T);
33  tmp /= R::rexp(1.);
34  if (tmp > best) {
35  best = tmp;
36  res = word;
37  }
38  }
39  // Separate iteration for EOS token
40  tmp = std::pow(prob_(EOS_TOK, context), 1 / T) / R::rexp(1.);
41  if (tmp > best)
42  res = EOS_TOK;
43  // N.B.: we forbid sampling the UNK token
44  return res;
45  }
46 
47  // Not yet implemented. Sampling using simple rejection method.
48  // N.B.: requires normalized probabilities.
49  //
50  // std::string sample_word_rej(std::string context) {
51  // std::string res;
52  // double best = 0, tmp;
53  // std::string word;
54  // while (true) {
55  // size_t n = R::runif(0, prob_.V_);
56  // word = prob_.f_.dictionary()->word(std::to_string(n));
57  // if (prob_(word, context) / R::runif(0, 1) > 1)
58  // return word;
59  // }
60  // }
61 public:
62  //--------Constructor--------//
63 
67  Sampler (Smoother prob) : prob_(prob) {}
68 
75  std::string sample_sentence(size_t max_length, double T = 1.0) {
76  std::string res = "", context = "";
77  for (size_t i = 1; i < prob_.f_.N(); ++i) {
78  context += BOS_TOK + " ";
79  }
80  size_t n_words = 0;
81  std::string new_word; size_t start = 0;
82  while (n_words < max_length) {
83  n_words++;
84  new_word = sample_word(context, T);
85  if (new_word == EOS_TOK)
86  return res + "<eos>";
87  res += new_word + " ";
88  context += " " + new_word;
89  start = context.find_first_not_of(" ");
90  start = context.find_first_of(" ", start);
91  context = context.substr(start + 1);
92  }
93  return res + "[...] (truncated output)";
94  }
95 }; // class Sampler<Smoother>
96 
97 #endif // SAMPLER_H
Sampler::Sampler
Sampler(Smoother prob)
Initialize a Sampler from a given smoother object.
Definition: Sampler.h:67
Sampler::sample_sentence
std::string sample_sentence(size_t max_length, double T=1.0)
Sample a sentence from the probability distribution specified by the smoother.
Definition: Sampler.h:75
Sampler
Sample sequences from a k-gram language model.
Definition: Sampler.h:15