🙌

CMU10-414/714 作业总结

CMU10-414是一门教你从头构建深度学习系统的课,从最底层的“Numpy”,再到Tensor的构建,最后到深度学习库抽象的实现。做完作业后,感觉收获很多,重新从头梳理一遍构建一个真正的深度学习系统所需要的全部部件。

构建思路

从我的视角来说,一个深度学习系统的需要有如下几个部分:
notion image
在作业中并不是以这样的顺序构建的,我将以自下(在图中是上)而上的顺序重新审视整个在作业中构建的深度学习系统。

构建过程

底层数据结构

底层数据结构是整个深度学习系统的基础。在这个课程中,我们首先实现了一个类似于Numpy的多维数组库,称为NDArray。这个库提供了基本的数学运算和数组操作,为后续的张量计算和自动微分奠定了基础。NDArray支持各种数据类型和设备(如CPU和GPU),使得我们能够高效地进行大规模数值计算。以下的NDArray的构造函数:
class NDArray:
    """A generic ND array class that may contain multipe different backends
    i.e., a Numpy backend, a native CPU backend, or a GPU backend.

    This class will only contains those functions that you need to implement
    to actually get the desired functionality for the programming examples
    in the homework, and no more.

    For now, for simplicity the class only supports float32 types, though
    this can be extended if desired.
    """

    def __init__(self, other, device=None):
        """Create by copying another NDArray, or from numpy"""
        if isinstance(other, NDArray):
            # create a copy of existing NDArray
            if device is None:
                device = other.device
            self._init(other.to(device) + 0.0)  # this creates a copy
        elif isinstance(other, np.ndarray):
            # create copy from numpy array
            device = device if device is not None else default_device()
            array = self.make(other.shape, device=device)
            array.device.from_numpy(np.ascontiguousarray(other), array._handle)
            self._init(array)
        else:
            # see if we can create a numpy array from input
            array = NDArray(np.array(other), device=device)
            self._init(array)

    def _init(self, other):
        self._shape = other._shape
        self._strides = other._strides
        self._offset = other._offset
        self._device = other._device
        self._handle = other._handle
光看这段代码可能会难以理解,不妨试着解释一下,当一个NDArray对象被创建时,到底发生了什么。

数据开辟

x = nd.NDArray([1, 2, 3], device=nd.cpu())
x = nd.NDArray([1, 2, 3], device=nd.cuda())
首先,list并不是一个NDArray对象,构造函数先会把它转换成一个numpy.array,并调用核心方法make
@staticmethod
    def make(shape, strides=None, device=None, handle=None, offset=0):
        """Create a new NDArray with the given properties.  This will allocation the
        memory if handle=None, otherwise it will use the handle of an existing
        array."""
        array = NDArray.__new__(NDArray)
        array._shape = tuple(shape)
        array._strides = NDArray.compact_strides(shape) if strides is None else strides
        array._offset = offset
        array._device = device if device is not None else default_device()
        if handle is None:
            array._handle = array.device.Array(prod(shape))
        else:
            array._handle = handle
        return array
各个字段的意义不多赘述,其中strides的值默认为“正常”的strides(即跟shape对应的)。唯一需要说明的是handle字段,对应了NDArray的真实底层物理地址。要理解handle是哪里来的,我们首先要搞清楚device是什么:
class BackendDevice:
    """A backend device, wrapps the implementation module."""

    def __init__(self, name, mod):
        self.name = name
        self.mod = mod

    def __eq__(self, other):
        return self.name == other.name

    def __repr__(self):
        return self.name + "()"

    def __getattr__(self, name):
        return getattr(self.mod, name)
        
def cuda():
    """Return cuda device"""
    try:
        from . import ndarray_backend_cuda

        return BackendDevice("cuda", ndarray_backend_cuda)
    except ImportError:
        return BackendDevice("cuda", None)


def cpu_numpy():
    """Return numpy device"""
    return BackendDevice("cpu_numpy", ndarray_backend_numpy)


def cpu():
    """Return cpu device"""
    return BackendDevice("cpu", ndarray_backend_cpu)
