iis服务器助手广告
返回顶部
首页 > 资讯 > 后端开发 > 其他教程 >C++ TensorflowLite模型验证的过程详解
  • 139
分享到

C++ TensorflowLite模型验证的过程详解

2024-04-02 19:04:59 139人浏览 安东尼
摘要

故事是这样的: 有一个手撑检测的tflite模型,需要在开发板上跑起来。手机版本的已成熟,要移植到开发板上。现在要验证tflite模型文件在板子上的运行结果要和手机上一致。 前提:为

故事是这样的:

有一个手撑检测的tflite模型,需要在开发板上跑起来。手机版本的已成熟,要移植到开发板上。现在要验证tflite模型文件在板子上的运行结果要和手机上一致。

前提:为了多次重复测试,在Android端使用了同一帧数据(从一个录制的mp4中固定取一张图)测试代码如下图

下面是测试过程 

记录下Android版api运行推理前的图片数据文件(经过了规一化处理,所以都是-1~1之间的float数据)

这一步卡在了写float数据到二进制文件中,c++读出来有问题

换了个方案,直接存储float字符串


private void saveFile(float[] pfImageData) {
        try {
            File file = new File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS).getAbsolutePath() + "/tfimg");
 
            StringBuilder sb = new StringBuilder();
            for (float val : pfImageData) {
                //保留4位小数,这里可以改为其他值
                sb.append(String.fORMat("%.4f", val));
                sb.append("\r\n");
            }
 
            FileWriter out = new FileWriter(file);  //文件写入流
            out.write(sb.toString());
            out.close();
        } catch (Exception e) {
            e.printStackTrace();
            Log.e("Melon", "存储文件异常," + e.getMessage());
        }
    }

拿着这个文件在板子上输入到Tflite模型中

测试代码,主要是RunInference()和read_file()



 
#include "Tensorflow/lite/examples/label_image/label_image.h"
 
#include <fcntl.h>     // NOLINT(build/include_order)
#include <getopt.h>    // NOLINT(build/include_order)
#include <sys/time.h>  // NOLINT(build/include_order)
#include <sys/types.h> // NOLINT(build/include_order)
#include <sys/uio.h>   // NOLINT(build/include_order)
#include <unistd.h>    // NOLINT(build/include_order)
 
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <iOStream>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
 
#include "absl/memory/memory.h"
#include "tensorflow/lite/examples/label_image/bitmap_helpers.h"
#include "tensorflow/lite/examples/label_image/get_top_n.h"
#include "tensorflow/lite/examples/label_image/log.h"
#include "tensorflow/lite/kernels/reGISter.h"
#include "tensorflow/lite/optional_debug_tools.h"
#include "tensorflow/lite/profiling/profiler.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
 
namespace tflite
{
  namespace label_image
  {
 
    double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
 
    using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
    using ProvidedDelegateList = tflite::tools::ProvidedDelegateList;
 
    class DelegateProviders
    {
    public:
      DelegateProviders() : delegate_list_util_(&params_)
      {
        delegate_list_util_.AddAllDelegateParams();
      }
 
      // Initialize delegate-related parameters from parsing command line arguments,
      // and remove the matching arguments from (*arGC, argv). Returns true if all
      // recognized arg values are parsed correctly.
      bool InitFromCmdlineArgs(int *argc, const char **argv)
      {
        std::vector<tflite::Flag> flags;
        // delegate_list_util_.AppendCmdlineFlags(&flags);
 
        const bool parse_result = Flags::Parse(argc, argv, flags);
        if (!parse_result)
        {
          std::string usage = Flags::Usage(argv[0], flags);
          LOG(ERROR) << usage;
        }
        return parse_result;
      }
 
