00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 #ifndef _KDTREE_H
00016 #define _KDTREE_H
00017
00018 #include <cmath>
00019 #include <list>
00020 #include <functional>
00021 #include <cassert>
00022
00023 template <class VALUE>
00024 class kdtree {
00025 public:
00026 struct KDNode;
00027 int kdim;
00028 KDNode* root;
00029
00030 public:
00031
00037 kdtree(unsigned int kdim, std::list<float*>* dataset, std::list<VALUE>* userset = NULL) {
00038 if (kdim < 1) throw "kdim has to be 1 at least.";
00039 if (dataset == NULL) throw "dataset may not be NULL.";
00040 this->kdim = kdim;
00041 root = build(kdim, 0, 0, dataset, userset);
00042 if (root == NULL) throw "kd-build returned NULL.";
00043 }
00044
00045 ~kdtree() {
00046
00047 delete root;
00048
00049 }
00050
00054 int size() {
00055 return root->size();
00056 }
00057
00061 int dim() {
00062 return kdim;
00063 }
00064
00068 bool empty() {
00069 return false;
00070 }
00071
00079 std::list<float*>* findDataset(float* datavec) {
00080 std::list<float*>* original = root->findDataset(kdim, 0, datavec);
00081 return new std::list<float*>(original->begin(), original->end());
00082 }
00083
00090 std::list<float*>* findDatasetInterval(float* min_datavec, float* max_datavec) {
00091 return root->findDatasetInterval(kdim, 0, min_datavec, max_datavec);
00092 }
00093
00102 std::list<VALUE>* findUserset(float* datavec) {
00103 std::list<VALUE>* original = root->findUserset(kdim, 0, datavec);
00104 return new std::list<VALUE>(original->begin(), original->end());
00105 }
00106
00113 std::list<VALUE>* findUsersetInterval(float* min_datavec, float* max_datavec) {
00114 return root->findUsersetInterval(kdim, 0, min_datavec, max_datavec);
00115 }
00116
00117
00118 public:
00119
00140 KDNode * build(unsigned int kdim, unsigned int curlvl, unsigned int equaldims, std::list<float*>* dataset, std::list<VALUE>* userset = NULL) {
00141
00142 unsigned int index = curlvl % kdim;
00143
00144
00145
00146
00147 if (dataset->size() == 1 || equaldims == kdim) {
00148
00149
00150
00151
00152
00153
00154 float* v = dataset->front();
00155 KDNode* node = new KDNode();
00156 node->left = NULL;
00157 node->right = NULL;
00158 node->comparevalue = v[index];
00159 node->dataset = dataset;
00160 node->userset = userset;
00161
00162 if (curlvl == 0) {
00163 node->dataset = new std::list<float*>;
00164 node->dataset->insert(node->dataset->begin(), dataset->begin(), dataset->end());
00165 if (userset != NULL) {
00166 node->userset = new std::list<VALUE>;
00167 node->userset->insert(node->userset->begin(), userset->begin(), userset->end());
00168 }
00169 }
00170 if (equaldims == kdim) {
00171
00172 throw node;
00173 }
00174
00175 return node;
00176 }
00177
00178
00179 float min = +10000000.0f;
00180 float max = -10000000.0f;
00181 for (std::list<float*>::iterator i = dataset->begin(); i != dataset->end(); i++) {
00182 float value = (*i)[index];
00183 if (value < min) min = value;
00184 if (value > max) max = value;
00185 }
00186
00187
00188
00189 float extends_inverse = (min < max) ? (1.0f / (max - min)) : (1.0f);
00190
00191
00192
00193
00194
00195 unsigned int nbins = 2 + (int) (log(dataset->size()) * 1.443);
00196 unsigned int* bins = new unsigned int[nbins];
00197 memset(bins, 0, sizeof (unsigned int) * nbins);
00198
00199
00200
00201 float factor = (nbins - 1) * extends_inverse;
00202 bool randsamp = dataset->size() > 16 ? true : false;
00203 int cnt = 0;
00204 for (std::list<float*>::iterator i = dataset->begin(); i != dataset->end(); i++) {
00205 cnt++;
00206 if (randsamp && ((cnt&7) == 7)) continue;
00207 float value = (*i)[index];
00208 if (value != value) throw "Not-A-Number discovered while building kd-tree!";
00209 unsigned int bin = (unsigned int) ((value - min) * factor);
00210
00211 bins[bin]++;
00212 }
00213
00214
00215
00216
00217
00218 int pre = 0;
00219 int count = 0;
00220 int half = dataset->size() / 2;
00221 unsigned int split = 0;
00222 for (split = 0; split < nbins; split++) {
00223 count += bins[split];
00224 if (count > half && pre > 0) {
00225 break;
00226 }
00227 pre = count;
00228 }
00229 float splitvalue = split / factor + min;
00230
00231
00232
00233 delete bins;
00234
00235
00236
00237 typename std::list<VALUE>::iterator j;
00238 std::list<VALUE>* left_userset = NULL;
00239 std::list<VALUE>* right_userset = NULL;
00240 if (userset != NULL) {
00241 left_userset = new std::list<VALUE>;
00242 right_userset = new std::list<VALUE>;
00243 j = userset->begin();
00244 }
00245
00246 std::list<float*>* left_list = new std::list<float*>;
00247 std::list<float*>* right_list = new std::list<float*>;
00248 for (std::list<float*>::iterator i = dataset->begin(); i != dataset->end(); i++) {
00249 float value = (*i)[index];
00250 if (value < splitvalue) {
00251 left_list->push_back((*i));
00252 if (left_userset != NULL) left_userset->push_back((*j));
00253 } else {
00254 right_list->push_back((*i));
00255 if (right_userset != NULL) right_userset->push_back((*j));
00256 }
00257 if (userset != NULL) j++;
00258 }
00259
00260
00261
00262 if (curlvl > 0) {
00263 delete dataset;
00264 if (userset != NULL) {
00265 delete userset;
00266 }
00267 }
00268
00269
00270 KDNode* node = new KDNode();
00271 node->comparevalue = splitvalue;
00272 node->dataset = NULL;
00273 node->userset = NULL;
00274
00275
00276 if (left_list->size() == 0 || right_list->size() == 0) equaldims++;
00277 else equaldims = 0;
00278
00279
00280
00281
00282
00283
00284
00285
00286
00287
00288
00289
00290
00291 if (left_list->size() > 0) {
00292
00293 try {
00294 node->left = build(kdim, curlvl + 1, equaldims, left_list, left_userset);
00295 } catch (KDNode* nd) {
00296
00297
00298 if (equaldims == 0) {
00299 node->left = nd;
00300 } else {
00301 delete node;
00302
00303
00304 throw nd;
00305 }
00306 }
00307 } else {
00308
00309 delete left_list;
00310 delete left_userset;
00311 node->left = NULL;
00312 }
00313
00314 if (right_list->size() > 0) {
00315
00316 try {
00317 node->right = build(kdim, curlvl + 1, equaldims, right_list, right_userset);
00318 } catch (KDNode* nd) {
00319
00320
00321 if (equaldims == 0) {
00322 node->right = nd;
00323 } else {
00324 delete node;
00325
00326
00327 throw nd;
00328 }
00329 }
00330 } else {
00331
00332 delete right_list;
00333 delete right_userset;
00334 node->right = NULL;
00335 }
00336
00337 return node;
00338 }
00339
00345 struct KDNode {
00347 float comparevalue;
00349 KDNode *left;
00351 KDNode *right;
00353 std::list<float*>* dataset;
00355 std::list<VALUE>* userset;
00356
00360 template <typename T>
00361 struct no_specific_order : std::binary_function<T, T, bool> {
00363
00364 bool operator()(const T& x, const T & y) const {
00365 return true;
00366 };
00367 };
00368
00369 KDNode() {
00370 comparevalue = 0;
00371 left = NULL;
00372 right = NULL;
00373 dataset = NULL;
00374 userset = NULL;
00375 }
00376
00377 ~KDNode() {
00378 if (dataset) delete dataset;
00379 if (userset) delete userset;
00380 if (left) delete left;
00381 if (right) delete right;
00382 }
00383
00387 void printLeavesize() {
00388 if (this->left == NULL && this->right == NULL) {
00389 printf("%i\n", this->dataset->size());
00390 }
00391 if (this->left != NULL) this->left->printLeavesize();
00392 if (this->right != NULL) this->right->printLeavesize();
00393 }
00394
00398 int size() {
00399 if (this->left == NULL && this->right == NULL) return this->dataset->size();
00400 int sum = 0;
00401 if (this->left != NULL) sum += this->left->size();
00402 if (this->right != NULL) sum += this->right->size();
00403 return sum;
00404 }
00405
00413 std::list<float*>* findDatasetInterval(unsigned int kdim, unsigned int curlvl, float* min_datavec, float* max_datavec) {
00414 const float epsilon = 0.00001f;
00415
00416
00417
00418 if (this->left == NULL && this->right == NULL) {
00419 std::list<float*>* ret = new std::list<float*>;
00420 ret->insert(ret->begin(), this->dataset->begin(), this->dataset->end());
00421
00422 return ret;
00423 }
00424
00425
00426 unsigned int index = curlvl % kdim;
00427
00428
00429 if (max_datavec[index] < this->comparevalue) {
00430
00431
00432
00433 if (this->left != NULL) return this->left->findDatasetInterval(kdim, curlvl + 1, min_datavec, max_datavec);
00434 else return this->right->findDatasetInterval(kdim, curlvl + 1, min_datavec, max_datavec);
00435
00436
00437 } else if (min_datavec[index] >= this->comparevalue) {
00438
00439
00440
00441 if (this->right != NULL) return this->right->findDatasetInterval(kdim, curlvl + 1, min_datavec, max_datavec);
00442 else return this->left->findDatasetInterval(kdim, curlvl + 1, min_datavec, max_datavec);
00443
00444
00445 } else {
00446
00447
00448 std::list<float*>* ret = new std::list<float*>;
00449
00450
00451 if (this->left != NULL) {
00452
00453 float* left_min = new float[kdim];
00454 float* left_max = new float[kdim];
00455 memcpy(left_min, min_datavec, sizeof (float) * kdim);
00456 memcpy(left_max, max_datavec, sizeof (float) * kdim);
00457 left_max[index] = this->comparevalue - epsilon;
00458 std::list<float*>* left_list = this->left->findDatasetInterval(kdim, curlvl + 1, left_min, left_max);
00459 delete left_min;
00460 delete left_max;
00461 ret->merge(*left_list, no_specific_order<float*>());
00462 delete left_list;
00463 }
00464
00465
00466 if (this->right != NULL) {
00467
00468 float* right_min = new float[kdim];
00469 float* right_max = new float[kdim];
00470 memcpy(right_min, min_datavec, sizeof (float) * kdim);
00471 memcpy(right_max, max_datavec, sizeof (float) * kdim);
00472 right_min[index] = this->comparevalue;
00473 std::list<float*>* right_list = this->right->findDatasetInterval(kdim, curlvl + 1, right_min, right_max);
00474 delete right_min;
00475 delete right_max;
00476 ret->merge(*right_list, no_specific_order<float*>());
00477 delete right_list;
00478 }
00479
00480
00481
00482
00483
00484 return ret;
00485 }
00486 }
00487
00496 std::list<VALUE>* findUsersetInterval(unsigned int kdim, unsigned int curlvl, float* min_datavec, float* max_datavec) {
00497 const float epsilon = 0.00001f;
00498
00499
00500
00501 if (this->left == NULL && this->right == NULL) {
00502 std::list<VALUE>* ret = new std::list<VALUE>;
00503 ret->insert(ret->begin(), this->userset->begin(), this->userset->end());
00504
00505 return ret;
00506 }
00507
00508
00509 unsigned int index = curlvl % kdim;
00510
00511
00512 if (max_datavec[index] < this->comparevalue) {
00513
00514
00515
00516 if (this->left != NULL) return this->left->findUsersetInterval(kdim, curlvl + 1, min_datavec, max_datavec);
00517 else return this->right->findUsersetInterval(kdim, curlvl + 1, min_datavec, max_datavec);
00518
00519
00520 } else if (min_datavec[index] >= this->comparevalue) {
00521
00522
00523
00524 if (this->right != NULL) return this->right->findUsersetInterval(kdim, curlvl + 1, min_datavec, max_datavec);
00525 else return this->left->findUsersetInterval(kdim, curlvl + 1, min_datavec, max_datavec);
00526
00527
00528 } else {
00529
00530
00531 std::list<VALUE>* ret = new std::list<VALUE>;
00532
00533
00534 if (this->left != NULL) {
00535
00536 float* left_min = new float[kdim];
00537 float* left_max = new float[kdim];
00538 memcpy(left_min, min_datavec, sizeof (float) * kdim);
00539 memcpy(left_max, max_datavec, sizeof (float) * kdim);
00540 left_max[index] = this->comparevalue - epsilon;
00541 std::list<VALUE>* left_list = this->left->findUsersetInterval(kdim, curlvl + 1, left_min, left_max);
00542 delete left_min;
00543 delete left_max;
00544 ret->merge(*left_list, no_specific_order<VALUE>());
00545 delete left_list;
00546 }
00547
00548
00549 if (this->right != NULL) {
00550
00551 float* right_min = new float[kdim];
00552 float* right_max = new float[kdim];
00553 memcpy(right_min, min_datavec, sizeof (float) * kdim);
00554 memcpy(right_max, max_datavec, sizeof (float) * kdim);
00555 right_min[index] = this->comparevalue;
00556 std::list<VALUE>* right_list = this->right->findUsersetInterval(kdim, curlvl + 1, right_min, right_max);
00557 delete right_min;
00558 delete right_max;
00559 ret->merge(*right_list, no_specific_order<VALUE>());
00560 delete right_list;
00561 }
00562
00563
00564
00565
00566
00567 return ret;
00568 }
00569 }
00570
00578 std::list<float*>* findDataset(unsigned int kdim, unsigned int curlvl, float* datavec) {
00579
00580 if (this->left == NULL && this->right == NULL) return this->dataset;
00581
00582
00583 unsigned int index = curlvl % kdim;
00584
00585 float value = datavec[index];
00586 if (value < this->comparevalue) {
00587
00588
00589 if (this->left != NULL) return this->left->findDataset(kdim, curlvl + 1, datavec);
00590 else return this->right->findDataset(kdim, curlvl + 1, datavec);
00591 } else {
00592
00593
00594 if (this->right != NULL) return this->right->findDataset(kdim, curlvl + 1, datavec);
00595 else return this->left->findDataset(kdim, curlvl + 1, datavec);
00596 }
00597 }
00598
00607 std::list<VALUE>* findUserset(unsigned int kdim, unsigned int curlvl, float* datavec) {
00608
00609 if (this->left == NULL && this->right == NULL) return this->userset;
00610
00611
00612 unsigned int index = curlvl % kdim;
00613
00614 float value = datavec[index];
00615 if (value < this->comparevalue) {
00616
00617
00618 if (this->left != NULL) return this->left->findUserset(kdim, curlvl + 1, datavec);
00619 else return this->right->findUserset(kdim, curlvl + 1, datavec);
00620 } else {
00621
00622
00623 if (this->right != NULL) return this->right->findUserset(kdim, curlvl + 1, datavec);
00624 else return this->left->findUserset(kdim, curlvl + 1, datavec);
00625 }
00626 }
00627
00628 };
00629
00630 };
00631
00632
00633 #endif
00634
00635