Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >线程池管理的pipeline设计模式(用了“精进C++”里的内容)

线程池管理的pipeline设计模式(用了“精进C++”里的内容)

作者头像
用户9831583
发布于 2022-12-04 08:29:30
发布于 2022-12-04 08:29:30
1.3K00
代码可运行
举报
文章被收录于专栏:码出名企路码出名企路
运行总次数:0
代码可运行

记录最近算法工程里开发的pipeline设计模式。优化了上一版本:

1,增加了线程池管理,每个node可以异步处理;

2,增加了callback,将最后一个node的结果callback到主程序,避免的参数传递的冗余实现;

3,去掉了模板类设计,避免只能在头文件中去实现的弊端;

4,去掉了前node的输出就是后node的输入,避免函数返回值带来复制的开销的应用;

/** @ 带有线程池的pipeline pipeline里的Node可以异步执行,加快处理速度 */

task_queue.h

/** @ 线程池的任务队列 @ 入队和出队 */

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
template<class T>
class TaskQueue
{
    public:
        TaskQueue() = default;
        ~TaskQueue() = default;

        //任务入队
        void enqueue(T& t)
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            if(m_pNextQueue)
            {
                m_pNextQueue->enqueue(t);
                return;
            }
            m_queue.push(t);
        }

        //任务出队
        bool dequeue(T& t)
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            if(m_queue.empty())
                return false;

            t = std::move(m_queue.front());
            m_queue.pop();
            return true;
        }

        int32_t size()
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            return m_queue.szie();
        }

        bool empty()
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            return m_queue.empty();
        }

        //出队等待
        bool dequeue_wait(T& t,uint32_t timeout)
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            if(m_queue.empty())
                m_cond.wait_for(lock,std::chrono::milliseconds(timeout));

            if(m_queue.empty())
                return false;

            t = std::move(m_queue.front());
            m_queue.pop();
            return true;
        }

        //取出taskQueue对象
        void connect(TaskQueue<T>* pQueue)
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            m_pNextQueue = pQueue;
        }

    private:

        std::queue<T> m_queue;
        std::mutex m_mutex;
        std::condition_variable m_cond;
        TaskQueue<T>* m_pNextQueue;
};

thread_manager.h

/** @ 线程管理 */

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
static const uint32_t MaxThreadNums = 8;
class ThreadManager
{
    public:
        ThreadManager(const int m_threads = MaxThreadNums ):m_threads(std::vector<std::thread>(m_threads)),m_shutdown(false){

        }

        ~ThreadManager(){
            this->shutdown();
        }

        ThreadManager(ThreadManager &&)=delete;
        ThreadManager(const ThreadManager &)=delete;
        ThreadManager &operator=(ThreadManager &&)=delete;
        ThreadManager &operator=(const ThreadManager &) =delete;

        void init()
        {
            for(uint32_t i =0; i < m_threads.size();++i)
            {
                m_threads.at(i) = std::thread(ThreadWorker(this,i));
            }
        }

        void shutdown()
        {
            m_shutdown = true;
            m_cond.notify_all();
            for(uint32_t i =0; i < m_threads.size(); ++i)
            {
                if(m_threads.at(i).joinable())
                {
                    m_threads.at(i).join();
                }
            }
        }

        template<typename F,typename... Args>
        auto postJobs(F&& f, Args &&...args)->std::future<decltype(f(args...))>
        {
            std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f),std::forward<Args>...);
            auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);

            std::function<void()> warpper_func = [task_ptr]()
            {
                (*task_ptr);
            };

            m_task_queue.push(warpper_func);
            m_cond.notify_one();

            return task_ptr->get_future();
        }

    private:

        class ThreadWorker
        {
            public:
                ThreadWorker(ThreadManager *pThreadManager,const int32_t tid):m_pThreadManager(pThreadManager),m_tid(tid){

                };

                void operator()()
                {
                    std::function<void()> task;

                    bool dequeued = false;
                    while(!m_pThreadManager->m_shutdown)
                    {
                        std::unique_lock<std::mutex> lock(m_pThreadManager->m_mutex);
                        m_pThreadManager->m_cond.wait(lock,[&](){
                            return !m_pThreadManager->m_task_queue.empty();
                        });

                        m_pThreadManager->m_task_queue.pop();
                        lock.unlock();

                        task();
                    }
                }

            private:
                int32_t m_tid;
                ThreadManager *m_pThreadManager;
        };

    private:
        bool m_shutdown;
        std::mutex m_mutex;
        std::condition_variable m_cond;
        std::vector<std::thread> m_threads;
        std::queue<std::function<void()>> m_task_queue;

};