这里使用了代理模式,当试图调用device这个属性时,会调用相应的模块(即ndarray_backend_cudandarray_backend_cpu等),随后调用相应模块的对应方法。以ndarray_backend_cpu为例,device.Array对应着:
py::class_<AlignedArray>(m, "Array")
      .def(py::init<size_t>(), py::return_value_policy::take_ownership)
      .def("ptr", &AlignedArray::ptr_as_int)
      .def_readonly("size", &AlignedArray::size);
pybind11Array与C++代码中的AlignedArray类相绑定,调用Array的构造方法就是调用AlignedArray的构造方法:
struct AlignedArray {
  AlignedArray(const size_t size) {
    int ret = posix_memalign((void**)&ptr, ALIGNMENT, size * ELEM_SIZE);
    if (ret != 0) throw std::bad_alloc();
    this->size = size;
  }
  ~AlignedArray() { free(ptr); }
  size_t ptr_as_int() {return (size_t)ptr; }
  scalar_t* ptr;
  size_t size;
};
明显看到,这个类的初始化函数调用posix_memalign方法(其实就是malloc)来开辟地址空间。array._handle对应的也并非一个具体的地址,而是一个“包装类“。
有了空间之后,就需要将数据写进地址中:
array.device.from_numpy(np.ascontiguousarray(other), array._handle)
np.ascontiguousarray(other)的具体功能是让other的存储连续。接着调用device.from_numpyother的内存拷贝到handle中:
  m.def("from_numpy", [](py::array_t<scalar_t> a, AlignedArray* out) {
    std::memcpy(out->ptr, a.request().ptr, out->size * ELEM_SIZE);
  });
可以看到就是一个简单的memcpy。至此,我们已经拥有了一个可以存储数据的NDArray了。

数据运算

有了数据之后,我们只是有了一个高级一点的“数组”而已,所以我们还需要让NDArray可以进行运算(以加法为例,其他是类似的):
def ewise_or_scalar(self, other, ewise_func, scalar_func):
        """Run either an elementwise or scalar version of a function,
        depending on whether "other" is an NDArray or scalar
        """
        out = NDArray.make(self.shape, device=self.device)
        if isinstance(other, NDArray):
            assert self.shape == other.shape, "operation needs two equal-sized arrays"
            ewise_func(self.compact()._handle, other.compact()._handle, out._handle)
        else:
            scalar_func(self.compact()._handle, other, out._handle)
        return out

def __add__(self, other):
        return self.ewise_or_scalar(
            other, self.device.ewise_add, self.device.scalar_add
        )

__radd__ = __add__
运算首先会调用ewise_or_scalar方法,用于判断是向量之间的运算还是向量和标量之间的运算。接着调用self.device.xxx方法实现运算。不过值得注意的是,在运算之前,我们首先调用了compact方法来让NDArray的底层数据连续:
def compact(self):
        """Convert a matrix to be compact"""
        if self.is_compact():
            return self
        else:
            out = NDArray.make(self.shape, device=self.device)
            self.device.compact(
                self._handle, out._handle, self.shape, self.strides, self._offset
            )
            return out
这段代码调用了各个device自己实现的compact方法,将目标NDArray的底层数据变得连续,以cpu的实现为例:
void Helper(const AlignedArray* a, AlignedArray* out, std::vector<int32_t> shape,
                    std::vector<int32_t> strides, size_t offset, scalar_t val, OP_MODE mode) {
  int size = shape.size();
  std::vector<uint32_t> loops(size, 0);
  int cnt = 0;

  while(true) {
    int index = offset;
    for (int i = 0; i < loops.size(); i++) {
      index += loops[i] * strides[i];
    }

    switch (mode)
    {
    case W_IN: out->ptr[cnt++] = a->ptr[index]; break;
    case W_OUT: out->ptr[index] = a->ptr[cnt++]; break;
    case SET: out->ptr[index] = val; break;
    default: break;
    }
    

    for (int i = loops.size() - 1; i >= 0; i--) {
      if (loops[i] < shape[i] - 1) {
        loops[i]++;
        break;
      } else if (i == 0) {
        return;
      } else {
        loops[i] = 0;
      }
    }
  }
}

