Skip to main content

tf2pb功能: tf transformer系列模型ckpt格式转换pb模型 tf-serving pb fastertransformer pb

Project description

tf2pb功能: tf transformer系列模型ckpt格式转换pb模型 tf-serving pb fastertransformer pb

# -*- coding: utf-8 -*-
'''
简介:
        tf2pb tf transformer模型转换pb
        支持普通pb和fastertransformer pb转换
        convert_ckpt.py: 将tf transformer系列模型ckpt格式转换pb模型 tf-serving pb fastertransformer pb
        convert_ckpt_dtype.py:  精度转换 , 将tf模型ckpt 32精度转换ckpt 16精度
        convert_keras.py: 将keras h5py模型转换pb
        convert_ckpt.py 转换 fastertransformer pb 可提高1.9x - 3.x加速, fastertransformer 目前只支持官方bert transformer系列
        建议pb模型均可以通过nn-sdk推理
        fastertransformer pb 当前只支持linux tensorflow 1.15 cuda11.3 cuda10.2 , 其他pb模型则不依赖。
        推荐 tensorflow 链接如下,建议使用cuda11.3.1 环境tensorflow 1.15
        tensorflow链接: https://pan.baidu.com/s/1PXelYOJ2yqWfWfY7qAL4wA 提取码: rpxv 复制这段内容后打开百度网盘手机App,操作更方便哦
        链接的tf经过测试 , bert 加速3.x
'''

convert_ckpt_dtype.py转换精度

# -*- coding: utf-8 -*-
'''
    convert_ckpt_dtype.py:  ckpt 32精度 转换16精度
'''
import os
import tensorflow as tf
import tf2pb

src_ckpt = r'/home/tk/tk_nlp/script/ner/ner_output/bert/model.ckpt-2704'
dst_ckpt = r'/root/model_16fp.ckpt'
#转换32 to 16
tf2pb.convert_ckpt_dtype(src_ckpt,dst_ckpt)

convert_ckpt.py ckpt转换pb

# -*- coding: utf-8 -*-
'''
    convert_ckpt.py: 将tf bert transformer 等模型ckpt转换pb模型 tf-serving pb和 fastertransformer pb
'''
import os
import tensorflow as tf
import shutil
import tf2pb

#if not fastertransformer , don't advice change
ready_config = {
    "floatx": "float32",  # float16, float32 训练模型(ckpt_filename)的精度,通常需32,如需16 可以通过convert_ckpt_dtype.py 转换16精度之后再转换pb
    "fastertransformer": {
        "use": 0,  # 0 普通模型转换 , 1 启用fastertransormer
        "cuda_version": "11.3",  # 当前支持 10.2, 11.3
        "remove_padding": False,
        "int8_mode": 0,  # 需显卡支持,不建议修改
    }
}


