covid-19-detector/flask_app/detect/predict_sevxray.py
2024-07-29 11:43:52 +08:00

33 lines
1.2 KiB
Python

import tensorflow as tf
import os
import numpy as np
from .processImg import process_image_file
def detectsev_xray(imagepath):
weightspath = "models/COVIDNet-CXR-S"
metaname = "model.meta"
ckptname = "model"
n_classes = "2"
in_tensorname = "input_1:0"
out_tensorname = "norm_dense_2/Softmax:0"
input_size = 480
top_percent = 0.08
mapping = {'轻微': 0, '严重': 1}
inv_mapping = {0: '轻微', 1: '严重'}
mapping_keys = list(mapping.keys())
sess = tf.Session()
tf.get_default_graph()
saver = tf.train.import_meta_graph(os.path.join(weightspath, metaname))
saver.restore(sess, os.path.join(weightspath, ckptname))
graph = tf.get_default_graph()
image_tensor = graph.get_tensor_by_name(in_tensorname)
pred_tensor = graph.get_tensor_by_name(out_tensorname)
x = process_image_file(imagepath, input_size, top_percent=top_percent)
x = x.astype('float32') / 255.0
feed_dict = {image_tensor: np.expand_dims(x, axis=0)}
pred = sess.run(pred_tensor, feed_dict=feed_dict)
pred_type = inv_mapping[pred.argmax(axis=1)[0]]
pred_mild = round(pred[0][mapping['轻微']], 3)
pred_severe = round(pred[0][mapping['严重']], 3)
return pred_type,pred_mild,pred_severe