广告
返回顶部
首页 > 资讯 > 后端开发 > Python >Pythonkeras.metrics源代码分析
  • 617
分享到

Pythonkeras.metrics源代码分析

Pythonkeras.metricsPythonkeras.metrics方法Pythonkeras.metrics示例 2022-11-13 19:11:00 617人浏览 独家记忆

Python 官方文档:入门教程 => 点击学习

摘要

目录前言metrics原理解析(以metrics.Mean为例)创建自定义metrics创建无状态 metrics通过继承Metric创建有状态metricsadd_metric()

前言

metrics用于判断模型性能。度量函数类似于损失函数,只是度量的结果不用于训练模型。可以使用任何损失函数作为度量(如logloss等)。在训练期间监控metrics的最佳方式是通过Tensorboard。

官方提供的metrics最重要的概念就是有状态(stateful)变量,通过更新状态变量,可以不断累积统计数据,并可以随时输出状态变量的计算结果。这是区别于losses的重要特性,losses是无状态的(stateless)。

本文部分内容参考了:

Keras-Metrics官方文档

代码运行环境为:tf.__version__==2.6.2 。

metrics原理解析(以metrics.Mean为例)

metrics是有状态的(stateful),即Metric 实例会存储、记录和返回已经累积的结果,有助于未来事务的信息。下面以tf.keras.metrics.Mean()为例进行解释:

创建tf.keras.metrics.Mean的实例:

m = tf.keras.metrics.Mean()

通过help(m) 可以看到MRO为:

Mean
Reduce
Metric
keras.engine.base_layer.Layer
...

可见Metric和Mean是 keras.layers.Layer 的子类。相比于类Layer,其子类Mean多出了几个方法:

  • result: 计算并返回标量度量值(tensor形式)或标量字典,即状态变量简单地计算度量值。例如,m.result(),就是计算均值并返回。
  • total: 状态变量m目前累积的数字总和
  • count: 状态变量m目前累积的数字个数(m.total/m.count就是m.result()的返回值)
  • update_state: 累积统计数字用于计算指标。每次调用m.update_state都会更新m.totalm.count
  • reset_state: 将状态变量重置到初始化状态;
  • reset_states: 等价于reset_state,参见keras源代码metrics.py L355
  • reduction: 目前来看,没什么用。

这也决定了Mean的特殊性质。其使用参见如下代码:

# 创建状态变量m,由于m未刚初始化,
# 所以total,count和result()均为0
m = tf.keras.metrics.Mean()
print("m.total:",m.total)
print("m.count:",m.count)
print("m.result():",m.result())

"""
# 输出:
m.total: <tf.Variable 'total:0' shape=() dtype=float32, numpy=0.0>
m.count: <tf.Variable 'count:0' shape=() dtype=float32, numpy=0.0>
m.result(): tf.Tensor(0.0, shape=(), dtype=float32)
"""

# 更新状态变量,可以看到total累加了总和,
# count累积了个数,result()返回total/count
m.update_state([1,2,3])
print("m.total:",m.total)
print("m.count:",m.count)
print("m.result():",m.result())

"""
# 输出:
m.total: <tf.Variable 'total:0' shape=() dtype=float32, numpy=6.0>
m.count: <tf.Variable 'count:0' shape=() dtype=float32, numpy=3.0>
m.result(): tf.Tensor(2.0, shape=(), dtype=float32)
"""

# 重置状态变量, 重置到初始化状态
m.reset_state()
print("m.total:",m.total)
print("m.count:",m.count)
print("m.result():",m.result())

"""
# 输出:
m.total: <tf.Variable 'total:0' shape=() dtype=float32, numpy=0.0>
m.count: <tf.Variable 'count:0' shape=() dtype=float32, numpy=0.0>
m.result(): tf.Tensor(0.0, shape=(), dtype=float32)
"""

创建自定义metrics

创建无状态 metrics

与损失函数类似,任何带有类似于metric_fn(y_true, y_pred)、返回损失数组(如输入一个batch的数据,会返回一个batch的损失标量)的函数,都可以作为metric传递给compile()

import Tensorflow as tf
import numpy as np
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(1, activation=tf.nn.softmax)(x)
model1 = tf.keras.Model(inputs=inputs, outputs=outputs)
def my_metric_fn(y_true, y_pred):
    squared_difference = tf.square(y_true - y_pred)
    return tf.reduce_mean(squared_difference, axis=-1) # shape=(None,)
