kivantium活動日記

プログラムを使っていろいろやります

C++によるCARTアルゴリズムの実装

注意:この記事は書きかけです!!

決定木アルゴリズムの中で基本的なCARTアルゴリズムを実装しました

決定木アルゴリズムとは

決定木アルゴリズムは条件分岐で分類や回帰を行うようなアルゴリズムです。フローチャートのようなイメージでいいでしょう。
f:id:kivantium:20150325180354p:plain:w400
これをうまいことやるアルゴリズムの一つがCARTアルゴリズムです。

CARTアルゴリズム

次の3クラス分類問題を考えます。
f:id:kivantium:20150325180450p:plain:w400
この3クラスを一本の直線で一番きれいに分けるとしたらおそらくこうなります。
f:id:kivantium:20150325180655p:plain:w400
分けた結果をさらに分けるとしたらこうなります。
f:id:kivantium:20150325180814p:plain:w400

この「きれいに分ける直線の向き」と「直線の位置」を決めるためにCARTアルゴリズムを使います。

(以下元気があるときに追記します)

ソースコード

2次元での3クラス分類まではうまくいくことを試しましたが、3次元以上や4クラス以上の分類はなんだか微妙な感じなのでどこか間違っている可能性もありますがご容赦ください

#include<iostream>

class Node {
private:
    // number of class
    int n_class;
    // dimension of data
    int dimension;
    // size of data
    int size;
    // pointer to data
    float *data;
    // pointer to label
    int *label;
    // Gini index
    float G;

public:
    // whether all data belong to the same class
    bool isSame;
    // whether the node is leaf or not
    bool isLeaf;

    // best attribute to split node
    int spl_attr;
    // best threshold to split node
    float spl_theta;
    // index of left child node
    int child;
    // most probable class in this node
    int predict_class;

    /* caluculate impurity by Gini funtion
     * label: label of each data 
     * size: the size of label */
    float impurity(int* label, int size){
        if(size == 0) return 0;
        float result = 1.0;
        for(int i=0; i<n_class; i++){
            int count = 0;
            for(int j=0; j<size; j++){
                if(label[j] == i) count++;
            }
            result -= (float)(count/size)*(count/size);
        }
        return result;
    }

    // constructor
    Node(float *arg_data, int *arg_label, int arg_size, int arg_n_class, int arg_dimension){
        // set basic information of the node
        isLeaf = true;
        data = arg_data;
        size = arg_size;
        label = arg_label;
        n_class = arg_n_class;
        dimension = arg_dimension;

        // check whether all data belong to the same class
        isSame = true;
        for(int i=1; i<size; i++){
            if(label[i-1] != label[i]){
                isSame = false;
                G = 0;
            }
        }
        // calculate G and criteria
        if(isSame == false){
            // label of left node
            int *left = new int[size];
            int left_size = 0;
            // label of right node
            int *right = new int[size];
            int right_size = 0;
            // decide best criteria to split the node
            int best_attr;
            float best_theta;
            float maxG = -99999;
            for(int j=0; j<dimension; j++){
                for(int i=0; i<size; i++){
                    float theta = data[i*dimension+j];
                    left_size = 0;
                    right_size = 0;
                    for(int k=0; k<size; k++){
                        if(data[k*dimension+j] < theta){
                            left[left_size] = label[k];
                            left_size++;
                        } else{
                            right[right_size] = label[k];
                            right_size++;
                        }
                    }
                    float tmpG = impurity(label, size)
                        - (float)(left_size*impurity(left, left_size)/size)
                        - (float)(right_size*impurity(right, right_size)/size); 
                    if(tmpG > maxG){
                        best_attr = j;
                        best_theta = theta;
                        maxG = tmpG;
                    }
                }
            }
            G = maxG;
            spl_attr = best_attr;
            spl_theta = best_theta;
            delete[] right;
            delete[] left;
        }
        // decide most probable class
        int max = -1;
        for(int i=0; i<n_class; i++){
            int tmp = 0;
            for(int j=0; j<size; j++){
                if(label[j] == i) tmp++;
            }
            if(tmp > max){
                max = tmp;
                predict_class = i;
            }
        }
    }

    float getF(int parent_size){
        return (float)(G*size/parent_size);
    }
};



class DecisionTree {
    private:
        // number of class
        int n_class;
        // dimension of data
        int dimension;
        // max number of node
        int max_node;
        // current size of node
        int node_size;
        // array of node
        Node **node_tree;
    public:
        // constructor
        DecisionTree(int num_of_class, int dim_of_vec, int max_num_of_node = 10){
            n_class = num_of_class;
            dimension = dim_of_vec;
            max_node = max_num_of_node;
            node_size = 0;
            node_tree = new Node*[max_node];
        }

