将TensorFlow模型转换为pytorch模型代码:
xxx表示模型的路径(当前路径一定要记得是./xxxxxxxx不是/xxxxx)
export BERT_BASE_DIR=xxx
transformers-cli convert --model_type bert \
--tf_checkpoint $BERT_BASE_DIR/bert_model.ckpt \
--config $BERT_BASE_DIR/bert_config.json \
--pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin
下面是对几个文件的理解:
├── bert_config.json <- 模型配置文件
├── bert_model.ckpt.data-00000-of-00001 <- 保存断点文件列表,可以用来迅速查找最近一次的断点文件
├── bert_model.ckpt.index <- 为数据文件提供索引,存储的核心内容是以tensor name为键以BundleEntry为值的表格entries,BundleEntry主要内容是权值的类型、形状、偏移、校验和等信息。
├── bert_model.ckpt.meta <- 是MetaGraphDef序列化的二进制文件,保存了网络结构相关的数据,包括graph_def和saver_def等
└── vocab.txt <- 模型词汇表文件
参考:
1、https://huggingface.co/transformers/converting_tensorflow_models.html
2、https://blog.csdn.net/sunyueqinghit/article/details/103458365/