Chainer 类型检查

Posted by 徐志平 on December 17, 2017

类型检查

在本节中,您将了解以下内容:

  • 类型检查的基本用法
  • 类型信息的细节
  • 类型检查的内部机制
  • 更复杂的情况
  • 函数调用
  • 典型的检查例子

阅读本节后,您将能够:

  • 编写一个代码来检查你自己函数的输入参数类型

类型检查的基本用法

当您使用无效类型的数组调用某个函数时,您有时不会收到错误,但会通过广播获得意外的结果。当您使用非法类型的CUDA数组时,会导致内存损坏,并且会出现严重错误。这些错误很难修复。 Chainer可以检查每个函数的先决条件,并有助于防止这些问题。这些条件可以帮助用户理解函数的设定。

每个函数的实现都有一个类型检查方法check_type_forward()。该函数在Function类的forward()方法之前被调用。您可以重写此方法来检查参数的类型和形状的条件。

check_type_forward() 有一个 in_types 参数:

def check_type_forward(self, in_types):
  ...

in_typesTypeInfoTuple的一个实例,它是元组的一个子类。要获取有关第一个参数的类型信息,请使用in_types [0]。如果函数获取多个参数,我们建议使用新的变量来提高可读性:

x_type, y_type = in_types

在这种情况下,x_type表示第一个参数的类型,y_type表示第二个参数。

我们用一个例子来描述in_types的用法。当你想检查x_type的维数是否等于2时,写下这个代码:

utils.type_check.expect(x_type.ndim == 2)

当这个条件成立时,没有任何反应。否则,这段代码会抛出一个异常,并且用户得到这样的消息:

Traceback (most recent call last):
...
chainer.utils.type_check.InvalidType: Expect: in_types[0].ndim == 2
Actual: 3 != 2

这个错误消息意味着“第一个参数的ndim预期为2,但实际上是3”。

类型信息的细节

您可以访问x_type的三方面的信息。

  • .shape 是一个int的元组。每个值都是每个维度的大小。
  • .ndim 是表示维数的int值。请注意,ndim == len(shape)
  • .dtype 是表示数据类型的 numpy.dtype 。

你可以检查所有成员。例如,第一维的大小必须是正值,你可以这样写:

utils.type_check.expect(x_type.shape[0] > 0)

您也可以使用.dtype检查数据类型:

utils.type_check.expect(x_type.dtype == np.float64)

而一个错误是这样的:

Traceback (most recent call last):
...
chainer.utils.type_check.InvalidType: Expect: in_types[0].dtype == <class 'numpy.float64'>
Actual: float32 != <class 'numpy.float64'>

你同样可以检查dtype的kind。下面代码检查type是否为浮点型

utils.type_check.expect(x_type.dtype.kind == 'f')

你可以在变量之间进行比较。例如,下面的代码检查第一个参数和第二个参数是否具有相同的长度:

utils.type_check.expect(x_type.shape[1] == y_type.shape[1])

类型检查的内部机制

它如何显示in_types[0].ndim == 2的错误信息?如果x_type是一个包含ndim成员变量的对象,我们不能显示这样的错误信息,因为Python解释器将此方程作为布尔值进行计算。

其实x_type是一个Expr对象,本身并没有一个ndim成员变量。Expr代表一个语法树。x_type.ndim使Expr对象表示为(getattr,x_type,'ndim')x_type.ndim == 2使对象执行像(eq,(getattr,x_type,'ndim'),2)这样的操作。 type_check.expect()获取Expr对象并对其进行估值运算。如果为真,则不会导致错误,也不会显示任何内容。否则,此方法显示可读的错误消息。

如果要估值Expr对象,请调用eval()方法:

actual_type = x_type.eval()

actual_typeTypeInfo的一个实例,而x_typeExpr的一个实例。以同样的方式, x_type.shape[0].eval()返回一个int值。

更强大的方法

Expr类更强大。它支持所有的数学运算符,如+*。你可以写出一个条件,即x_type的第一个维度是y_type的第一维度乘以四:

utils.type_check.expect(x_type.shape[0] == y_type.shape[0] * 4)

y_type.shape [0] == 1时,用户可以得到下面的错误信息:

Traceback (most recent call last):
...
chainer.utils.type_check.InvalidType: Expect: in_types[0].shape[0] == in_types[1].shape[0] * 4
Actual: 3 != 4

要比较函数的成员变量,请用Variable包装一个值以显示可读的错误消息:

x_type.shape[0] == utils.type_check.Variable(self.in_size, "in_size")

这段代码可以检查下面的等价条件:

x_type.shape[0] == self.in_size

但是,后一种情况不知道这个值的意思。当这个条件不满意时,后面的代码显示不可读的错误信息:

chainer.utils.type_check.InvalidType: Expect: in_types[0].shape[0] == 4  # what does '4' mean?
Actual: 3 != 4

请注意,utils.type_check.Variable的第二个参数仅用于可读性。

前者显示这个消息:

chainer.utils.type_check.InvalidType: Expect: in_types[0].shape[0] == in_size  # OK, `in_size` is a value that is given to the constructor
Actual: 3 != 4  # You can also check actual value here

调用函数

如何检查所有shape值的总和? Expr也支持函数调用:

sum = utils.type_check.Variable(np.sum, 'sum')
utils.type_check.expect(sum(x_type.shape) == 10)

为什么我们需要用utils.type_check.Variable包装函数numpy.sumx_type.shape不是一个元组,而是Expr的一个对象,就像我们之前看到的那样。 因此,numpy.sum(x_type.shape)会失效。

上面的例子产生这样的错误信息:

Traceback (most recent call last):
...
chainer.utils.type_check.InvalidType: Expect: sum(in_types[0].shape) == 10
Actual: 7 != 10

更加复杂的例子

在不能用这些操作符的情况下如何写一个更复杂的条件?您可以使用eval()方法评估Expr并获取其结果值。然后检查情况并手动显示警告消息:

x_shape = x_type.shape.eval()  # get actual shape (int tuple)
if not more_complicated_condition(x_shape):
    expect_msg = 'Shape is expected to be ...'
    actual_msg = 'Shape is ...'
    raise utils.type_check.InvalidType(expect_msg, actual_msg)

请写一个可读的错误信息。此代码会生成以下错误消息:

Traceback (most recent call last):
...
chainer.utils.type_check.InvalidType: Expect: Shape is expected to be ...
Actual: Shape is ...

典型的检查例子

我们给出一个典型的函数检查类型。

首先检查参数的数量:

utils.type_check.expect(in_types.size() == 2)

n_types.size()返回表示参数个数的Expr对象。你可以用同样的方法检查它。

然后,获取每种类型:

x_type, y_type = in_types

检查in_types.size()之前不要获取每个值。当参数个数不合法时,type_check.expect可能会输出无用的错误消息。例如,当in_types的大小为0时,此代码不起作用:

utils.type_check.expect(
  in_types.size() == 2,
  in_types[0].ndim == 3,
)

之后,检查每种类型:

utils.type_check.expect(
  x_type.dtype == np.float32,
  x_type.ndim == 3,
  x_type.shape[1] == 2,
)

上述示例即使在x_type.ndim == 0的情况下也能正常工作,因为所有条件都将被懒惰地评估。