void Compact(const AlignedArray& a, AlignedArray* out, std::vector<int32_t> shape,
             std::vector<int32_t> strides, size_t offset) {
  /**
   * Compact an array in memory
   *
   * Args:
   *   a: non-compact representation of the array, given as input
   *   out: compact version of the array to be written
   *   shape: shapes of each dimension for a and out
   *   strides: strides of the *a* array (not out, which has compact strides)
   *   offset: offset of the *a* array (not out, which has zero offset, being compact)
   *
   * Returns:
   *  void (you need to modify out directly, rather than returning anything; this is true for all the
   *  function will implement here, so we won't repeat this note.)
   */
  Helper(&a, out, shape, strides, offset, 0, W_IN);
}
简单的来说,就是将源NDArray中的内存依次按顺序地拷贝到out中即可,但是我们并不知道源NDArray的形状,因此我们不能写固定数量个for循环来实现拷贝,而是要在Helper函数中”手动“实现for循环。
回到上面,在实现了Compact后,数据已经在内存中变得连续,实现加法运算就比较简单了:
void EwiseAdd(const AlignedArray& a, const AlignedArray& b, AlignedArray* out) {
  /**
   * Set entries in out to be the sum of correspondings entires in a and b.
   */
  for (size_t i = 0; i < a.size; i++) {
    out->ptr[i] = a.ptr[i] + b.ptr[i];
  }
}

void ScalarAdd(const AlignedArray& a, scalar_t val, AlignedArray* out) {
  /**
   * Set entries in out to be the sum of corresponding entry in a plus the scalar val.
   */
  for (size_t i = 0; i < a.size; i++) {
    out->ptr[i] = a.ptr[i] + val;
  }
}
到目前为止,我们已经拥有了一个类似于Numpy一样的数据结构了!

自动微分实现

关键数据结构:Tensor

拥有了NDArray之后,我们依旧无法做到实现自动微分。因为我们并没有在NDArray中记录有关计算的“上下文信息”,即计算图信息。
所以我们需要一个新的数据结构来封装NDArray并添加额外的功能。这个新的数据结构就是TensorTensor不仅包含了数据(通过NDArray),还记录了计算图信息,使得我们能够进行反向传播和梯度计算:
import numpy as array_api
NDArray = numpy.ndarray

class Tensor(Value):
    grad: "Tensor"

    def __init__(
        self,
        array,
        *,
        device: Optional[Device] = None,
        dtype=None,
        requires_grad=True,
        **kwargs
    ):
        if isinstance(array, Tensor):
            if device is None:
                device = array.device
            if dtype is None:
                dtype = array.dtype
            if device == array.device and dtype == array.dtype:
                cached_data = array.realize_cached_data()
            else:
                # fall back, copy through numpy conversion
                cached_data = Tensor._array_from_numpy(
                    array.numpy(), device=device, dtype=dtype
                )
        else:
            device = device if device else cpu()
            cached_data = Tensor._array_from_numpy(array, device=device, dtype=dtype)

        self._init(
            None,
            [],
            cached_data=cached_data,
            requires_grad=requires_grad,
        )

    @staticmethod
    def _array_from_numpy(numpy_array, device, dtype):
        if array_api is numpy:
            return numpy.array(numpy_array, dtype=dtype)
        return array_api.array(numpy_array, device=device, dtype=dtype)

    @staticmethod
    def make_from_op(op: Op, inputs: List["Value"]):
        tensor = Tensor.__new__(Tensor)
        tensor._init(op, inputs)
        if not LAZY_MODE:
            if not tensor.requires_grad:
                return tensor.detach()
            tensor.realize_cached_data()
        return tensor
这个Tensor类的实现包含了几个关键部分:
  1. 初始化方法:可以从现有的Tensor、NumPy数组或其他数据源创建Tensor
  1. 静态方法_array_from_numpy:用于将NumPy数组转换为设备特定的数组。
  1. 静态方法make_from_op:用于从操作和输入创建新的Tensor,这是构建计算图的关键。
这里的关键是,如何使用make_from_op实现计算图的构建。