      // According to passed-in settings `s`, this function sets corresponding
      // parameters that are defined by various delegate execution providers. See
      // lite/tools/delegates/README.md for the full list of parameters defined.
      void MergeSettingsIntoParams(const Settings &s)
      {
        // Parse settings related to GPU delegate.
        // Note that GPU delegate does support OpenCL. 'gl_backend' was introduced
        // when the GPU delegate only supports OpenGL. Therefore, we consider
        // setting 'gl_backend' to true means using the GPU delegate.
        if (s.gl_backend)
        {
          if (!params_.HasParam("use_gpu"))
          {
            LOG(WARN) << "GPU deleate execution provider isn't linked or GPU "
                         "delegate isn't supported on the platform!";
          }
          else
          {
            params_.Set<bool>("use_gpu", true);
            // The parameter "gpu_inference_for_sustained_speed" isn't available for
            // iOS devices.
            if (params_.HasParam("gpu_inference_for_sustained_speed"))
            {
              params_.Set<bool>("gpu_inference_for_sustained_speed", true);
            }
            params_.Set<bool>("gpu_precision_loss_allowed", s.allow_fp16);
          }
        }
 
        // Parse settings related to NNAPI delegate.
        if (s.accel)
        {
          if (!params_.HasParam("use_nnapi"))
          {
            LOG(WARN) << "NNAPI deleate execution provider isn't linked or NNAPI "
                         "delegate isn't supported on the platform!";
          }
          else
          {
            params_.Set<bool>("use_nnapi", true);
            params_.Set<bool>("nnapi_allow_fp16", s.allow_fp16);
          }
        }
 
        // Parse settings related to HexaGon delegate.
        if (s.hexagon_delegate)
        {
          if (!params_.HasParam("use_hexagon"))
          {
            LOG(WARN) << "Hexagon deleate execution provider isn't linked or "
                         "Hexagon delegate isn't supported on the platform!";
          }
          else
          {
            params_.Set<bool>("use_hexagon", true);
            params_.Set<bool>("hexagon_profiling", s.profiling);
          }
        }
 
        // Parse settings related to XNNPACK delegate.
        if (s.xnnpack_delegate)
        {
          if (!params_.HasParam("use_xnnpack"))
          {
            LOG(WARN) << "XNNPACK deleate execution provider isn't linked or "
                         "XNNPACK delegate isn't supported on the platform!";
          }
          else
          {
            params_.Set<bool>("use_xnnpack", true);
            params_.Set<bool>("num_threads", s.number_of_threads);
          }
        }
      }
 
      // Create a list of TfLite delegates based on what have been initialized (i.e.
      // 'params_').
      std::vector<ProvidedDelegateList::ProvidedDelegate> CreateAllDelegates()
          const
      {
        return delegate_list_util_.CreateAllRankedDelegates();
      }
 
    private:
      // Contain delegate-related parameters that are initialized from command-line
      // flags.
      tflite::tools::ToolParams params_;
 
      // A helper to create TfLite delegates.
      ProvidedDelegateList delegate_list_util_;
    };
 
    // Takes a file name, and loads a list of labels from it, one per line, and
    // returns a vector of the strings. It pads with empty strings so the length
    // of the result is a multiple of 16, because our model expects that.
 
    // std::vector<uint8_t> read_file(const std::string &input_bmp_name)
    // {
    //   int begin, end;
 
    //   std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary);
    //   if (!file)
    //   {
    //     LOG(FATAL) << "input file " << input_bmp_name << " not found";
    //     exit(-1);
    //   }
 
    //   begin = file.tellg();
    //   file.seekg(0, std::ios::end);
    //   end = file.tellg();
    //   size_t len = end - begin;
 
    //   LOG(INFO) << "len: " << len;
    //   std::vector<uint8_t> img_bytes(len);
 
    //   file.seekg(0, std::ios::beg);
    //   file.read(reinterpret_cast<char *>(img_bytes.data()), len);
 
    //   return img_bytes;
    // }
 
    
    std::vector<float> read_file(const std::string &input_bmp_name)
    {
      int begin, end;
 
      std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary);
      if (!file)
      {
        LOG(FATAL) << "input file " << input_bmp_name << " not found";
        exit(-1);
      }
 
      begin = file.tellg();
      file.seekg(0, std::ios::end);
      end = file.tellg();
      size_t len = end - begin;
 
      LOG(INFO) << "len: " << len;
      std::vector<float> img_bytes;
 
      file.seekg(0, std::ios::beg);
 
      string strLine = "";
      float temp;
      while (getline(file, strLine))
      {
        temp = atof(strLine.c_str());
        img_bytes.push_back(temp);
      }
 
