300字范文,内容丰富有趣,生活中的好帮手!
300字范文 > SVM分类器C++语言实现

SVM分类器C++语言实现

时间:2020-10-21 18:55:15

相关推荐

SVM分类器C++语言实现

为部分代码,只做参考。文中很多变量类型为自己定义的数据结构。遗憾的是,纯C实现的SVM代码找不到了,有空再写一个吧

头文件:

#ifndef SVM_C_H

#define SVM_C_H

#include"Process.h"

extern void Label(std::vector<sample_type> &PSample,std::vector<sample_type> &NSample);//标签赋值;

extern void Train();

extern void Test();

extern void Classify();

extern void ParamsSelection(std::vector<sample_type> &PSample,std::vector<sample_type> &NSample);

extern void Train_opencv_ori();

extern void Train_opencv_opt();

extern void Classify_opencv(int mode);

void SVM_Params_ori();

void SVM_Params_opt();

#endif;

cpp文件:

#include "StdAfx.h"

#include"Svm_c.h"

dlib::svm_c_trainer<kernel_type>trainer;

std::vector<sample_type>AllSamples;

std::vector<double>All_labels;

funct_type learned_function;

dlib::vector_normalizer<sample_type>normalizer;

dlib::rand rnd;

//cv::Mat Classes;

CvSVMParams SVM_params;

CvSVM svm;

int respones;

int PrePoNum=0;

int PreNgNum=0;

std::vector<int>PSIndex;//分为正样本的索引;

std::vector<int>NGIndex;//分为负样本的索引;

CvParamGrid nuGrid;

CvParamGrid coeffGrid;

CvParamGrid degreeGrid;

void Label(std::vector<sample_type> &PSample,std::vector<sample_type> &NSample)

{

std::cout<<"labeling..."<<std::endl;

AllSamples.clear();

int PSnum=PSample.size();

int NGnum=NSample.size();

int Num=0;

Num=PSnum+NGnum;

if(PSnum>0&&NGnum>0)

{

for(int i=0;i<Num;++i)

{

if(i<PSnum)

{

AllSamples.push_back(PSample[i]);

All_labels.push_back(1);

}

else

{

AllSamples.push_back(NSample[i-PSnum]);

All_labels.push_back(-1);

}

}

normalizer.train(AllSamples);

for(unsigned long i=0;i<AllSamples.size();++i)

{

AllSamples[i]=normalizer(AllSamples[i]);

std::cout<<AllSamples[i](0)<<" "<<AllSamples[i](1)<<" "<<AllSamples[i](2)<<" "<<AllSamples[i](3)<<std::endl;

}

dlib::randomize_samples(AllSamples,All_labels);//可要可不要;

}

else

std::cout<<"训练样本无效!"<<std::endl;

}

void Train()

{

if(AllSamples.size()>0)

{

std::cout<<"Doing cross calidation"<<std::endl;

for(double gamma=0.00001;gamma<=1;gamma*=5)

{

for(double C=1;C<100000;C*=5)

{

trainer.set_kernel(kernel_type(gamma));

trainer.set_c(C);

std::cout<<"gamma: "<<gamma<<" C: "<<C;

std::cout<<" cross validation accuracy: "<<dlib::cross_validate_trainer(trainer,AllSamples,All_labels,10);

}

}

learned_function.normalizer=normalizer;

learned_function.function=trainer.train(AllSamples,All_labels);

std::cout<<"\nnumber of support vector in our learned_functions are "<<learned_function.function.basis_vectors.size()<<std::endl;

dlib::serialize("saved_function.dat")<<learned_function;

}

else

std::cout<<"无法训练!"<<std::endl;

}

void Test()

{

trainer.set_kernel(kernel_type(0.00625));

trainer.set_c(5);

}

void Classify()

{

dlib::deserialize("saved_function.dat")>>learned_function;

if(CSample.size()>0)

{

for(int i=0;i<CSample.size();i++)

std::cout<<"分类结果: "<<learned_function(CSample[i])<<std::endl;

}

else

{

std::cout<<"无测试样本。"<<std::endl;

}

std::cout<<"分类结束!"<<endl;

}

void ParamsSelection(cv::vector<sample_type> &Features_All_p,cv::vector<sample_type> &Features_All_n)