common_struct.h

/** @ pipeline的入参结构体 */

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
enum NodeType
{
    Source,
    Channel,
    Sink
};
struct NodeNeedInfo
{
    std::string name;
    NodeType type;
};
struct InputRequestInfo
{
    bool isOK;
    uint32_t requestId;
    
    //nodeInput Info

    NodeNeedInfo nodeInfo[8];
};
using NodeNeedInfoPtr = std::shared_ptr<NodeNeedInfo>;
using InputRequestInfoPtr = std::shared_ptr<InputRequestInfo>;
using ResultCallback = std::function<void(const InputRequestInfoPtr&)>;

struct PipelineDescriptor
{
    uint32_t nums;
    std::string name;
    
    //NodeInfo
    NodeNeedInfo nodes[8];
    ResultCallback callback;
};
using PipelineDescriptorPtr = std::shared_ptr<PipelineDescriptor>;

node.h

//node.h : base Node /*** @ 1, 去掉了类模板 @ 2, 不需要上一级的输出是下一级的输入 @ 3, 通过callback的方式将最后一级的结果输出给前一级 */

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class Node
{
    public:
        Node(): m_stop(false),m_is_sink(false){};
        virtual ~Node() = default;

        virtual int32_t initialize(const std::string& conf) = 0;
        virtual int32_t process(InputRequestInfoPtr pRequestInfo) = 0;

        virtual  std::string getNodeName() const= 0;

        virtual NodeType Type()const =0;
    

    public:

        void start()
        {
            //起线程处理
            m_thread = std::thread([this](){
                executeRequest();
            });
        }

        void stop()
        {
            m_stop = true;
            if(m_thread.joinable())
            {
                m_thread.join();
            }
        }

        // inline std::string getNodeName() const
        // {
        //     return m_node_name;
        // }

        void executeRequest()
        {   
            int count = 0;
            while(!m_stop)
            {
                InputRequestInfoPtr pRequest;
                if(m_input_queue.dequeue(pRequest))
                {
                    int32_t ret = process(pRequest);

                    if(ret != 0)
                    {
                        ///////////
                    }

                    //set request for next node
                    if(m_type != NodeType::Sink)//bug to do
                    {
                        count++;
                        m_output_queue.enqueue(pRequest);
                    }
                    else
                    {
                        m_result_callback(pRequest);//回到main: publishResult
                    }
                 
                    
                }
                else
                {
                    ////////////
                }
            }
        }

        TaskQueue<InputRequestInfoPtr> &input_queue()
        {
            return m_input_queue;
        }

        TaskQueue<InputRequestInfoPtr> &output_queue()
        {
            return m_output_queue;
        }

        // inline NodeType Type()const
        // {
        //     return m_type;
        // }

        void callbackRegister(ResultCallback callback)
        {
            m_result_callback = std::move(callback);
        }

    private:
        bool m_stop;
        bool m_is_sink;
        bool m_source;
        TaskQueue<InputRequestInfoPtr> m_input_queue;
        TaskQueue<InputRequestInfoPtr> m_output_queue;
        std::thread  m_thread;
        std::string m_node_name;
        ResultCallback m_result_callback;
        NodeType m_type;

};

nodeA/B

/** NodeA -> NodeB -> NodeC */

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class Node_A :public Node
{
    public:
        Node_A() = default;
        ~Node_A() =default;

        int32_t initialize(const std::string& conf)override{
            std::cout<<"I am NodeA initialize"<<std::endl;
            return 0;
        }

        int32_t process(InputRequestInfoPtr pRequestInfo)override{
            std::cout<<"I am NodeA process"<<std::endl;
            pRequestInfo->requestId = 100;
            return 0;
        }

        std::string getNodeName()const override
        {
            return "Node_A";
        }

        NodeType Type()const override
        {
            return NodeType::Source;
        }

};

//NodeB
class Node_B :public Node
{
    public:
        Node_B() = default;
        ~Node_B() =default;

        int32_t initialize(const std::string& conf)override{
            std::cout<<"I am NodeB initialize"<<std::endl;
            return 0;
        }

        int32_t process(InputRequestInfoPtr pRequestInfo)override{
            std::cout<<"I am NodeB process"<<std::endl;
            return 0;
        }

        std::string getNodeName()const override
        {
            return "Node_B";
        }

        NodeType Type()const override
        {
            return NodeType::Sink;
        }

};