model1.compile(optimizer='adam', loss='mse', metrics=[my_metric_fn])
x = np.random.random((100, 3))
y = np.random.random((100, 1))
model1.fit(x, y, epochs=3)

输出:

Epoch 1/3
4/4 [==============================] - 0s 667us/step - loss: 0.0971 - my_metric_fn: 0.0971
Epoch 2/3
4/4 [==============================] - 0s 667us/step - loss: 0.0958 - my_metric_fn: 0.0958
Epoch 3/3
4/4 [==============================] - 0s 1ms/step - loss: 0.0946 - my_metric_fn: 0.0946

注意,因为本例创建的是无状态的度量,所以上面跟踪的度量值(my_metric_fn后面的值)是每个batch的平均度量值,并不是一个epoch(完整数据集)的累积值。(这一点需要理解,这也是为什么要使用有状态度量的原因!)

值得一提的是,如果上述代码使用

model1.compile(optimizer='adam', loss='mse', metrics=["mse"])

进行compile,则输出的结果是累积的,在每个epoch结束时的结果就是整个数据集的结果,因为metrics=["mse"]是直接调用了标准库的有状态度量。

通过继承Metric创建有状态metrics

如果想查看整个数据集的指标,就需要传入有状态的metrics,这样就会在一个epoch内累加,并在epoch结束时输出整个数据集的度量值。

创建有状态度量指标,需要创建Metric的子类,它可以跨batch维护状态,步骤如下:

  • __init__中创建状态变量(state variables)
  • 更新update_state()y_truey_pred的变量
  • result()中返回标量度量结果
  • reset_states()中清除状态
class BinaryTruePositives(tf.keras.metrics.Metric):
    def __init__(self, name='binary_true_positives', **kwargs):
        super(BinaryTruePositives, self).__init__(name=name, **kwargs)
        self.true_positives = self.add_weight(name='tp', initializer='zeros')
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, tf.bool)
        y_pred = tf.cast(y_pred, tf.bool)
        values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
        values = tf.cast(values, self.dtype)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, self.dtype)
            values = tf.multiply(values, sample_weight)
        self.true_positives.assign_add(tf.reduce_sum(values))
    def result(self):
        return self.true_positives
    def reset_states(self):
        self.true_positives.assign(0)
m = BinaryTruePositives()
m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
print('Intermediate result:', float(m.result()))
m.update_state([1, 1, 1, 1], [0, 1, 1, 0])
print('Final result:', float(m.result()))

add_metric()方法

add_metric 方法是 tf.keras.layers.Layer类添加的方法,Layer的父类tf.Module并没有这个方法,因此在编写Layer子类如包括自定义层、官方提供的层(Dense)或模型(tf.keras.Model也是Layer的子类)时,可以使用add_metric()来与层相关的统计量。比如,将类似Dense的自定义层的激活平均值记录为metric。可以执行以下操作:

class DenseLike(Layer):
    """y = w.x + b"""
    ...
    def call(self, inputs):
        output = tf.matmul(inputs, self.w) + self.b
        self.add_metric(tf.reduce_mean(output), aggregation='mean', name='activation_mean')
        return output

将在名称为activation_mean的度量下跟踪output,跟踪的值为每个批次度量值的平均值。

更详细的信息,参阅官方文档The base Layer class - add_metric method。

参考

Keras-Metrics官方文档

到此这篇关于python keras.metrics源代码分析的文章就介绍到这了,更多相关Python keras.metrics内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

--结束END--

本文标题: Pythonkeras.metrics源代码分析