{

std::cout<<"开始选择参数..."<<std::endl;

//double gamma=1.0/(dlib::compute_mean_squared_distance(dlib::randomly_subsample(All_samples,20)));

const double gamma=dlib::verbose_find_gamma_with_big_centroid_gap(AllSamples,All_labels);

dlib::kcentroid<kernel_type> kc(kernel_type(gamma),0.001,40);//最后一个参数可以调整;

std::cout<<dlib::rank_features(kc,AllSamples,All_labels)<<std::endl;

}

void Train_opencv_ori()

{

SVM_Params_ori();

std::cout<<"开始训练"<<std::endl;

svm.train(trainingdatas,Classes,cv::Mat(),cv::Mat(),SVM_params);

std::cout<<"SVM分类器训练完毕。"<<std::endl;

svm.save("svm_ori.xml");

std::cout<<"模型保存完毕。"<<std::endl;

}

void Train_opencv_opt()

{

SVM_Params_opt();

std::cout<<"开始训练"<<std::endl;

svm.train_auto(trainingdatas,Classes,cv::Mat(),cv::Mat(),SVM_params,10,svm.get_default_grid(CvSVM::C),svm.get_default_grid(CvSVM::GAMMA),svm.get_default_grid(CvSVM::P),nuGrid,coeffGrid,degreeGrid);

CvSVMParams SVM_params_return=svm.get_params();

std::cout<<"SVM分类器训练完毕。"<<std::endl;

svm.save("svm_opt_15_0925_16features.xml");

std::cout<<"模型保存完毕。"<<std::endl;

}

void SVM_Params_opt()

{

SVM_params.svm_type=CvSVM::C_SVC;

SVM_params.kernel_type=CvSVM::RBF;

SVM_params.C=1;

SVM_params.gamma=0.0001;

SVM_params.term_crit=cvTermCriteria(CV_TERMCRIT_ITER,15000,0.001);

CvParamGrid nuGrid=CvParamGrid(1,1,0.0);

CvParamGrid coeffGrid=CvParamGrid(1,1,0.0);

CvParamGrid degreeGrid=CvParamGrid(1,1,0.0);

}

void SVM_Params_ori()//设置SVM参数;

{

SVM_params.svm_type=CvSVM::C_SVC;

SVM_params.kernel_type=CvSVM::RBF;

SVM_params.degree=0;

SVM_params.gamma=1;

SVM_params.coef0=0;

SVM_params.C=1;

SVM_params.nu=0;

SVM_params.p=0;

SVM_params.term_crit=cvTermCriteria(CV_TERMCRIT_ITER,100000,0.01);

}

void Classify_opencv(int mode)

{

string modelpath;

if(mode==1)

modelpath="svm_ori.xml";

if(mode==2)

modelpath="svm_opt_15_0925_16features.xml";

FileStorage svm_fs(modelpath,FileStorage::READ);

if(svm_fs.isOpened())

{

respones=0;

svm.load(modelpath.c_str());

std::cout<<std::endl;

PSIndex.clear();

NGIndex.clear();

std::cout<<"开始分类"<<std::endl;

//std::cout<<PredictingDatas<<std::endl;

for(int i=0;i<PredictingDatas.rows;i++)

{

Mat classMat=PredictingDatas.rowRange(i,i+1);

classMat=classMat.reshape(1,1);

respones=(int)svm.predict(classMat);

if(respones==1)

{

PrePoNum++;

PSIndex.push_back(i);

}

else

{

PreNgNum++;

NGIndex.push_back(i);

}

}

//std::cout<<"正样本路径: "<<std::endl;

PrintFileName(PSIndex,1);

// std::cout<<"负样本路径: "<<std::endl;

PrintFileName(NGIndex,2);

std::cout<<"分类结束!正样本数: "<<PrePoNum<<" 负样本数: "<<PreNgNum<<std::endl;

}

}

代码数据结构参考:

#ifndef DATASTRUCT_H

#define DATASTRUCT_H

#include<iostream>

#include<dlib/svm.h>

#include<vector>

#include<dlib/rand.h>

#include<opencv2\highgui\highgui.hpp>

#include<opencv2\imgproc\imgproc.hpp>