perceptionPipeline.h

/** @ 一个具体的pipeline */

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class PerceptionPipeline
{

    public:
        PerceptionPipeline()=default;
        ~PerceptionPipeline()=default;

        /**
        @ submit request to source node
        */
        void submit(InputRequestInfoPtr& pRequest)
        {
            m_pNodes[0]->input_queue().enqueue(pRequest);
        }

        /**
        @ initialize an pipline
        */
        int32_t initialize(const PipelineDescriptorPtr& pPipelineDesc)
        {
            int32_t result = 0;
            result = createNodes(pPipelineDesc);

            return result;
        }
        
        int32_t createNodes(const PipelineDescriptorPtr& pPipelineInfo)
        {
            int32_t result = 0;
            for(uint32_t i=0 ; i < pPipelineInfo->nums; i++)
            {
                //todo factory create nodes
                std::shared_ptr<Node> pNode = std::move(CreateNode(pPipelineInfo->nodes[i]));

                result = pNode->initialize("lxk");

                if(0!=result)
                {
                    //////////////
                    break;
                }

                if(pNode->Type() == NodeType::Sink)
                {   
                    std::cout<<"------------callbackRegister-----------"<<std::endl;
                    pNode->callbackRegister(pPipelineInfo->callback);
                }

                this->addNode(pNode);
            }

            return result;
        }
        static std::shared_ptr<Node> CreateNode(const NodeNeedInfo& node_desc)
        {
            if(node_desc.name == "NodeA")
                return (std::make_shared<Node_A>());
            if(node_desc.name == "NodeB")
                return (std::make_shared<Node_B>());

            return nullptr;
        }

        void start()
        {
            for(auto i:m_pNodes)
            {
                i->start();
            }
        }

        void stop()
        {
            for(auto i:m_pNodes)
            {
                i->stop();
            }
        }

        std::string PipelineInfo()
        {
            std::stringstream sstr;
            sstr<<"\n";
            sstr<<"-------Pipeline info  start----------\n";
            sstr<<"number of nodes: "<<m_pNodes.size()<<"\n";
            for(uint32_t i =0; i <m_pNodes.size(); i++)
            {
                if(i == m_pNodes.size() -1)
                {
                    sstr<<m_pNodes[i]->getNodeName()<<"\n";
                }
                else
                {
                    sstr<<m_pNodes[i]->getNodeName()<<"->";
                }
            }

            sstr<<"----------Pipeline info end----------\n";

            return sstr.str();
        }

    private:

        void addNode(std::shared_ptr<Node>& pNode)
        {
            std::shared_ptr<Node> pTail = nullptr;
            if(!m_pNodes.empty())
            {
                pTail = m_pNodes.back();
            }
            m_pNodes.push_back(pNode);

            //connect output queue node with input queue of next node
            if(pTail)
            {
                pTail->output_queue().connect(&pNode->input_queue());
            }
        }

    private:
        std::vector<std::shared_ptr<Node>> m_pNodes;
};

CameraPerception

/** @ 实际测试案例 */

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class CameraPerception
{
    public:
        CameraPerception();
        ~CameraPerception();

        bool init();

    private:

        void cameraPerceptionCallback();

        void publishResult(const InputRequestInfoPtr& pInferResult);

        void MarkObstacleOnImage(uint64_t request_id);

        std::unique_ptr<PerceptionPipeline> m_perception_pipeline;
        std::unique_ptr<ThreadManager> m_thread_manager;

};

CameraPerception::CameraPerception()
{

}

CameraPerception::~CameraPerception()
{
    m_perception_pipeline->stop();
}

bool CameraPerception::init()
{
    m_thread_manager.reset(new ThreadManager());
    m_thread_manager->init();

    m_perception_pipeline.reset(new PerceptionPipeline);
    
    PipelineDescriptorPtr pPipeline(new PipelineDescriptor);
    pPipeline->name = "perception pipeline";

    int count = 2;

    pPipeline->nodes[0].name = "NodeA";
    pPipeline->nodes[0].type = NodeType::Source;

    pPipeline->nodes[1].name = "NodeB";
    pPipeline->nodes[1].type = NodeType::Sink;
    pPipeline->nums = count;
    
    pPipeline->callback = std::bind(&CameraPerception::publishResult, this, std::placeholders::_1);
    int32_t ret = m_perception_pipeline->initialize(pPipeline);

    if(ret != 0)
    {
        std::cout<<"pipeline init error";
        return false;
    }
  
    m_perception_pipeline->start();

    std::cout<<"pipeline info: "<<m_perception_pipeline->PipelineInfo()<<std::endl;

    cameraPerceptionCallback();


    return ret;
}


