博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Roi Net
阅读量:6949 次
发布时间:2019-06-27

本文共 2609 字,大约阅读时间需要 8 分钟。

#!/usr/bin/env python# -*- coding: utf-8 -*-# @Time    : 18-8-7 下午5:47# @File    : fast_roi.py# @Software: PyCharm# @Author  : wxw# @Contact : xwwei@lighten.ai# @Desc    : roi tf netimport tensorflow as tffrom collections import OrderedDictfrom utils import configimport numpy as npclass RoiNets:    def __init__(self, scores, endpoints, height, width, net_name, is_training=True):        self.is_training = is_training        self.scores = scores        self.endpoints = endpoints        self.witdh = width        self.height = height        self.net_name = net_name        self.batch_norm = {            "is_training": is_training,            "center": True,            "scale": True,            "decay": 0.9,            "epsilon": 0.001,        }        self.positions = self.get_position(self.scores)        self.box_ind = tf.constant(np.arange(config.batch_size), tf.int32)        self.cut_maps = self.get_roi_maps()        self.page_info = self.get_page_info()    def get_position(self, scores):        scores = tf.nn.softmax(scores)        scores = tf.split(scores, num_or_size_splits=2, axis=3)[1]        map = tf.identity(scores)        map = tf.reshape(map, [config.batch_size, -1])        max_idx = tf.argmax(map, axis=1)        heigth = tf.cast(tf.expand_dims(max_idx // 224, 1) / 224, tf.float32)        width = tf.cast(tf.expand_dims(max_idx % 224, 1) / 224, tf.float32)        b_h = tf.maximum(0.0, (heigth - 0.25))        b_w = tf.maximum(0.0, (width - 0.25))        e_w = tf.minimum(1.0, (width + 0.25))        e_h = heigth        return tf.concat([b_h, b_w, e_h, e_w], axis=1)    def get_roi_maps(self):        number = len(self.net_name) - 1        cut_maps = OrderedDict()        self.chanels = []        for i in range(number):            cut_name = "cut_map_%d" % i            net = self.endpoints[self.net_name[i]]            self.chanels.append(net.get_shape()[3].value)            cut_maps[cut_name] = tf.image.crop_and_resize(image=net,                                                          boxes=self.positions,                                                          box_ind=self.box_ind,                                                          crop_size=[self.height[i],                                                                     self.witdh[i]])        net = self.endpoints[self.net_name[number]]        cut_maps["cut_map_%d" % number] = net        self.chanels.append(net.get_shape()[3].value)        for idx, amap in enumerate(cut_maps):            print('[%d]:' % idx, cut_maps[amap])        return cut_maps复制代码

转载于:https://juejin.im/post/5b6c10315188251b39510b9c

你可能感兴趣的文章
反编译.o到.cpp
查看>>
[LeetCode]Remove Duplicates from Sorted Array
查看>>
qtp试用期30天已经过了就无法使用,解决办法
查看>>
困惑好久 删除配置文件中的一行 怎么办?
查看>>
winform文本框怎么实现html的placeholder效果
查看>>
认识CSS样式
查看>>
excel表格数据信息传递老出错,还有没有更好用数据处理工具?
查看>>
[转]SQLITE3 C语言接口 API 函数简介
查看>>
Delphi XE5中使用jar包
查看>>
org.apache.felix.framework-5.6.12源码解析——org.apache.felix.framework文件夹最后的部分...
查看>>
Python3的tcp socket接收不定长数据包接收到的数据不全。
查看>>
b2b
查看>>
第三周Java学习总结
查看>>
OGRE的安装和编译【转+改】
查看>>
获取管理员组用户
查看>>
Mysql—(2)—
查看>>
简历的分布式
查看>>
[转]string和stringstream用法总结
查看>>
LeetCode:Rotate Array
查看>>
jquery pagination.js 分页
查看>>