本文链接: https://www.lsjlt.com/news/171111.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • Pythonkeras.metrics源代码分析
    目录前言metrics原理解析(以metrics.Mean为例)创建自定义metrics创建无状态 metrics通过继承Metric创建有状态metricsadd_metric()...
    99+
    2022-11-13
    Python keras.metrics Python keras.metrics方法 Python keras.metrics示例
  • Android ArrayMap源代码分析
          分析源码之前先来介绍一下ArrayMap的存储结构,ArrayMap数据的存储不同于HashMap和SparseA...
    99+
    2022-06-06
    代码分析 Android
  • 如何用源代码分析FileZilla
    这期内容当中小编将会给大家带来有关如何用源代码分析FileZilla,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。FileZilla是一种快速、可信赖的FTP客户端以及服务器端开放源代码程式,具有多种特色...
    99+
    2023-06-16
  • 怎么用源代码分析FileZilla
    本篇文章给大家分享的是有关怎么用源代码分析FileZilla,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。FileZilla是一种快速、可信赖的FTP客户端以及服务器端开放源代...
    99+
    2023-06-16
  • 怎么进行FileZilla源代码分析
    怎么进行FileZilla源代码分析,针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。FileZilla是一种快速、可信赖的FTP客户端以及服务器端开放源代码程式,具有多种特色...
    99+
    2023-06-16
  • Java线程通信源代码分析
    本篇内容介绍了“Java线程通信源代码分析”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!其实我们在源代码中就能发现其中的奥秘。因为Threa...
    99+
    2023-06-17
  • RateLimiter 源码分析
    俗话说得好,缓存,限流和降级是系统的三把利剑。刚好项目中每天早上导出数据时因调订单接口频率过高,订单系统担心会对用户侧的使用造成影响,让我们对调用限速一下,所以就正好用上了。 常用的限流算法有2种:漏桶算法和令牌桶算法。漏桶算法漏...
    99+
    2023-05-31
    ratelimiter 源码 mi
  • SocketServer 源码分析
    Creating network servers. contents SocketServer.py contents file head BaseServer BaseServer.serve_forever BaseServ...
    99+
    2023-01-31
    源码 SocketServer
  • CesiumJS源码分析
    这篇文章主要介绍“CesiumJS源码分析”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“CesiumJS源码分析”文章能帮助大家解决问题。1. 有什么光CesiumJS 支持的光的类型比较少,默认场...
    99+
    2023-07-06
  • 如何进行FileZilla的源代码分析
    这篇文章将为大家详细讲解有关如何进行FileZilla的源代码分析,文章内容质量较高,因此小编分享给大家做个参考,希望大家阅读完这篇文章后对相关知识有一定的了解。FileZilla是一种快速、可信赖的FTP客户端以及服务器端开放源代码程式,...
    99+
    2023-06-16
  • 怎样进行FileZilla的源代码分析
    本篇文章为大家展示了怎样进行FileZilla的源代码分析,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。FileZilla是一种快速、可信赖的FTP客户端以及服务器端开放源代码程式,具有多种特色、直...
    99+
    2023-06-16
  • 开源知识付费APP代码分析
    如今,传统的学校已经不能满足大众多元化的需求,各种教育培训机构落地生根。随着时间的推移,互联网与传统教育的结合也开拓了一种新的教育方式,这就是广为人知的知识付费。在线教育的突然崛起多半是因为疫情的“催...
    99+
    2023-09-06
    java 开发语言 源码软件 php
  • Java notify唤醒源代码的示例分析
    这期内容当中小编将会给大家带来有关Java notify唤醒源代码的示例分析,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。Java notify唤醒在此对象监视器上等待的单个线程。相关的问题需要...
    99+
    2023-06-17
  • Android LayoutInflater.inflate源码分析
    LayoutInflater.inflate源码详解 LayoutInflater的inflate方法相信大家都不陌生,在Fragment的onCreateView中或者在Ba...
    99+
    2022-06-06
    layoutinflater Android
  • Android AsyncTask源码分析
    Android中只能在主线程中进行UI操作,如果是其它子线程,需要借助异步消息处理机制Handler。除此之外,还有个非常方便的AsyncTask类,这个类内部封装了Handl...
    99+
    2022-06-06
    asynctask Android
  • Nebula Graph源码分析
    本篇内容介绍了“Nebula Graph源码分析”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!导读对于一些...
    99+
    2022-10-19
  • Kafka源码分析(一)
    Apache Kafka® 是 一个分布式流处理平台. 这到底意味着什么呢 我们知道流处理平台有以下三种特性: 可以让你发布和订阅流式的记录。这一方面与消息队列或者企业消息系统类似。 可以储存流式的记录,并且有较好的容错性。 可...
    99+
    2019-10-17
    Kafka源码分析(一)
  • django源码分析 LazySetti
    一、django中通过LazySetting对象来获取项目的配置,LazySetting对象有什么特性?为什么使用这个对象? LazySetting顾名思义,就是延迟获取配置内容。比如,我们定义了一个对象A,并对其添加了一些属性,对A初始...
    99+
    2023-01-31
    源码 django LazySetti
  • 分析Android Choreographer源码
    目录一、前言二、主线程运行机制的本质三、Choreographer 简介3.1、Choreographer 的工作流程四、Choreographer 源码分析4.1、Choreogr...
    99+
    2022-11-12
  • Spring cache源码分析
    今天小编给大家分享一下Spring cache源码分析的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。题外话:如何阅...
    99+
    2023-06-29
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作