Kernel Quantum Probability Library
The KQP library aims at providing tools for working with quantums probabilities
divide_and_conquer.hpp
1 /*
2  This file is part of the Kernel Quantum Probability library (KQP).
3 
4  KQP is free software: you can redistribute it and/or modify
5  it under the terms of the GNU General Public License as published by
6  the Free Software Foundation, either version 3 of the License, or
7  (at your option) any later version.
8 
9  KQP is distributed in the hope that it will be useful,
10  but WITHOUT ANY WARRANTY; without even the implied warranty of
11  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  GNU General Public License for more details.
13 
14  You should have received a copy of the GNU General Public License
15  along with KQP. If not, see <http://www.gnu.org/licenses/>.
16  */
17 
18 #ifndef __KQP_DIVIDE_AND_CONQUER_BUILDER_H__
19 #define __KQP_DIVIDE_AND_CONQUER_BUILDER_H__
20 
21 #include <boost/type_traits/is_complex.hpp>
22 #include <kqp/cleanup.hpp>
23 #include <kqp/kernel_evd.hpp>
24 
25 namespace kqp {
26 
27 # include <kqp/define_header_logger.hpp>
28  DEFINE_KQP_HLOGGER("kqp.kevd.dc");
29 
34  template <typename Scalar> class DivideAndConquerBuilder : public KernelEVD<Scalar> {
35  public:
36  KQP_SCALAR_TYPEDEFS(Scalar);
37 
38  DivideAndConquerBuilder(const FSpaceCPtr &fs) : KernelEVD<Scalar>(fs), batchSize(100) {}
39 
40  virtual ~DivideAndConquerBuilder() {}
41 
43  void setBatchSize(Index batchSize) {
44  this->batchSize = batchSize;
45  }
46 
48  void setBuilder(const boost::shared_ptr< KernelEVD<Scalar> > &builder) {
49  this->builder = builder;
50  }
51 
53  void setBuilderCleaner(const boost::shared_ptr< Cleaner<Scalar> > &builderCleaner) {
54  this->builderCleaner = builderCleaner;
55  }
56 
58  void setMerger(const boost::shared_ptr<KernelEVD<Scalar> > &merger) {
59  this->merger = merger;
60  }
61 
63  void setMergerCleaner(const boost::shared_ptr<Cleaner<Scalar> > &mergerCleaner) {
64  this->mergerCleaner = mergerCleaner;
65  }
66 
67 
68  protected:
69  void reset() {
71 
72  decompositions.clear();
73  builder->reset();
74  }
75 
76  virtual Decomposition<Scalar> _getDecomposition() const override {
77  // Flush last merger
78  const_cast<DivideAndConquerBuilder&>(*this).flushBuilder();
79 
80  // Empty decomposition if we had no data
81  if (decompositions.size() == 0)
82  return Decomposition<Scalar>();
83 
84 
85  // Merge everything
86  const_cast<DivideAndConquerBuilder&>(*this).merge(true);
87  return decompositions[0];
88  }
89 
90  // Rank update
91  virtual void _add(Real alpha, const FMatrixCPtr &mU, const ScalarAltMatrix &mA) override {
92  // Prepare
93  if (builder->getUpdateCount() > batchSize) {
94  flushBuilder();
95  merge(false);
96  }
97  // Update the decomposition
98  builder->add(alpha, mU, mA);
99 
100  }
101 
102  private:
103  // Add the current decomposition to the stack
104  void flushBuilder() {
105  if (builder->getUpdateCount() == 0) return;
106 
107  // Get & clean
108  decompositions.push_back(builder->getDecomposition());
109  Decomposition<Scalar> &d = decompositions.back();
110  if (builderCleaner.get())
111  builderCleaner->cleanup(d);
112  assert(!kqp::isNaN(d.fs->k(d.mX, d.mY, d.mD).squaredNorm()));
113 
114  // Resets the builder
115  builder->reset();
116  }
117 
119  static void merge(KernelEVD<Scalar> &merger, const Decomposition<Scalar> &d) {
120  Index posCount = 0;
121 
122  for(Index j = 0; j < d.mD.size(); j++)
123  if (d.mD(j,0) >= 0) posCount++;
124  Index negCount = d.mD.size() - posCount;
125 
126  // FIXME: block expression for Alt expression
127  ScalarMatrix mY = d.mY * d.mD.cwiseAbs().cwiseSqrt().asDiagonal();
128 
129  Index jPos = 0;
130  Index jNeg = 0;
131 
132  ScalarMatrix mYPos(mY.rows(), posCount);
133  ScalarMatrix mYNeg(mY.rows(), negCount);
134 
135  for(Index j = 0; j < d.mD.size(); j++)
136  if (d.mD(j,0) >= 0)
137  mYPos.col(jPos++) = mY.col(j);
138  else
139  mYNeg.col(jNeg++) = mY.col(j);
140 
141  assert(jPos == posCount);
142  assert(jNeg == negCount);
143 
144  Real posNorm = d.fs->k(d.mX, mYPos).norm();
145  Real negNorm = d.fs->k(d.mX, mYNeg).norm();
146  KQP_HLOG_DEBUG_F("Adding a decomposition to a merger (pos=%d / %g, neg=%d / %g)",
147  %posCount %posNorm %negCount %negNorm);
148  if (posCount > 0 && posNorm / negNorm > Eigen::NumTraits<Scalar>::epsilon())
149  merger.add(1, d.mX, mYPos);
150  if (negCount > 0 && negNorm / posNorm > Eigen::NumTraits<Scalar>::epsilon())
151  merger.add(-1, d.mX, mYNeg);
152  }
153 
154 
159  void merge(bool force) {
160 
161  // Merge while the number of merged decompositions is the same for the two last decompositions
162  // (or less, to handle the case of previous unbalanced merges)
163  while (decompositions.size() >= 2 && (force || (decompositions.back().updateCount >= (decompositions.end()-2)->updateCount))) {
164 
165  const Decomposition<Scalar> &d1 = decompositions[decompositions.size() - 2];
166  const Decomposition<Scalar> &d2 = decompositions[decompositions.size() - 1];
167 
168  KQP_HLOG_DEBUG_F("Starting the merge [force=%d, level=%d] of two decompositions [%d/%d;%d] and [%d/%d;%d]",
169  %force %decompositions.size()
170  %d1.mD.rows() %d1.mX->size() %d1.updateCount
171  %d2.mD.rows() %d2.mX->size() %d2.updateCount);
172 
173  merger->reset();
174  merge(*merger, d1);
175  merge(*merger, d2);
176 
177  decompositions.pop_back();
178  decompositions.pop_back();
179 
180 
181  // Push back new decomposition
182  decompositions.push_back(merger->getDecomposition());
183  auto &d = decompositions.back();
184  assert(!kqp::isNaN(d.fs->k(d.mX, d.mY, d.mD).squaredNorm()));
185  if (mergerCleaner.get())
186  mergerCleaner->cleanup(d);
187  d.updateCount = d1.updateCount + d2.updateCount;
188  assert(!kqp::isNaN(d.fs->k(d.mX, d.mY, d.mD).squaredNorm()));
189 
190  KQP_HLOG_DEBUG_F("Merged [force=%d, level=%d] into [rank= %d, pre-images=%d; updates=%d]",
191  %force %decompositions.size()
192  %d.mD.rows() %d.mX->size() %d.updateCount);
193 
194 
195  }
196  }
197 
198 
199  private:
200 
202  std::vector<Decomposition<Scalar> > decompositions;
203 
205  Index batchSize;
206 
207  boost::shared_ptr<KernelEVD<Scalar> > builder;
208  boost::shared_ptr<Cleaner<Scalar> > builderCleaner;
209 
210  boost::shared_ptr<KernelEVD<Scalar> > merger;
211  boost::shared_ptr<Cleaner<Scalar> > mergerCleaner;
212  };
213 }
214 
215 #endif
216