#include<opencv2\core\core.hpp>

#include<opencv2\ml\ml.hpp>

#include<opencv\cv.hpp>

#include<stdio.h>

#include<stdlib.h>

#include<math.h>

#include<fstream>

#include<io.h>

#include<cassert>

#include<iterator>

#include<functional>

#include<algorithm>

#include<opencv2/opencv.hpp>

#define FEATURESNUM 4

typedef std::vector<std::vector<int> >Vec2D;

typedef struct _GLMCFeatures

{

_GLMCFeatures():energy(0.0),entropy(0.0),contrast(0.0),idmoment(0.0)

{

}

double energy;

double entropy;

double contrast;

double idmoment;

}GLCMFeatures;

typedef struct _StandValue

{

_StandValue():mean_train_energy(0.0),mean_train_entropy(0.0),mean_train_contrast(0.0),mean_train_idmoment(0.0),sigma_train_energy(0.0),sigma_train_entropy(0.0),sigma_train_contrast(0.0),sigma_train_idmoment(0.0)

{}

double mean_train_energy;

double mean_train_entropy;

double mean_train_contrast;

double mean_train_idmoment;

double sigma_train_energy;

double sigma_train_entropy;

double sigma_train_contrast;

double sigma_train_idmoment;

}

StandValue;

typedef struct _NormaData

{

_NormaData():mean_energy_hor(0.0),mean_entropy_hor(0.0),mean_contrast_hor(0.0),mean_idmoment_hor(0.0),

mean_energy_ver(0.0),mean_entropy_ver(0.0),mean_contrast_ver(0.0),mean_idmoment_ver(0.0),

mean_energy_45(0.0),mean_entropy_45(0.0),mean_contrast_45(0.0),mean_idmoment_45(0.0),

mean_energy_135(0.0),mean_entropy_135(0.0),mean_contrast_135(0.0),mean_idmoment_135(0.0),

sum_energy_hor(0.0), sum_entropy_hor(0.0), sum_contrast_hor(0.0), sum_idmoment_hor(0.0),

sum_energy_ver(0.0), sum_entropy_ver(0.0), sum_contrast_ver(0.0), sum_idmoment_ver(0.0),

sum_energy_45(0.0), sum_entropy_45(0.0), sum_contrast_45(0.0), sum_idmoment_45(0.0),

sum_energy_135(0.0), sum_entropy_135(0.0), sum_contrast_135(0.0), sum_idmoment_135(0.0),

pow_energy_hor(0.0),pow_entropy_hor(0.0),pow_contrast_hor(0.0),pow_idmoment_hor(0.0),

pow_energy_ver(0.0),pow_entropy_ver(0.0),pow_contrast_ver(0.0),pow_idmoment_ver(0.0),

pow_energy_45(0.0),pow_entropy_45(0.0),pow_contrast_45(0.0),pow_idmoment_45(0.0),

pow_energy_135(0.0),pow_entropy_135(0.0),pow_contrast_135(0.0),pow_idmoment_135(0.0),

spow_energy_hor(0.0),spow_entropy_hor(0.0),spow_contrast_hor(0.0),spow_idmoment_hor(0.0),

spow_energy_ver(0.0),spow_entropy_ver(0.0),spow_contrast_ver(0.0),spow_idmoment_ver(0.0),

spow_energy_45(0.0),spow_entropy_45(0.0),spow_contrast_45(0.0),spow_idmoment_45(0.0),

spow_energy_135(0.0),spow_entropy_135(0.0),spow_contrast_135(0.0),spow_idmoment_135(0.0)

{}

double mean_energy_hor;

double mean_entropy_hor;

double mean_contrast_hor;

double mean_idmoment_hor;

double sum_entropy_hor;

double sum_energy_hor;

double sum_contrast_hor;

double sum_idmoment_hor;

double pow_entropy_hor;

double pow_energy_hor;

double pow_contrast_hor;

double pow_idmoment_hor;

double spow_entropy_hor;

double spow_energy_hor;

double spow_contrast_hor;

double spow_idmoment_hor;

double mean_energy_ver;

double mean_entropy_ver;

double mean_contrast_ver;

double mean_idmoment_ver;

double sum_entropy_ver;

double sum_energy_ver;

double sum_contrast_ver;

double sum_idmoment_ver;

double pow_entropy_ver;

double pow_energy_ver;

double pow_contrast_ver;

double pow_idmoment_ver;

double spow_entropy_ver;

double spow_energy_ver;

double spow_contrast_ver;

double spow_idmoment_ver;

double mean_energy_45;

double mean_entropy_45;

double mean_contrast_45;

double mean_idmoment_45;

double sum_entropy_45;

double sum_energy_45;

double sum_contrast_45;

double sum_idmoment_45;

double pow_entropy_45;

double pow_energy_45;

double pow_contrast_45;

double pow_idmoment_45;

double spow_entropy_45;

double spow_energy_45;

double spow_contrast_45;

double spow_idmoment_45;

double mean_energy_135;

double mean_entropy_135;

double mean_contrast_135;

double mean_idmoment_135;

double sum_entropy_135;

double sum_energy_135;

double sum_contrast_135;

double sum_idmoment_135;

double pow_entropy_135;

double pow_energy_135;

double pow_contrast_135;

double pow_idmoment_135;

double spow_entropy_135;

double spow_energy_135;

double spow_contrast_135;

double spow_idmoment_135;

}NormaData;