        /* generate decisiontree
         * data: train data
         * label: label of train data
         * size: size of train data */
        void generate(float* data, int* label, int size){
            // array to contain data array
            float **node_data = new float*[max_node];
            // array to contain label array
            int **node_label = new int*[max_node];
            // array to contain data size
            int *node_data_size = new int[max_node];
            // create root node
            node_tree[0] = new Node(data, label, size, n_class, dimension);
            // copy data and label
            node_data[0] = new float[size*dimension];
            node_label[0] = new int[size];
            for(int i=0; i<size; i++){
                for(int j=0; j<dimension; j++)node_data[0][i*dimension+j] = data[i*dimension+j];
                node_label[0][i] = label[i];
            }
            node_data_size[0] = size;
            for(node_size=0; node_size<max_node-2; node_size+=2){
                // search argmax F
                int spl_node = -1;
                for(int i=0; i<=node_size; i++){
                    float maxF = -99999;
                    if(node_tree[i]->isLeaf == true && node_tree[i]->isSame == false){
                        float F = node_tree[i]->getF(size);
                        if(F > maxF){
                            maxF = F;
                            spl_node = i;
                        }
                    }
                }
                //if classified correctly, finish generation
                if(spl_node == -1) break;
                // if not, split the node
                std::cout << "split node" << spl_node << std::endl;
                std::cout << "attr: " << node_tree[spl_node]->spl_attr 
                    << ", theta: " << node_tree[spl_node]->spl_theta<< std::endl;
                // create new data
                std::cout << "node data: " << node_data_size[spl_node] << std::endl;
                for(int i=0; i<node_data_size[spl_node]; i++){
                    std::cout << i << ": ";
                    for(int j=0; j<dimension; j++){
                        std::cout << node_data[spl_node][i*dimension+j] << " ";
                    }
                    std::cout << std::endl;
                }

                node_data[node_size+1] = new float[size*dimension];
                node_data[node_size+2] = new float[size*dimension];
                node_label[node_size+1] = new int[size];
                node_label[node_size+2] = new int[size];
                int left_size = 0;
                int right_size = 0;
                //std::cout << "repeat: " << node_data_size[spl_node] << std::endl;
                for(int i=0; i<node_data_size[spl_node]; i++){
                    int attribute = node_tree[spl_node]->spl_attr;
                    float threshold = node_tree[spl_node]->spl_theta;
                    if(node_data[spl_node][i*dimension+attribute] < threshold){
                        for(int j=0; j<dimension; j++){
                            node_data[node_size+1][left_size*dimension+j] = node_data[spl_node][i*dimension+j];
                        }
                        node_label[node_size+1][left_size] = node_label[spl_node][i];
                        left_size++;
                    } else {
                        for(int j=0; j<dimension; j++){
                            node_data[node_size+2][right_size*dimension+j] = node_data[spl_node][i*dimension+j];
                        }
                        node_label[node_size+2][right_size] = node_label[spl_node][i];
                        right_size++;
                    }
                }
                node_data_size[node_size+1] = left_size;
                node_data_size[node_size+2] = right_size;
                //create new nodes
                node_tree[node_size+1]
                    = new Node(node_data[node_size+1], node_label[node_size+1], left_size, n_class, dimension);
                node_tree[node_size+2]
                    = new Node(node_data[node_size+2], node_label[node_size+2], right_size, n_class, dimension);
                node_tree[spl_node]->isLeaf = false;
                node_tree[spl_node]->child = node_size+1;
                std::cout << "left node size: " << left_size<< std::endl;
                for(int i=0; i<left_size; i++){
                    for(int j=0; j<dimension; j++){
                        std::cout << node_data[node_size+1][i*dimension+j] << " ";
                    }
                    std::cout << std::endl;
                }
                std::cout << "right node size: " << right_size << std::endl;
                for(int i=0; i<right_size; i++){
                    for(int j=0; j<dimension; j++){
                        std::cout << node_data[node_size+2][i*dimension+j] << " ";
                    }
                    std::cout << std::endl;
                }
                std::cout << std::endl;
            }
            for(int i=0; i<node_size; i++){
                delete[] node_data[i];
                delete[] node_label[i];
            }
        }
        int predict(float *data){
            int current_node = 0;
            while(current_node <= max_node+1){
                if(node_tree[current_node]->isLeaf){
                    return node_tree[current_node]->predict_class;
                }else{
                    if(data[node_tree[current_node]->spl_attr] < node_tree[current_node]->spl_theta){
                        current_node = node_tree[current_node]->child;
                    } else{
                        current_node = node_tree[current_node]->child+1;
                    }
                }
            }
            return -1;
        }
};

int main(void){
    const int num_of_class = 3;
    const int dimension = 2;
    const int max_node = 5;
    const int size = 7;
    DecisionTree mytree(num_of_class, dimension, max_node);
    float data[dimension*size]
        ={1,1.5,
          1.5,0.5,
          2,1,
          2.5,1.5,
          3,0.3,
          2.5,3.0,
          1.2,4};
    int label[size] = {0,0,1,1,1,2,2};
    mytree.generate(data, label, size);
    std::cout << "result of prediction" << std::endl;
    std::cout << mytree.predict(data+0) << std::endl;
    std::cout << mytree.predict(data+2) << std::endl;
    std::cout << mytree.predict(data+4) << std::endl;
    std::cout << mytree.predict(data+6) << std::endl;
    std::cout << mytree.predict(data+8) << std::endl;
    std::cout << mytree.predict(data+10) << std::endl;
    std::cout << mytree.predict(data+12) << std::endl;
    return 0;
}

この7データに関してはうまく分類できました。他のも分類したいですがデバッグをどうやればいいのか分からないので止まっています。実装大変だった……