计算图的构建

让我们更详细地探讨make_from_op方法如何构建计算图:
  1. 当执行操作时(如加法、乘法等),会调用make_from_op方法。(仍然以加法为例,下文代码块中的函数调用顺序自上而下)。
def __add__(self, other):
        if isinstance(other, Tensor):
            return needle.ops.EWiseAdd()(self, other)
        else:
            return needle.ops.AddScalar(other)(self)
            
class EWiseAdd(TensorOp):
    def compute(self, a: NDArray, b: NDArray):
        return a + b

    def gradient(self, out_grad: Tensor, node: Tensor):
        return out_grad, out_grad


def add(a, b):
    return EWiseAdd()(a, b)
  1. 在初始化EWiseAdd对象并直接调用时,会调用它父类(TensorOp)的__call__函数:
    1. class TensorOp(Op):
          def __call__(self, *args):
              return Tensor.make_from_op(self, args)
  1. 让我们深入解析 make_from_op 方法的实现:
@staticmethod
def make_from_op(op: Op, inputs: List["Value"]):
    tensor = Tensor.__new__(Tensor)
    tensor._init(op, inputs)
    if not LAZY_MODE:
        if not tensor.requires_grad:
            return tensor.detach()
        tensor.realize_cached_data()
    return tensor
  
def _init(
        self,
        op: Optional[Op],
        inputs: List["Tensor"],
        *,
        num_outputs: int = 1,
        cached_data: List[object] = None,
        requires_grad: Optional[bool] = None
    ):
        global TENSOR_COUNTER
        TENSOR_COUNTER += 1
        if requires_grad is None:
            requires_grad = any(x.requires_grad for x in inputs)
        self.op = op
        self.inputs = inputs
        self.num_outputs = num_outputs
        self.cached_data = cached_data
        self.requires_grad = requires_grad
这个方法是构建计算图的核心。它的工作原理如下:
  • 创建一个新的 Tensor 对象,但不调用 __init__ 方法。
  • 调用 _init 方法,将操作 op 和输入 inputs 与这个新的 Tensor 关联起来。这一步骤建立了计算图的连接。
  • 如果不是懒惰模式(LAZY_MODE 为 False):
    • 如果新 Tensor不需要梯度,直接返回其分离版本。
    • 否则,实现缓存数据(即,实际执行计算)。
  • 返回新创建的 Tensor
通过这个方法,每个 Tensor 都记录了它是如何被创建的(通过哪个操作),以及它的输入是什么。这种信息的记录使得我们能够在反向传播时追踪计算图,并正确地计算梯度。

反向传播实现:Reverse mode AD

notion image
反向传播(Reverse mode AD)是自动微分的核心部分,它允许我们高效地计算复杂函数的梯度。在我们的实现中,反向传播主要通过 Tensor 类的 backward 方法来实现。这个方法会遍历计算图,从输出节点开始,逐步计算每个节点的梯度。以下是 backward 方法的核心实现:
def backward(self, out_grad=None):
        out_grad = (
            out_grad
            if out_grad
            else init.ones(*self.shape, dtype=self.dtype, device=self.device)
        )
        compute_gradient_of_variables(self, out_grad)
        
def compute_gradient_of_variables(output_tensor, out_grad):
    """Take gradient of output node with respect to each node in node_list.

    Store the computed result in the grad field of each Variable.
    """
    # a map from node to a list of gradient contributions from each output node
    node_to_output_grads_list: Dict[Tensor, List[Tensor]] = {}
    # Special note on initializing gradient of
    # We are really taking a derivative of the scalar reduce_sum(output_node)
    # instead of the vector output_node. But this is the common case for loss function.
    node_to_output_grads_list[output_tensor] = [out_grad]

    # Traverse graph in reverse topological order given the output_node that we are taking gradient wrt.
    reverse_topo_order = list(reversed(find_topo_sort([output_tensor])))

    for node in reverse_topo_order:
        ajoint = sum_node_list(node_to_output_grads_list[node])
        node.grad = ajoint
        
        if node.op is None:
            continue

        parial_ajoint = node.op.gradient_as_tuple(ajoint, node)
        for input, parial in zip(node.inputs, parial_ajoint):
            if input not in node_to_output_grads_list:
                node_to_output_grads_list[input] = []
            node_to_output_grads_list[input].append(parial)
            
 def gradient_as_tuple(self, out_grad: "Value", node: "Value") -> Tuple["Value"]:
        """Convenience method to always return a tuple from gradient call"""
        output = self.gradient(out_grad, node)
        if isinstance(output, tuple):
            return output
        elif isinstance(output, list):
            return tuple(output)
        else:
            return (output,)