void CameraPerception::cameraPerceptionCallback()
{
    InputRequestInfoPtr input_info(new InputRequestInfo);
    input_info->requestId = 1;
    input_info->isOK = true;

    for(size_t i =0; i < 3; i++)
    {
        input_info->nodeInfo[i].name = "lxkkk";
    }
    
    m_perception_pipeline->submit(input_info);

}

//callbacked by node
void CameraPerception::publishResult(const InputRequestInfoPtr& pInferResult)
{
    std::cout<<"publishResult:  ID: "<<pInferResult->requestId<<std::endl;

    m_thread_manager->postJobs(std::bind(&CameraPerception::MarkObstacleOnImage, this,pInferResult->requestId));
}

void CameraPerception::MarkObstacleOnImage(uint64_t request_id)
{
     std::cout<<"MarkObstacleOnImage:  ID: "<<request_id<<std::endl;
}
int main()
{
    std::unique_ptr<CameraPerception> pCameraPerceptionHandle(new CameraPerception());
    pCameraPerceptionHandle->init();
}
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-11-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 码出名企路 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
别让数据坑了你!用置信学习找出错误标注(附开源实现)
在实际工作中,你是否遇到过这样一个问题或痛点:无论是通过哪种方式获取的标注数据,数据标注质量可能不过关,存在一些错误?亦或者是数据标注的标准不统一、存在一些歧义?特别是badcase反馈回来,发现训练集标注的居然和badcase一样?如下图所示,QuickDraw、MNIST和Amazon Reviews数据集中就存在错误标注。
zenRRan
2020/07/03
5.5K0
【综述专栏】如何在标注存在错标的数据上训练模型
在科学研究中,从方法论上来讲,都应“先见森林,再见树木”。当前,人工智能学术研究方兴未艾,技术迅猛发展,可谓万木争荣,日新月异。对于AI从业者来说,在广袤的知识森林中,系统梳理脉络,才能更好地把握趋势。为此,我们精选国内外优秀的综述文章,开辟“综述专栏”,敬请关注。
马上科普尚尚
2021/04/28
1.3K0
【综述专栏】如何在标注存在错标的数据上训练模型
领域前沿研究「无所不包」 ,走进标签噪声表征学习的过去、现在和未来
机器之心发布 机器之心编辑部 抗噪鲁棒性学习是机器学习中一个非常重要和热门的领域,各类方法也层出不穷。在本文中,来自香港浸会大学、清华大学等机构的研究者对标签噪声表征学习(LNRL)的方方面面进行了全方位的综述。 监督学习方法通常依赖精确的标注数据,然而在真实场景下数据误标注(标签噪声)问题不可避免。例如,对于数据本身存在不确定性的医疗任务,领域专家也无法给出完全可信的诊断结果(下图 1);基于用户反馈的垃圾邮件过滤程序,用户作为标注人员存在行为的不确定性(例如误点击)。不论是从理论还是从实验角度,人们均发
机器之心
2023/03/29
1.2K0
领域前沿研究「无所不包」 ,走进标签噪声表征学习的过去、现在和未来
样本混进了噪声怎么办?通过Loss分布把它们揪出来!
当训练样本中混有噪音,就很容易让模型过拟合,学习到错误的信息,因此必须加以干涉,来控制噪音带来的影响。这方面的研究,主要集中于“损失修正”方法,即loss correction。典型的方法有这些:
beyondGuo
2021/01/12
2.1K0
样本混进了噪声怎么办?通过Loss分布把它们揪出来!
CVPR 2022 | 应对噪声标签,西安大略大学、字节跳动等提出对比正则化方法
机器之心专栏 作者:西安大略大学、纽约大学、字节跳动 来自西安大略大学、纽约大学和字节跳动的研究者回答了一个重要的问题,即如何从带有噪声标签的数据集中学到可靠模型。 噪声标签(Noisy labels)随着深度学习研究的深入得到广泛的关注,因为在众多实际落地的场景模型的训练都离不开真实可靠的标签信息。由于人工标注误差(专业性不足等问题)、数据原始噪声,带噪声的数据不可避免,清洗数据的工作也是更加困难。 在有监督的图像分类问题中,经典的 cross-entropy (CE) 损失函数是最为广泛应用的函数之
机器之心
2022/06/13
1.1K0
CVPR 2022 | 应对噪声标签,西安大略大学、字节跳动等提出对比正则化方法
【干货】使用Pytorch实现卷积神经网络
【导读】图像识别是深度学习取得重要成功的领域,特别是卷积神经网络在图像识别和图像分类中取得了超过人类的好成绩。本文详细介绍了卷积神经网络(CNN)的基本结构,对卷积神经网络中的重要部分进行详细讲解,如卷积、非线性函数ReLU、Max-Pooling、全连接等。另外,本文通过对 CIFAR-10 的10类图像分类来加深读者对CNN的理解和Pytorch的使用,列举了如何使用Pytorch收集和加载数据集、设计神经网络、进行网络训练、调参和准确度量。总的来讲,这篇文章偏重概念理解和动手实现,相信对您的入门会有帮
WZEARW
2018/06/05
8.3K0
如何使用TensorFlow实现卷积神经网络
编者按:本文节选自图书《TensorFlow实战》第五章,本书将重点从实用的层面,为读者讲解如何使用TensorFlow实现全连接神经网络、卷积神经网络、循环神经网络,乃至Deep Q-Network。同时结合TensorFlow原理,以及深度学习的部分知识,尽可能让读者通过学习本书做出实际项目和成果。 卷积神经网络简介 卷积神经网络(Convolutional Neural Network,CNN)最初是为解决图像识别等问题设计的,当然其现在的应用不仅限于图像和视频,也可用于时间序列信号,比如音频信号
用户1737318
2018/07/20
6550
如何训练孪生神经网络
使用机器学习训练时,如果想训练出精确和健壮的模型需要大量的数据。但当训练模型用于需要自定义数据集的目的时,您通常需要在模型所看到的数据量级上做出妥协。
deephub
2021/05/18
1.6K0
如何训练孪生神经网络
入门 | 完全云端运行:使用谷歌CoLaboratory训练神经网络
选自Medium 作者:Sagar Howal 机器之心编译 参与:路雪 Colaboratory 是一个 Google 研究项目,旨在帮助传播机器学习培训和研究成果。它是一个 Jupyter 笔记本环境,不需要进行任何设置就可以使用,并且完全在云端运行。Colaboratory 笔记本存储在 Google 云端硬盘 (https://drive.google.com/) 中,并且可以共享,就如同您使用 Google 文档或表格一样。Colaboratory 可免费使用。本文介绍如何使用 Google Co
机器之心
2018/05/11
1.7K0
迁移学习之快速搭建【卷积神经网络】
卷积神经网络 概念认识:https://cloud.tencent.com/developer/article/1822928
一颗小树x
2021/05/14
2K0
迁移学习之快速搭建【卷积神经网络】
黑客视角:避免神经网络训练失败,需要注意什么?
确保网络正常运行的关键因素之一是网络的配置。正如机器学习大师 Jason Brownle 所说,「深度学习神经网络已经变得易于定义和拟合,但仍然难以配置。」
AI研习社
2019/10/08
9100
黑客视角:避免神经网络训练失败,需要注意什么?
我用 PyTorch 复现了 LeNet-5 神经网络(CIFAR10 数据集篇)!
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
红色石头
2022/01/10
1.3K0
我用 PyTorch 复现了 LeNet-5 神经网络(CIFAR10 数据集篇)!
使用自编码器进行图像去噪
正确理解图像信息在医学等领域是至关重要的。去噪可以集中在清理旧的扫描图像上,或者有助于癌症生物学中的特征选择。噪音的存在可能会混淆疾病的识别和分析,从而导致不必要的死亡。因此,医学图像去噪是一项必不可少的预处理技术。
deephub
2021/05/18
1.2K0
使用自编码器进行图像去噪
用 PyTorch 从零创建 CIFAR-10 的图像分类器神经网络,并将测试准确率达到 85%
一般,深度学习的教材或者是视频,作者都会通过 MNIST 这个数据集,讲解深度学习的效果,但这个数据集太小了,而且是单色图片,随便弄些模型就可以取得比较好的结果,但如果我们不满足于此,想要训练一个神经网络来对彩色图像进行分类,可以不可以呢?
Frank909
2019/01/14
10.1K0
机器之心GitHub项目:从零开始用TensorFlow搭建卷积神经网络
机器之心原创 参与:蒋思源 机器之心基于 Ahmet Taspinar 的博文使用 TensorFlow 手动搭建卷积神经网络,并提供所有代码和注释的 Jupyter Notebook 文档。我们将不仅描述训练情况,同时还将提供各种背景知识和分析。所有的代码和运行结果都已上传至 Github,机器之心希望通过我们的试验提供精确的代码和运行经验,我们将持续试验这一类高质量的教程和代码。 机器之心项目地址:https://github.com/jiqizhixin/ML-Tutorial-Experiment
机器之心
2018/05/08
1.5K0
机器之心GitHub项目:从零开始用TensorFlow搭建卷积神经网络
深度学习中的卷积神经网络:原理、结构与应用
文章链接:https://cloud.tencent.com/developer/article/2471829
小馒头学Python
2024/11/27
5750
深度学习中的卷积神经网络:原理、结构与应用
TensorFlow | 自己动手写深度学习模型之全连接神经网络
前半个多月总共写了三篇深度学习相关的理论介绍文章,另外两个月前,我们使用逻辑回归算法对sklearn里面的moons数据集进行了分类实验,最终准确率和召回率都达到了97.9%,详情参看这篇文章:一文打尽:线性回归和逻辑斯蒂线性回归(https://zhuanlan.zhihu.com/p/31075733),今天我们尝试使用神经网络来进行分类。全连接神经网络的搭建本身没什么难度,几句代码就够了,但是本文的真正目的是: 让大家了解Tensorflow 的基本使用方法; 使用 tensorboard 可视化你的
AI研习社
2018/03/16
1.5K0
TensorFlow | 自己动手写深度学习模型之全连接神经网络
【深度学习实验】卷积神经网络(八):使用深度残差神经网络ResNet完成图片多分类任务
本实验实现了实现深度残差神经网络ResNet,并基于此完成图像分类任务。
Qomolangma
2024/07/30
5690
【深度学习实验】卷积神经网络(八):使用深度残差神经网络ResNet完成图片多分类任务
Python从0到100(八十四):神经网络-卷积神经网络训练CIFAR-10数据集
CIFAR-10 数据集由 10 个类的 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。
是Dream呀
2025/03/05
1630
Python从0到100(八十四):神经网络-卷积神经网络训练CIFAR-10数据集
NeurIPS 2019 | 一种对噪音标注鲁棒的基于信息论的损失函数
噪音标注(noisy label)是机器学习领域的一个热门话题,这是因为标注大规模的数据集往往费时费力,尽管在众包平台上获取数据更加快捷,但是获得的标注往往是有噪音的,直接在这样的数据集上训练会损害模型的性能。许多之前处理噪音标注的工作仅仅对特定的噪音模式(noise pattern)鲁棒,或者需要额外的先验信息,比如需要事先对噪音转移矩阵(noise transition matrix)有较好的估计。我们提出了一种新的损失函数,
机器之心
2019/11/22
4480
推荐阅读
别让数据坑了你!用置信学习找出错误标注(附开源实现)
5.5K0
【综述专栏】如何在标注存在错标的数据上训练模型
1.3K0
领域前沿研究「无所不包」 ,走进标签噪声表征学习的过去、现在和未来
1.2K0
样本混进了噪声怎么办?通过Loss分布把它们揪出来!
2.1K0
CVPR 2022 | 应对噪声标签,西安大略大学、字节跳动等提出对比正则化方法
1.1K0
【干货】使用Pytorch实现卷积神经网络
8.3K0
如何使用TensorFlow实现卷积神经网络
6550
如何训练孪生神经网络
1.6K0
入门 | 完全云端运行:使用谷歌CoLaboratory训练神经网络
1.7K0
迁移学习之快速搭建【卷积神经网络】
2K0
黑客视角:避免神经网络训练失败,需要注意什么?
9100
我用 PyTorch 复现了 LeNet-5 神经网络(CIFAR10 数据集篇)!
1.3K0
使用自编码器进行图像去噪
1.2K0
用 PyTorch 从零创建 CIFAR-10 的图像分类器神经网络,并将测试准确率达到 85%
10.1K0
机器之心GitHub项目:从零开始用TensorFlow搭建卷积神经网络
1.5K0
深度学习中的卷积神经网络:原理、结构与应用
5750
TensorFlow | 自己动手写深度学习模型之全连接神经网络
1.5K0
【深度学习实验】卷积神经网络(八):使用深度残差神经网络ResNet完成图片多分类任务
5690
Python从0到100(八十四):神经网络-卷积神经网络训练CIFAR-10数据集
1630
NeurIPS 2019 | 一种对噪音标注鲁棒的基于信息论的损失函数
4480
相关推荐
别让数据坑了你!用置信学习找出错误标注(附开源实现)
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验