#!/usr/bin/env python# -*- coding: utf-8 -*-# @Time : 18-8-7 下午5:47# @File : fast_roi.py# @Software: PyCharm# @Author : wxw# @Contact : xwwei@lighten.ai# @Desc : roi tf netimport tensorflow as tffrom collections import OrderedDictfrom utils import configimport numpy as npclass RoiNets: def __init__(self, scores, endpoints, height, width, net_name, is_training=True): self.is_training = is_training self.scores = scores self.endpoints = endpoints self.witdh = width self.height = height self.net_name = net_name self.batch_norm = { "is_training": is_training, "center": True, "scale": True, "decay": 0.9, "epsilon": 0.001, } self.positions = self.get_position(self.scores) self.box_ind = tf.constant(np.arange(config.batch_size), tf.int32) self.cut_maps = self.get_roi_maps() self.page_info = self.get_page_info() def get_position(self, scores): scores = tf.nn.softmax(scores) scores = tf.split(scores, num_or_size_splits=2, axis=3)[1] map = tf.identity(scores) map = tf.reshape(map, [config.batch_size, -1]) max_idx = tf.argmax(map, axis=1) heigth = tf.cast(tf.expand_dims(max_idx // 224, 1) / 224, tf.float32) width = tf.cast(tf.expand_dims(max_idx % 224, 1) / 224, tf.float32) b_h = tf.maximum(0.0, (heigth - 0.25)) b_w = tf.maximum(0.0, (width - 0.25)) e_w = tf.minimum(1.0, (width + 0.25)) e_h = heigth return tf.concat([b_h, b_w, e_h, e_w], axis=1) def get_roi_maps(self): number = len(self.net_name) - 1 cut_maps = OrderedDict() self.chanels = [] for i in range(number): cut_name = "cut_map_%d" % i net = self.endpoints[self.net_name[i]] self.chanels.append(net.get_shape()[3].value) cut_maps[cut_name] = tf.image.crop_and_resize(image=net, boxes=self.positions, box_ind=self.box_ind, crop_size=[self.height[i], self.witdh[i]]) net = self.endpoints[self.net_name[number]] cut_maps["cut_map_%d" % number] = net self.chanels.append(net.get_shape()[3].value) for idx, amap in enumerate(cut_maps): print('[%d]:' % idx, cut_maps[amap]) return cut_maps复制代码