      LOG(INFO) << "文件读取完成:" << input_bmp_name;
      return img_bytes;
    }
 
    
    void RunInference(Settings *settings)
    {
      if (!settings->model_name.c_str())
      {
        LOG(ERROR) << "no model file name";
        exit(-1);
      }
 
      std::unique_ptr<tflite::FlatBufferModel> model;
      std::unique_ptr<tflite::Interpreter> interpreter;
      model = tflite::FlatBufferModel::BuildFromFile(settings->model_name.c_str());
      if (!model)
      {
        LOG(ERROR) << "Failed to mmap model " << settings->model_name;
        exit(-1);
      }
      settings->model = model.get();
      LOG(INFO) << "Loaded model " << settings->model_name;
      model->error_reporter();
      LOG(INFO) << "resolved reporter";
 
      tflite::ops::builtin::BuiltinOpResolver resolver;
 
      tflite::InterpreterBuilder(*model, resolver)(&interpreter); //生成interpreter
      if (!interpreter)
      {
        LOG(ERROR) << "Failed to construct interpreter";
        exit(-1);
      }
 
      interpreter->SetAllowFp16PrecisionForFp32(settings->allow_fp16);
 
      if (settings->verbose)
      {
        LOG(INFO) << "tensors size: " << interpreter->tensors_size();
        LOG(INFO) << "nodes size: " << interpreter->nodes_size();
        LOG(INFO) << "inputs: " << interpreter->inputs().size();
        LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0);
 
        int t_size = interpreter->tensors_size();
        for (int i = 0; i < t_size; i++)
        {
          if (interpreter->tensor(i)->name)
            LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
                      << interpreter->tensor(i)->bytes << ", "
                      << interpreter->tensor(i)->type << ", "
                      << interpreter->tensor(i)->params.scale << ", "
                      << interpreter->tensor(i)->params.zero_point;
        }
      }
 
      if (settings->number_of_threads != -1)
      {
        interpreter->SetNumThreads(settings->number_of_threads);
      }
 
      int image_width = 128;
      int image_height = 128;
      int image_channels = 3;
      // std::vector<uint8_t> in = read_bmp(settings->input_bmp_name, &image_width, &image_height, &image_channels, settings);
      std::vector<float> file_bytes = read_file(settings->input_bmp_name);
      for (int i = 0; i < 100; i++)
      {
        //和Android的输入做对比
        LOG(INFO) << i << ": " << file_bytes[i];
      }
 
      
      int input = interpreter->inputs()[0];
      LOG(INFO) << "input: " << input;
 
      const std::vector<int> inputs = interpreter->inputs();
      const std::vector<int> outputs = interpreter->outputs();
 
      LOG(INFO) << "number of inputs: " << inputs.size();
      LOG(INFO) << "input index: " << inputs[0];
      LOG(INFO) << "number of outputs: " << outputs.size();
      LOG(INFO) << "outputs index1: " << outputs[0] << ",outputs index2: " << outputs[1];
 
      if (interpreter->AllocateTensors() != kTfLiteOk)
      { //加载所有tensor
        LOG(ERROR) << "Failed to allocate tensors!";
        exit(-1);
      }
 
      if (settings->verbose)
        PrintInterpreterState(interpreter.get());
 
      // 从输入张量的原数据中得到输入尺寸
      TfLiteIntArray *dims = interpreter->tensor(input)->dims;
      int wanted_height = dims->data[1];
      int wanted_width = dims->data[2];
      int wanted_channels = dims->data[3];
 
      settings->input_type = interpreter->tensor(input)->type;
 
      //typed_tensor返回一个经过固定数据类型转换的tensor指针
      //以input为索引,在TfLiteTensor* content_.tensors这个张量表得到具体的张量
      //返回该张量的data.raw,它指示张量正关联着的内存块
      // resize<float>(interpreter->typed_tensor<float>(input), in.data(),
      //               image_height, image_width, image_channels, wanted_height,
      //               wanted_width, wanted_channels, settings);
 
      //赋值给input tensor
      float *inputP = interpreter->typed_input_tensor<float>(0);
 
      LOG(INFO) << "file_bytes size: " << file_bytes.size();
      for (int i = 0; i < file_bytes.size(); i++)
      {
        inputP[i] = file_bytes[i];
      }
 
      struct timeval start_time, stop_time;
      gettimeofday(&start_time, nullptr);
      for (int i = 0; i < settings->loop_count; i++)
      { //调用模型进行推理
        if (interpreter->Invoke() != kTfLiteOk)
        {
          LOG(ERROR) << "Failed to invoke tflite!";
          exit(-1);
        }
      }
      gettimeofday(&stop_time, nullptr);
      LOG(INFO) << "invoked";
      LOG(INFO) << "average time: "
                << (get_us(stop_time) - get_us(start_time)) /
                       (settings->loop_count * 1000)
                << " ms";
 
      const float threshold = 0.001f;
 
      int output = interpreter->outputs()[1];
      LOG(INFO) << "output: " << output;
      LOG(INFO) << "interpreter->tensors_size: " << interpreter->tensors_size();
 
      TfLiteTensor *tensor = interpreter->tensor(output);
 