//dlib库的相关变量定义;

typedef dlib::matrix<double,4,1>sample_type;

typedef dlib::radial_basis_kernel<sample_type>kernel_type;

typedef dlib::radial_basis_kernel<sample_type>kernel_type;

typedef dlib::decision_function<kernel_type>dec_funct_type;

typedef dlib::normalized_function<dec_funct_type>funct_type;

//灰度共生矩阵相关定义;

extern Vec2D Vec_hor;

extern Vec2D Vec_ver;

extern Vec2D Vec_45;

extern Vec2D Vec_135;

extern StandValue standValue_hor;

extern StandValue standValue_ver;

extern StandValue standValue_45;

extern StandValue standValue_135;

extern GLCMFeatures features_hor;

extern GLCMFeatures features_ver;

extern GLCMFeatures features_45;

extern GLCMFeatures features_135;

extern dlib::svm_c_trainer<kernel_type>trainer;

extern std::vector<sample_type>PSample;

extern std::vector<sample_type>NSample;

extern std::vector<sample_type>CSample;

extern std::vector<sample_type>AllSamples;

extern std::vector<double>All_labels;

extern funct_type learned_function;

extern dlib::rand rnd;

extern std::string RootFileName_P;

extern std::string RootFileName_N;

extern std::string RootFileName_C;

extern std::string RootSavePath_PS;

extern std::string RootSavePath_NG;

extern std::string RootPath_glcm;

extern std::string RootProjectPath;

//extern std::string SavePath_glcm_n;

extern std::vector<std::string>Vec_ImageFiles_p;

extern std::vector<std::string>Vec_ImageFiles_n;

extern std::vector<std::string>Vec_ImageFiles_c;

extern std::vector<std::string>Vec_RoiFiles_p;

extern std::vector<std::string>Vec_RoiFiles_n;

extern std::vector<std::string>Vec_RoiFiles_c;

extern std::vector<std::string>Vec_RoiFiles_A;

extern std::vector<cv::Mat>BMPImages_p;

extern std::vector<cv::Mat>BMPImages_n;

extern std::vector<cv::Mat>BMPImages_c;

extern std::vector<cv::Mat>BMPclass_p;

extern std::vector<cv::Mat>BMPclass_n;

extern std::vector<cv::Mat>ROI_p;

extern std::vector<cv::Mat>ROI_n;

extern std::vector<cv::Mat>ROI_c;

extern std::string tempSave;

extern cv::Mat Classes;

extern std::vector<int>trainLabels;

extern cv::Mat trainAlldatas;

extern cv::Mat trainingdatas;

extern cv::Mat PreDatas;

extern cv::Mat PredictingDatas;

extern CvSVMParams SVM_params;

extern CvSVM svm;

extern int respones;

extern int PrePoNum;

extern int PreNgNum;

extern std::vector<int>PSIndex;//分为正样本的索引;

extern std::vector<int>NGIndex;//分为负样本的索引;

extern int savenum;

#endif

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。