koth/kcws

大神,好,训练模型后如何使用模型呢?

wangwisdom opened this issue · 9 comments

大神,好,训练模型后如何使用模型呢?意思是如何输入一个句子进行测试呢?

项目使用的是bazel编译,比较麻烦,修改一下seg_backend_api.cc 这个文件,如下:
/*

  • Copyright 2016- 2018 Koth. All Rights Reserved.
  • =====================================================================================
  • Filename: seg_backend_api.cc
  • Author: Koth
  • Create Time: 2016-11-20 20:43:26
  • Description:

*/
#include
#include
#include
#include
#include

#include "base/base.h"
#include "jsonxx.h"
#include "basic_string_util.h"
#include "tf_seg_model.h"
#include "pos_tagger.h"
#include "third_party/crow/include/crow.h"
#include "tensorflow/core/platform/init_main.h"

DEFINE_int32(port, 9090, "the api serving binding port");
DEFINE_string(model_path, "tensor_decode/models/seg_model.pbtxt", "the model path");
DEFINE_string(vocab_path, "tensor_decode/models/basic_vocab.txt", "char vocab path");
DEFINE_string(pos_model_path, "tensor_decode/models/pos_model.pbtxt", "the pos tagging model path");
DEFINE_string(word_vocab_path, "tensor_decode/models/word_vocab.txt", "word vocab path");
DEFINE_string(pos_vocab_path, "tensor_decode/models/pos_vocab.txt", "pos vocab path");
DEFINE_int32(max_sentence_len, 80, "max sentence len ");
DEFINE_string(user_dict_path, "", "user dict path");
DEFINE_int32(max_word_num, 50, "max num of word per sentence ");

class SegMiddleware
{
public:
struct context {};
SegMiddleware() {}
~SegMiddleware() {}
void before_handle(crow::request& req, crow::response& res, context& ctx) {}
void after_handle(crow::request& req, crow::response& res, context& ctx) {}
private:
};

int main(int argc, char* argv[])
{
tensorflow::port::InitMain(argv[0], &argc, &argv);
google::ParseCommandLineFlags(&argc, &argv, true);
crow::App app;
kcws::TfSegModel model;
CHECK(model.LoadModel(FLAGS_model_path,
FLAGS_vocab_path,
FLAGS_max_sentence_len,
FLAGS_user_dict_path))
<< "Load model error";

if (!FLAGS_pos_model_path.empty()) {
	kcws::PosTagger* tagger = new kcws::PosTagger;
	CHECK(tagger->LoadModel(FLAGS_pos_model_path,
				FLAGS_word_vocab_path,
				FLAGS_vocab_path,
				FLAGS_pos_vocab_path,
				FLAGS_max_word_num)) << "load pos model error";
	model.SetPosTagger(tagger);
}
std::ifstream in("/home/work/test/nlu-ner/train_and_dict/model/origin_corpus/origin_corpus");
std::string sentence = "";
while (std::getline(in, sentence)) {
//while (1) {
//	std::cout << "please input query:";
//	std::cin >> sentence;
//	if (sentence == "q") {
//		return 0;
//	}

	std::vector<std::string> result;
	std::vector<std::string> tags;
	std::string desc = "";
	//std::cout << "input sentence is:" << sentence << std::endl;
	if (model.Segment(sentence, &result, &tags)) {
		int status = 0;
		//std::cout << "result size:" << result.size() << std::endl;
		//std::cout << "tags size:" << tags.size() << std::endl;
		if (result.size() == tags.size()) {
			int nl = result.size();
			for (int i = 0; i < nl; i++) {
				std::cout << result[i] << "/" << tags[i] << " ";
			}
			std::cout << std::endl;
		} else {
			for (std::string str : result) {
				std::cout << str <<  " ";
			}
			std::cout << std::endl;
		}
		//std::cout << "segments" << std::endl;
	} else {
		desc = "Parse request error";
	}

	//std::cout << "status" << std::endl;;
	//std::cout << "msg" << desc << std::endl;
}

return 0;

}
然后 自己写个编译命令编译,就可以直接运行,如下:
#!/bin/bash

set -e -x

g++ -std=c++11 -o seg_backend_api ./kcws/cc/seg_backend_api.cc
./kcws/cc/pos_tagger.cc ./kcws/cc/sentence_breaker.cc ./kcws/cc/tf_seg_model.cc ./kcws/cc/viterbi_decode.cc
./utils/basic_vocab.cc ./utils/jsonxx.cc ./utils/py_word2vec_vob.cc ./utils/word2vec_vob.cc
./tfmodel/tfmodel.cc
-g -Wall -D_DEBUG -Wshadow -Wno-sign-compare -w -Xlinker -export-dynamic
-I../tensorflow/
-I./kcws/cc/
-I./utils/
-I./tfmodel/
-I./third_party/gflags/include/
-I./third_party/glog/include/
-I/home/soft/boost/include/
-I/usr/include/python2.7/
-I../tensorflow/tensorflow/contrib/makefile/gen/proto
-I../tensorflow/tensorflow/contrib/makefile/downloads/eigen
-I../tensorflow/tensorflow/contrib/makefile/gen/protobuf/include
-I../tensorflow/tensorflow/contrib/makefile/downloads/nsync/public/
-L../tensorflow/bazel-bin/tensorflow -ltensorflow_cc
-L../tensorflow/bazel-bin/tensorflow -ltensorflow_framework
-L./third_party/gflags/lib -lgflags
-L./third_party/gflags/lib -lgflags_nothreads
-L./third_party/glog/lib -lglog
-L/home/soft/boost/lib -lboost_system
-L/usr/lib64 -lpython2.7
-lm
-ldl
-lpthread
注意修改下 自己的代码库路径

@forever1dream 能再解释下上面的方法吗?为何不用bazel编译,然后具体使用依然不懂啊?

@AlleyEli 我是觉得用bazel编译比较麻烦,而且集成到自己的项目中比较麻烦,上面的方法主要是弄清楚解码所需要的依赖 -I是依赖的头文件,-L是所需的.a或者.so 可以直接根据这个些makefile或者其他的,修改的代码主要是将源码中的网络服务去掉,改成,本地输入测试,或者文件测试
image
image

@forever1dream 非常感谢!

@forever1dream 能告诉下训练和使用时候环境吗?
我的Ubuntu16.04 + python2.7 + bzel0.45 + tf1.7.0
我对比下

@forever1dream thanks
我自己用python封装了一套训练加使用的接口(https://github.com/AlleyEli/kcws), 测试也通过了;
主要是对比下环境,之前因为版本兼容问题,困扰好久,所以想看看你的运行环境

@AlleyEli 我把整个编译过程也放到github(https://github.com/forever1dream/cplus-kcws) 修改下自己的tensorflow安装路径(把该项目和tensorflow安装路径放在同级目录,就可以了) 和 boost路径,就可以运行了。感谢你的Python训练方法,我再去看看,多谢啦。