将TensorFlow模型转换为pytorch模型代码:
xxx表示模型的路径(当前路径一定要记得是./xxxxxxxx不是/xxxxx)

export BERT_BASE_DIR=xxx

transformers-cli convert --model_type bert \
  --tf_checkpoint $BERT_BASE_DIR/bert_model.ckpt \
  --config $BERT_BASE_DIR/bert_config.json \
  --pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin

下面是对几个文件的理解:

├── bert_config.json                     <- 模型配置文件
├── bert_model.ckpt.data-00000-of-00001  <- 保存断点文件列表,可以用来迅速查找最近一次的断点文件
├── bert_model.ckpt.index                <- 为数据文件提供索引,存储的核心内容是以tensor name为键以BundleEntry为值的表格entries,BundleEntry主要内容是权值的类型、形状、偏移、校验和等信息。
├── bert_model.ckpt.meta                 <- 是MetaGraphDef序列化的二进制文件,保存了网络结构相关的数据,包括graph_def和saver_def等
└── vocab.txt                            <- 模型词汇表文件

参考:
1、https://huggingface.co/transformers/converting_tensorflow_models.html
2、https://blog.csdn.net/sunyueqinghit/article/details/103458365/

Last modification:March 22nd, 2021 at 08:50 pm
如果觉得我的文章对你有用,请随意赞赏