def load_model_tensor(bert_dir,max_seq_len,num_labels):
    config_file = os.path.join(bert_dir, 'bert_config.json')
    if not os.path.exists(config_file):
        raise Exception("bert_config does not exist")

    # BertModel_module = load_model_tensor 加载 官方bert模型和fastertransformer模型
    # tf2pb.get_modeling 根据自己需求,可自定义
    BertModel_module = tf2pb.get_modeling(ready_config)
    if BertModel_module is None:
        raise Exception('tf2pb get_modeling failed')
    bert_config = BertModel_module.BertConfig.from_json_file(config_file)

    def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, num_labels, use_one_hot_embeddings):
        """Creates a classification model."""
        model = BertModel_module.BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

        output_layer = model.get_pooled_output()
        hidden_size = output_layer.shape[-1].value
        output_weights = tf.get_variable(
            "output_weights", [num_labels, hidden_size],
            dtype="float32",
            initializer=tf.truncated_normal_initializer(stddev=0.02))
        output_bias = tf.get_variable(
            "output_bias", [num_labels],
            dtype="float32",
            initializer=tf.zeros_initializer())
        logits = tf.matmul(output_layer, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        probabilities = tf.nn.softmax(logits, axis=-1)
        return probabilities

    input_ids = tf.placeholder(tf.int32, (None, max_seq_len), 'input_ids')
    input_mask = tf.placeholder(tf.int32, (None, max_seq_len), 'input_mask')
    segment_ids = None
    # 这里简单使用分类,具体根据自己需求修改
    probabilities = create_model(bert_config, False, input_ids, input_mask, segment_ids, num_labels, False)
    save_config = {
        "input_tensor": {
            'input_ids': input_ids,
            'input_mask': input_mask
        },
        "output_tensor": {
            "pred_ids": probabilities
        },
    }
    return save_config

if __name__ == '__main__':

    # 训练ckpt权重
    weight_file = r'/home/tk/tk_nlp/script/ner/ner_output/bert/model.ckpt-2704'
    output_dir = r'/home/tk/tk_nlp/script/ner/ner_output/bert'

    bert_dir = r'/data/nlp/pre_models/tf/bert/chinese_L-12_H-768_A-12'
    max_seq_len = 340
    num_labels = 16 * 4 + 1

    #normal pb
    pb_config = {
        "ckpt_filename": weight_file,  # 训练ckpt权重
        "save_pb_file": os.path.join(output_dir,'bert_inf.pb'),
    }
    #serving pb
    pb_serving_config = {
        'use':False,#默认注释掉保存serving模型
        "ckpt_filename": weight_file,  # 训练ckpt权重
        "save_pb_path_serving": os.path.join(output_dir,'serving'),  # tf_serving 保存模型路径
        'serve_option': {
            'method_name': 'tensorflow/serving/predict',
            'tags': ['serve'],
        }
    }

    if pb_config['save_pb_file'] and os.path.exists(pb_config['save_pb_file']):
        os.remove(pb_config['save_pb_file'])

    if pb_serving_config['use'] and pb_serving_config['save_pb_path_serving'] and os.path.exists(pb_serving_config['save_pb_path_serving']):
        shutil.rmtree(pb_serving_config['save_pb_path_serving'])


    def convert2pb(is_save_serving):
        def create_network_fn():
            save_config = load_model_tensor(bert_dir=bert_dir,max_seq_len=max_seq_len,num_labels=num_labels)
            save_config.update(pb_serving_config if is_save_serving else pb_config)
            return save_config

        if not is_save_serving:
            ret = tf2pb.freeze_pb(create_network_fn)
            if ret ==0:
                tf2pb.pb_show(pb_config['save_pb_file'])  # 查看
            else:
                print('tf2pb.freeze_pb failed ',ret)
        else:
            ret = tf2pb.freeze_pb_serving(create_network_fn)
            if ret ==0:
                tf2pb.pb_serving_show(pb_serving_config['save_pb_path_serving'],pb_serving_config['serve_option']['tags'])  # 查看
            else:
                print('tf2pb.freeze_pb_serving failed ',ret)

    convert2pb(is_save_serving = False)
    if pb_serving_config['use']:
        convert2pb(is_save_serving = True)

convert_keras.py keras转换pb

# -*- coding: utf-8 -*-
'''
    convert_keras.py: keras h5py 权重 转换pb:
'''
import sys
import tensorflow as tf
import tf2pb
import os
from keras.models import Model,load_model
# test pass at tensorflow 1.x


# bert_model is construct by your src code
weight_file = os.path.join(output_dir, 'best_model.h5')
bert_model.load_weights(weight_file , by_name=False)
# or bert_model = load_model(weight_file)


#modify output name
pred_ids = tf.identity(bert_model.output, "pred_ids")

print(bert_model.inputs[0])
print(bert_model.inputs[1])

config = {
    'model': bert_model,# the model your trained
    'input_tensor' : {
        "Input-Token": bert_model.inputs[0], # Tensor such as  bert.Input[0]
        "Input-Segment": bert_model.inputs[1], # Tensor such as  bert.Input[0]
    },
    'output_tensor' : {
        "pred_ids": pred_ids, # Tensor output tensor
    },
    'save_pb_file': r'/root/save_pb_file.pb', # pb filename
}

if os.path.exists(config['save_pb_file']):
    os.remove(config['save_pb_file'])
#直接转换
tf2pb.freeze_keras_pb(config)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

tf2pb-0.1.10-py3-none-any.whl (13.8 MB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page