      TfLiteIntArray *output_dims = tensor->dims;
      // assume output dims to be something like (1, 1, ... ,size)
      auto output_size = output_dims->data[output_dims->size - 1];
      LOG(INFO) << "索引为" << output << "的输出张量的-"
                << "output_size: " << output_size;
 
      for (int i = 0; i < output_dims->size; i++)
      {
        LOG(INFO) << "元数据有:" << output_dims->data[i];
      }
 
      float *prediction = interpreter->typed_output_tensor<float>(1);
 
      float classificators[1][896][1];
      memcpy(classificators, prediction, 896 * 1 * sizeof(float));
      // float classificators[1][896][18];
      // memcpy(classificators, prediction, 896 * 18 * sizeof(float));
 
      //输出分类结果
      for (float(&r)[896][1] : classificators)
      {
        for (float(&p)[1] : r)
        {
          for (float &q : p)
          {
            std::cout << q << ' ';
          }
          std::cout << std::endl;
        }
        std::cout << std::endl;
      }
    }
 
    int Main(int argc, char **argv)
    {
      DelegateProviders delegate_providers;
      bool parse_result = delegate_providers.InitFromCmdlineArgs(
          &argc, const_cast<const char **>(argv));
      if (!parse_result)
      {
        return EXIT_FAILURE;
      }
 
      Settings s;
 
      int c;
      while (true)
      {
        static struct option long_options[] = {
            {"accelerated", required_argument, nullptr, 'a'},
            {"allow_fp16", required_argument, nullptr, 'f'},
            {"count", required_argument, nullptr, 'c'},
            {"verbose", required_argument, nullptr, 'v'},
            {"image", required_argument, nullptr, 'i'},
            {"labels", required_argument, nullptr, 'l'},
            {"tflite_model", required_argument, nullptr, 'm'},
            {"profiling", required_argument, nullptr, 'p'},
            {"threads", required_argument, nullptr, 't'},
            {"input_mean", required_argument, nullptr, 'b'},
            {"input_std", required_argument, nullptr, 's'},
            {"num_results", required_argument, nullptr, 'r'},
            {"max_profiling_buffer_entries", required_argument, nullptr, 'e'},
            {"warmup_runs", required_argument, nullptr, 'w'},
            {"gl_backend", required_argument, nullptr, 'g'},
            {"hexagon_delegate", required_argument, nullptr, 'j'},
            {"xnnpack_delegate", required_argument, nullptr, 'x'},
            {nullptr, 0, nullptr, 0}};
 
        
        int option_index = 0;
 
        c = getopt_long(argc, argv,
                        "a:b:c:d:e:f:g:i:j:l:m:p:r:s:t:v:w:x:", long_options,
                        &option_index);
 
        
        if (c == -1)
          break;
 
        switch (c)
        {
        case 'a':
          s.accel = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'b':
          s.input_mean = strtod(optarg, nullptr);
          break;
        case 'c':
          s.loop_count =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'e':
          s.max_profiling_buffer_entries =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'f':
          s.allow_fp16 =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'g':
          s.gl_backend =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'i':
          s.input_bmp_name = optarg;
          break;
        case 'j':
          s.hexagon_delegate = optarg;
          break;
        case 'l':
          s.labels_file_name = optarg;
          break;
        case 'm':
          s.model_name = optarg;
          break;
        case 'p':
          s.profiling =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'r':
          s.number_of_results =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 's':
          s.input_std = strtod(optarg, nullptr);
          break;
        case 't':
          s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn)
              optarg, nullptr, 10);
          break;
        case 'v':
          s.verbose =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'w':
          s.number_of_warmup_runs =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'x':
          s.xnnpack_delegate =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'h':
        case '?':
          
          exit(-1);
        default:
          exit(-1);
        }
      }
 
      delegate_providers.MergeSettingsIntoParams(s);
      RunInference(&s);
      return 0;
    }
 
  } // namespace label_image
} // namespace tflite
 
int main(int argc, char **argv)
{
  return tflite::label_image::Main(argc, argv);
}

运行指令 ./ws_app --tflite_model libnewpalm_detection.tflite --image tfimg对比推理前的输入一致

Android端

开发板上

对比推理后的输出一致 Android端

开发板端

到此这篇关于C++ TensorflowLite模型验证的文章就介绍到这了,更多相关C++ TensorflowLite模型验证内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

--结束END--

本文标题: C++ TensorflowLite模型验证的过程详解

本文链接: https://www.lsjlt.com/news/133731.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作