Skip to main content

tf2pb tensorflow 转换pb工具,支持pb tf-serving pb 和 fastertransformer pb转换

Project description

tf2pb tensorflow ckpt转换pb工具,支持pb tf-serving pb 和 fastertransformer pb转换

# -*- coding: utf-8 -*-
'''
简介:
        tf2pb tensorflow模型转换pb
        convert_ckpt.py: 将tf transformer系列模型ckpt格式转换pb模型 tf-serving pb fastertransformer pb
        convert_ckpt_dtype.py:  精度转换 , 将tf模型ckpt 32精度转换ckpt 16精度
        convert_keras.py: 将keras h5py模型转换pb
        convert_ckpt.py 转换 fastertransformer pb 可提高1.9x加速.
        建议pb模型均可以通过nn-sdk推理
        fastertransformer pb 当前只支持linux tensorflow 1.15 cuda11.3 cuda10.0 , 其他pb不依赖。
        推荐 tensorflow 链接如下,建议使用cuda11.3.1 环境tensorflow 1.15
        tensorflow链接: https://pan.baidu.com/s/1PXelYOJ2yqWfWfY7qAL4wA 提取码: rpxv 复制这段内容后打开百度网盘手机App,操作更方便哦
'''

convert_ckpt_dtype.py转换精度

# -*- coding: utf-8 -*-
'''
    convert_ckpt_dtype.py:  ckpt 32精度 转换16精度
'''
import os
import tensorflow as tf
import tf2pb
#初始化
tf2pb.ready({})
src_ckpt = r'/home/tk/tk_nlp/script/ner/ner_output/bert/model.ckpt-2704'
dst_ckpt = r'/root/model_16fp.ckpt'
#转换32 to 16
tf2pb.convert_ckpt_dtype(src_ckpt,dst_ckpt)

convert_ckpt.py ckpt转换pb

# -*- coding: utf-8 -*-
'''
    convert_ckpt.py: 将tf bert transformer 等模型ckpt转换pb模型 tf-serving pb和 fastertransformer pb
'''
import os
import sys
import tensorflow as tf
import shutil
import tf2pb

ready_config = {
    "floatx": "float32",  # float16, float32 训练模型(ckpt_filename)的精度,通常需32,如需16 可以通过convert_ckpt_dtype.py 转换16精度之后再转换pb
    "fastertransformer":{
        "use":  0,# 0 普通模型转换 , 1 启用fastertransormer
        "cuda_version": "11.3", # 当前支持 10.2, 11.3
        "remove_padding": False,
        "int8_mode": 0,# 需算力7.5以上显卡支持,不建议修改
    }
}
#保存普通pb格式配置
freeze_pb_config = {
    "ckpt_filename": r"/home/tk/tk_nlp/script/ner/ner_output/bert/model.ckpt-2704",  # 训练ckpt权重
    "save_pb_file": r"/data/finalmodel/2021/bert_ner_2021_09/bert_ner.pb",# 保存pb文件
}
#保存普通serving格式pb配置
freeze_pb_serving_config = {
    'use':False,#默认注释掉保存serving模型
    "ckpt_filename": r"/home/tk/tk_nlp/script/ner/ner_output/bert/model.ckpt-2704",  # 训练ckpt权重
    "save_pb_path_serving": r'/data/finalmodel/2021/bert_ner_2021_09/serving',  # tf_serving 保存模型路径
    'serve_option': {
        'method_name': 'tensorflow/serving/predict',
        'tags': ['serve'],
    }
}
if freeze_pb_config['ckpt_filename'] and os.path.exists(freeze_pb_config['ckpt_filename']):
    os.remove(freeze_pb_config['ckpt_filename'])

if freeze_pb_serving_config['use'] and freeze_pb_serving_config['save_pb_path_serving'] and os.path.exists(freeze_pb_serving_config['save_pb_path_serving']):
    shutil.rmtree(freeze_pb_serving_config['save_pb_path_serving'])