这个反向传播的实现过程可以总结为以下几个步骤:
  1. 初始化输出梯度:如果没有提供输出梯度,则默认为1。
  1. 构建反向拓扑排序:从输出节点开始,按照计算图的依赖关系逆序排列所有节点。(DFS)
def find_topo_sort(node_list: List[Value]) -> List[Value]:
    anwser = []
    visited = set()
    for node in node_list:
            topo_sort_dfs(node, visited, anwser)
    return anwser


def topo_sort_dfs(node, visited, topo_order):
    """Post-order DFS"""
    if node in visited:
        return
    visited.add(node)
    for input in node.inputs:
        if input not in visited:
            topo_sort_dfs(input, visited, topo_order)
    topo_order.append(node)
  1. 逐节点计算梯度:对于每个节点,计算其梯度并存储,然后使用链式法则计算其输入节点的梯度贡献。
这种反向传播的实现方式非常高效,因为它只需要遍历计算图一次就能计算出所有节点的梯度。通过使用拓扑排序,我们确保了在计算某个节点的梯度时,其所有依赖节点的梯度都已经被计算出来。这种方法不仅计算速度快,而且内存效率高,因为我们可以在计算完某个节点的梯度后立即释放不再需要的中间结果。
此外,这种方法与原地实现的反向传播不同,它在计算梯度时构建了新的计算图。这种方式有以下几个优点:
  • 允许我们计算高阶导数:由于每次反向传播都会创建新的计算图,我们可以对梯度再次应用反向传播,从而计算二阶甚至更高阶的导数。这在一些高级优化算法和某些机器学习模型中非常有用。
  • 增加了灵活性:我们可以在不影响原始计算图的情况下,对梯度计算过程进行修改或扩展。这对于实现一些复杂的优化策略或自定义的训练流程很有帮助。
  • 便于调试:由于梯度计算过程被显式地表示为一个计算图,我们可以更容易地检查和调试梯度计算的中间步骤。
 

深度学习库抽象

在深度学习库抽象层面,我们实现了一系列常用的组件,这些组件构成了深度学习框架的核心。以下是我们实现的关键组件:
  • 模块(Module):这是深度学习模型的基本构建块。我们实现了各种常用的神经网络层,如线性层(Linear)、卷积层(Conv)等,以及激活函数如ReLU、Sigmoid等。Module类提供了forward方法用于前向传播。
  • 损失函数(Loss):用于衡量模型输出与目标之间的差异。我们实现了常用的损失函数,如均方误差(MSELoss)和交叉熵损失(CrossEntropyLoss)。
  • 优化器(Optimizer):负责更新模型参数以最小化损失函数。我们实现了常见的优化算法,如随机梯度下降(SGD)和Adam优化器。
  • 数据加载器(DataLoader):用于高效地批量加载和预处理数据。它支持随机打乱、批量处理和并行加载等功能,使得数据处理过程更加高效。
这些组件共同工作,使得我们能够轻松构建、训练和评估复杂的深度学习模型。通过这种抽象,我们不仅简化了模型的构建过程,还提高了代码的可读性和可维护性。
现在,我们来逐一查看这些是如何实现的。

模块(nn.Module)

