首页 技术 正文
技术 2022年11月13日
0 收藏 752 点赞 4,807 浏览 5113 个字

如何在Caffe中增加一层新的Layer呢?主要分为四步:

(1)在./src/caffe/proto/caffe.proto 中增加对应layer的paramter message;

(2)在./include/caffe/***layers.hpp中增加该layer的类的声明,***表示有common_layers.hpp,

data_layers.hpp, neuron_layers.hpp, vision_layers.hpp 和loss_layers.hpp等;

(3)在./src/caffe/layers/目录下新建.cpp和.cu(GPU)文件,进行类实现。

(4)在./src/caffe/gtest/中增加layer的测试代码,对所写的layer前传和反传进行测试,测试还包括速度。(可省略,但建议加上)

这位博主添加了一个计算梯度的网络层,简介明了:

http://blog.csdn.net/shuzfan/article/details/51322976

这几位博主增加了自定义的loss层,可供参考:

http://blog.csdn.net/langb2014/article/details/50489305

http://blog.csdn.net/tangwei2014/article/details/46815231

我以添加precision_recall_loss层来学习代码,主要是precision_recall_loss_layer.cpp的实现

#include <algorithm>
#include <cfloat>
#include <cmath>
#include <vector>
#include <opencv2/opencv.hpp> #include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp" namespace caffe { //初始化,调用父类进行相应的初始化
template <typename Dtype>
void PrecisionRecallLossLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {
LossLayer<Dtype>::LayerSetUp(bottom, top);
}
//进行维度变换
template <typename Dtype>
void PrecisionRecallLossLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*> &bottom,
const vector<Blob<Dtype>*> &top) {
//同样先调用父类的Reshape,通过成员变量loss_来改变输入维度
LossLayer<Dtype>::Reshape(bottom, top);
loss_.Reshape(bottom[]->num(), bottom[]->channels(),
bottom[]->height(), bottom[]->width()); // Check the shapes of data and label 检查两个输入的维度是否想等
CHECK_EQ(bottom[]->num(), bottom[]->num())
<< "The number of num of data and label should be same.";
CHECK_EQ(bottom[]->channels(), bottom[]->channels())
<< "The number of channels of data and label should be same.";
CHECK_EQ(bottom[]->height(), bottom[]->height())
<< "The heights of data and label should be same.";
CHECK_EQ(bottom[]->width(), bottom[]->width())
<< "The width of data and label should be same.";
}
//前向传导
template <typename Dtype>
void PrecisionRecallLossLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {
const Dtype *data = bottom[]->cpu_data();
const Dtype *label = bottom[]->cpu_data();
const int num = bottom[]->num(); //num和count什么区别
const int dim = bottom[]->count() / num;
const int channels = bottom[]->channels();
const int spatial_dim = bottom[]->height() * bottom[]->width();
//存疑?
const int pnum =
this->layer_param_.precision_recall_loss_param().point_num();
top[]->mutable_cpu_data()[] = ;
//对于每个通道
for (int c = ; c < channels; ++c) {
Dtype breakeven = 0.0;
Dtype prec_diff = 1.0;
for (int p = ; p <= pnum; ++p) {
int true_positive = ; //统计每类的个数
int false_positive = ;
int false_negative = ;
int true_negative = ;

for (int i = ; i < num; ++i) {
const Dtype thresh = 1.0 / pnum * p; //计算阈值?
for (int j = ; j < spatial_dim; ++j) {
//取得相应的值和标签
const Dtype data_value = data[i * dim + c * spatial_dim + j];
const int label_value = (int)label[i * dim + c * spatial_dim + j];
//统计
if (label_value == && data_value >= thresh) {
++true_positive;
}
if (label_value == && data_value >= thresh) {
++false_positive;
}
if (label_value == && data_value < thresh) {
++false_negative;
}
if (label_value == && data_value < thresh) {
++true_negative;
}
}
}
//计算precision和recall
Dtype precision = 0.0;
Dtype recall = 0.0;
if (true_positive + false_positive > ) {
precision =
(Dtype)true_positive / (Dtype)(true_positive + false_positive);
} else if (true_positive == ) { //都是负类?
precision = 1.0;
}
if (true_positive + false_negative > ) {
recall =
(Dtype)true_positive / (Dtype)(true_positive + false_negative);
} else if (true_positive == ) {
recall = 1.0;
}
if (prec_diff > fabs(precision - recall) //如果二c者相差小
&& precision > && precision <
&& recall > && recall < ) {
breakeven = precision; //保留
prec_diff = fabs(precision - recall);
}
}
top[]->mutable_cpu_data()[] += 1.0 - breakeven; //计算误差
}
top[]->mutable_cpu_data()[] /= channels; //???
}
//反向
template <typename Dtype>
void PrecisionRecallLossLayer<Dtype>::Backward_cpu(
const vector<Blob<Dtype>*> &top,
const vector<bool> &propagate_down,
const vector<Blob<Dtype>*> &bottom) {
for (int i = ; i < propagate_down.size(); ++i) {
if (propagate_down[i]) { NOT_IMPLEMENTED; }
}
}
#ifdef CPU_ONLY
STUB_GPU(PrecisionRecallLossLayer);
#endif //注册该层
INSTANTIATE_CLASS(PrecisionRecallLossLayer);
REGISTER_LAYER_CLASS(PrecisionRecallLoss); } // namespace caffe
  1. template <typename Dtype>
  2. void PrecisionRecallLossLayer<Dtype>::Forward_cpu(
  3. const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {
  4. const Dtype *data = bottom[0]->cpu_data();
  5. const Dtype *label = bottom[1]->cpu_data();
  6. const int num = bottom[0]->num();
  7. const int dim = bottom[0]->count() / num;
  8. const int channels = bottom[0]->channels();
  9. const int spatial_dim = bottom[0]->height() * bottom[0]->width();
  10. const int pnum =
  11. this->layer_param_.precision_recall_loss_param().point_num();
  12. top[0]->mutable_cpu_data()[0] = 0;
  13. for (int c = 0; c < channels; ++c) {
  14. Dtype breakeven = 0.0;
  15. Dtype prec_diff = 1.0;
  16. for (int p = 0; p <= pnum; ++p) {
  17. int true_positive = 0;
  18. int false_positive = 0;
  19. int false_negative = 0;
  20. int true_negative = 0;
  21. for (int i = 0; i < num; ++i) {
  22. const Dtype thresh = 1.0 / pnum * p;
  23. for (int j = 0; j < spatial_dim; ++j) {
  24. const Dtype data_value = data[i * dim + c * spatial_dim + j];
  25. const int label_value = (int)label[i * dim + c * spatial_dim + j];
  26. if (label_value == 1 && data_value >= thresh) {
  27. ++true_positive;
  28. }
  29. if (label_value == 0 && data_value >= thresh) {
  30. ++false_positive;
  31. }
  32. if (label_value == 1 && data_value < thresh) {
  33. ++false_negative;
  34. }
  35. if (label_value == 0 && data_value < thresh) {
  36. ++true_negative;
  37. }
  38. }
  39. }
  40. Dtype precision = 0.0;
相关推荐
python开发_常用的python模块及安装方法
adodb:我们领导推荐的数据库连接组件bsddb3:BerkeleyDB的连接组件Cheetah-1.0:我比较喜欢这个版本的cheeta…
日期:2022-11-24 点赞:878 阅读:8,903
Educational Codeforces Round 11 C. Hard Process 二分
C. Hard Process题目连接:http://www.codeforces.com/contest/660/problem/CDes…
日期:2022-11-24 点赞:807 阅读:5,427
下载Ubuntn 17.04 内核源代码
zengkefu@server1:/usr/src$ uname -aLinux server1 4.10.0-19-generic #21…
日期:2022-11-24 点赞:569 阅读:6,244
可用Active Desktop Calendar V7.86 注册码序列号
可用Active Desktop Calendar V7.86 注册码序列号Name: www.greendown.cn Code: &nb…
日期:2022-11-24 点赞:733 阅读:6,057
Android调用系统相机、自定义相机、处理大图片
Android调用系统相机和自定义相机实例本博文主要是介绍了android上使用相机进行拍照并显示的两种方式,并且由于涉及到要把拍到的照片显…
日期:2022-11-24 点赞:512 阅读:7,687
Struts的使用
一、Struts2的获取  Struts的官方网站为:http://struts.apache.org/  下载完Struts2的jar包,…
日期:2022-11-24 点赞:671 阅读:4,726