#加载训练模型
max_seq_len = 340
num_labels = 16 * 4 + 1
bert_dir=r'/data/nlp/pre_models/tf/bert/chinese_L-12_H-768_A-12'
config_file = os.path.join(bert_dir, 'bert_config.json')
if not os.path.exists(config_file):
    raise Exception("bert_config does not exist")

#初始化tf2seq模块
BertModel_module = tf2pb.ready(ready_config)
if BertModel_module is None:
    raise Exception('tf2pb ready failed')
bert_config = BertModel_module.BertConfig.from_json_file(config_file)
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, num_labels, use_one_hot_embeddings):
    """Creates a classification model."""
    model = BertModel_module.BertModel(
      config=bert_config,
      is_training=is_training,
      input_ids=input_ids,
      input_mask=input_mask,
      token_type_ids=segment_ids,
      use_one_hot_embeddings=use_one_hot_embeddings)

    output_layer = model.get_pooled_output()
    hidden_size = output_layer.shape[-1].value
    output_weights = tf.get_variable(
      "output_weights", [num_labels, hidden_size],
      dtype="float32",
      initializer=tf.truncated_normal_initializer(stddev=0.02))
    output_bias = tf.get_variable(
      "output_bias", [num_labels],
      dtype="float32",
      initializer=tf.zeros_initializer())
    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
    probabilities = tf.nn.softmax(logits, axis=-1)
    return probabilities


def save(is_save_serving):
    def create_network_fn():
        input_ids = tf.placeholder(tf.int32, (None, max_seq_len), 'input_ids')
        input_mask = tf.placeholder(tf.int32, (None, max_seq_len), 'input_mask')
        segment_ids = None
        #这里简单使用分类,具体根据自己需求修改
        probabilities = create_model(bert_config, False, input_ids, input_mask, segment_ids, num_labels, False)
        save_config = {
            "input_tensor": {
                'input_ids': input_ids,
                'input_mask': input_mask
            },
            "output_tensor": {
                "pred_ids": probabilities
            },
        }
        save_config.update(freeze_pb_serving_config if is_save_serving else freeze_pb_config)
        return save_config
    #根据自己的模型进行编写
    if not is_save_serving:
        ret = tf2pb.freeze_pb(create_network_fn)
        print(ret)
        if ret ==0:
            tf2pb.pb_show(freeze_pb_config['save_pb_file'])  # 查看
    else:
        ret = tf2pb.freeze_pb_serving(create_network_fn)
        print(ret)
        if ret ==0:
            tf2pb.pb_serving_show(freeze_pb_serving_config['save_pb_path_serving'],freeze_pb_serving_config['serve_option']['tags'])  # 查看
#保存正常pb
save(is_save_serving = False)
#保存serving格式pb
if freeze_pb_serving_config['use']:
    save(is_save_serving = True)

convert_keras.py keras转换pb

# -*- coding: utf-8 -*-
'''
    convert_keras.py: keras h5py 权重 转换pb:
'''
import sys
import tensorflow as tf
import tf2pb
from keras.models import Model
#初始化
tf2pb.ready({})

config = {
    'model': None,# 训练构建的模型
    'weight_filename' : '/root/weight_filename.weights', #训练权重 h5py格式
    'input_tensor' : {
        "input_ids": None, # 对应输入Tensor 例如 bert.Input[0]
        "input_mask": None, # 对应输入Tensor 例如 bert.Input[1]
    },
    'output_tensor' : {
        "pred_ids": None,
    },
    'save_pb_file': r'/root/save_pb_file.pb' #保存pb路径
}
#直接转换
tf2pb.freeze_keras_pb(config)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tf2pb-0.0.5-cp38-cp38-manylinux2014_x86_64.whl (13.9 MB view hashes)

Uploaded CPython 3.8

tf2pb-0.0.5-cp37-cp37m-manylinux2014_x86_64.whl (13.9 MB view hashes)

Uploaded CPython 3.7m

tf2pb-0.0.5-cp36-cp36m-manylinux2014_x86_64.whl (13.9 MB view hashes)

Uploaded CPython 3.6m

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page