首先,让我们来看一下模块(Module)的实现。Module 是深度学习模型的基本构建单元,它封装了模型的参数和前向传播逻辑。以下是 Module 类的基本结构:
class Module:
    def __init__(self):
        self.training = True

    def parameters(self) -> List[Tensor]:
        """Return the list of parameters in the module."""
        return _unpack_params(self.__dict__)

    def _children(self) -> List["Module"]:
        return _child_modules(self.__dict__)

    def eval(self):
        self.training = False
        for m in self._children():
            m.training = False

    def train(self):
        self.training = True
        for m in self._children():
            m.training = True

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)
这个基础的 Module 类提供了一些核心功能:
  1. parameters() 方法返回模块中所有可训练的参数。
  1. _children() 方法返回所有子模块。
  1. eval()train() 方法用于切换模块的训练和评估模式。
  1. call() 方法允许我们像调用函数一样使用模块实例。
基于这个基类,我们可以实现各种特定的神经网络层。例如,让我们看一下线性层(Linear)的实现:
class Linear(Module):
    def __init__(
        self, in_features, out_features, bias=True, device=None, dtype="float32"
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.have_bias = bias
        self.weight = Parameter(init.kaiming_uniform(in_features, out_features,device = device, dtype = dtype))
        if self.have_bias:
            self.bias = Parameter(init.kaiming_uniform(out_features, 1, device = device, dtype = dtype).reshape((1, out_features)))

    def forward(self, X: Tensor) -> Tensor:
        if self.have_bias:
            bias = ops.broadcast_to(self.bias, (X.shape[0], self.out_features))
            return ops.matmul(X, self.weight) + bias
        else:
            return ops.matmul(X, self.weight)

损失函数(Loss)

损失函数是衡量模型预测结果与真实标签之间差异的关键组件。在我们的实现中,我们创建了一些常用的损失函数,如均方误差(MSE)和交叉熵损失。这些损失函数不仅计算损失值,还能计算梯度,为反向传播提供必要的信息。
class SoftmaxLoss(Module):
    def forward(self, logits: Tensor, y: Tensor):
        one_hot_y = init.one_hot(logits.shape[1], y)
        return (ops.summation(ops.logsumexp(logits, (1,)) / logits.shape[0])) - ops.summation(one_hot_y * logits / logits.shape[0])

优化器(Optimizer)

优化器是深度学习训练过程中的关键组件,负责更新模型参数以最小化损失函数。在我们的实现中,我们创建了几种常用的优化算法,如随机梯度下降(SGD)和Adam优化器。这些优化器接收模型参数和计算得到的梯度,然后根据特定的算法更新参数。
class Optimizer:
    def __init__(self, params):
        self.params = params

    def step(self):
        raise NotImplementedError()

    def reset_grad(self):
        for p in self.params:
            p.grad = None
基于这个基类,我们可以实现具体的优化器。例如,SGD(随机梯度下降)优化器的实现如下:
class SGD(Optimizer):
    def __init__(self, params, lr=0.01, momentum=0.0, weight_decay=0.0):
        super().__init__(params)
        self.lr = lr
        self.momentum = momentum
        self.u = {}
        self.weight_decay = weight_decay

    def step(self):
        for param in self.params:
            grad  = self.u.get(param, 0) * self.momentum + (1-self.momentum) * (param.grad.data + self.weight_decay * param.data)
            grad = ndl.Tensor(grad, dtype = param.dtype)
            self.u[param] = grad
            param.data =  param.data - self.lr * grad.data
这个实现包含了动量和权重衰减的功能,使得优化过程更加稳定和高效。通过这样的抽象,我们可以轻松地实现和使用各种优化算法,为深度学习模型的训练提供灵活性和高效性。

数据加载器和数据集

数据加载器(DataLoader)和数据集(Dataset)是深度学习训练过程中处理数据的关键组件。DataLoader负责批量加载数据,并可以实现数据的随机打乱和并行加载,从而提高训练效率。Dataset则定义了如何获取单个数据样本,通常包括数据的读取、预处理和转换等操作。这两个组件的结合使得数据处理变得高效且灵活,能够适应各种不同的数据格式和训练需求。
接下来,让我们来看一下 Dataset 的基本结构:
class Dataset:
    r"""An abstract class representing a `Dataset`.

    All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses must also overwrite
    :meth:`__len__`, which is expected to return the size of the dataset.
    """

    def __init__(self, transforms: Optional[List] = None):
        self.transforms = transforms

    def __getitem__(self, index) -> object:
        raise NotImplementedError

    def __len__(self) -> int:
        raise NotImplementedError
    
    def apply_transforms(self, x):
        if self.transforms is not None:
            # apply the transforms
            for tform in self.transforms:
                x = tform(x)
        return x

class MNISTDataset(Dataset):
    def __init__(
        self,
        image_filename: str,
        label_filename: str,
        transforms: Optional[List] = None,
    ):
        with gzip.open(image_filename, 'rb') as img_file:
            magic_num, img_num, row, col = struct.unpack(">4i", img_file.read(16))
            assert magic_num == 2051
            X = np.frombuffer(img_file.read(img_num * row * col), dtype=np.uint8).astype(np.float32).reshape((img_num, row * col))
            X = X.reshape(img_num, row, col, 1)
            X /= 255.0
    
        with gzip.open(label_filename, 'rb') as label_file:
            magic_num, label_num = struct.unpack(">2i", label_file.read(8))
            assert magic_num == 2049
            y = np.frombuffer(label_file.read(label_num), dtype=np.uint8)
    
        self.img = X
        self.label = y
        self.transforms = transforms

    def __getitem__(self, index) -> object:
        return self.apply_transforms(self.img[index]), self.label[index]

    def __len__(self) -> int:
        return len(self.img)
Dataset 类是一个抽象基类,定义了数据集的基本接口。它要求子类实现 getitem 方法来获取单个数据样本,以及可选的 len 方法来返回数据集的大小。
MNISTDatasetDataset 的一个具体实现,专门用于处理 MNIST 数据集。它在初始化时读取图像和标签文件,并支持对数据进行变换操作。getitem 方法返回指定索引的图像和标签对,同时应用所有指定的变换。
通过这种方式,我们可以轻松地创建自定义数据集,并将其与 DataLoader结合使用,以高效地批量加载和预处理数据。这种灵活的设计使得我们能够处理各种不同类型和格式的数据集,同时保持代码的整洁和可维护性。
以下是 DataLoader 的基本结构:
class DataLoader:
    r"""
    Data loader. Combines a dataset and a sampler, and provides an iterable over
    the given dataset.
    Args:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: ``1``).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: ``False``).
     """
    dataset: Dataset
    batch_size: Optional[int]

    def __init__(
        self,
        dataset: Dataset,
        batch_size: Optional[int] = 1,
        shuffle: bool = False,
    ):

        self.dataset = dataset
        self.shuffle = shuffle
        self.batch_size = batch_size
        if not self.shuffle:
            self.ordering = np.array_split(np.arange(len(dataset)), 
                                           range(batch_size, len(dataset), batch_size))

    def __iter__(self):
        if self.shuffle:
            self.ordering = np.array_split(np.random.permutation(len(self.dataset)), 
                                           range(self.batch_size, len(self.dataset), self.batch_size))
        self.t = 0
        return self

    def __next__(self):
        if self.t < len(self.ordering):
            train_set = self.dataset[self.ordering[self.t]]
            self.t += 1
            train_set = [Tensor(example) for example in train_set]
        else:
            raise StopIteration()
        return train_set

总结

至此,我们构建了一个完整的深度学习系统,涵盖了从底层数据结构到高级模型抽象的各个方面:
  • 底层数据结构:实现了类似Numpy的NDArray,为张量计算和自动微分奠定基础。
  • 张量和自动微分:构建了Tensor类,支持自动求导,是深度学习计算的核心。
  • 神经网络模块:创建了各种层(如Linear、ReLU)和激活函数,用于构建复杂的神经网络。
  • 损失函数:实现了常用的损失函数,如MSE和交叉熵损失,用于评估模型性能。
  • 优化器:开发了SGD、Adam等优化算法,用于更新模型参数。
  • 数据处理:实现了Dataset和DataLoader,高效处理和加载训练数据。
这个过程不仅让我深入理解了深度学习系统的各个组件,还帮助我们掌握了它们之间的交互和整体架构。这种从底层到高层的实现经验,为未来开发和优化更复杂的深度学习模型奠定了坚实的基础。