How to Train Deep Neural Networks Over Data Streams
We know how to train neural networks on regular datasets. But, what do we do if we only have access to a bit of data at a time?
Historically, many machine learning algorithms have been developed to handle, and learn from, incoming streams of data. For example, models such as SVMs and logistic regressors have been generalized to settings in which the entire dataset is not available to the learner and training must be conducted over an incoming, sequential stream of data [1, 2]. Similarly, many clustering algorithms have been proposed for learning over data streams [3]. These methodologies force the underlying model to learn from a continuous data stream that becomes available one example at a time, eliminating the need for the entire dataset to be available at once. Interestingly, although approaches for streaming learning have been developed for more traditional machine learning algorithms, streaming learning is not widely explored for deep neural networks, where offline training (i.e., performing several loops/epochs over the full dataset) dominates.
Aiming to close this gap, recent work in the deep learning community has explored the possibility of training deep networks via streaming. Within the deep learning domain, streaming can be described as a learning setup in which (i) the dataset is learned one example at a time (i.e., the dataset is presented as an incoming stream of data), (ii) each unique example in the dataset is seen only once, (iii) the ordering of data in the stream is arbitrary, and (iv) the model being trained can be evaluated at any point within the data stream. In comparison to the typical, offline method of training neural networks, this setup may seem pretty harsh, which explains why achieving high performance in the streaming setting is often difficult. Nonetheless, streaming learning is reflective of many applications in industry and, if utilized correctly, provides a powerful tool for the deep learning practitioner.
Within this post, I will provide an overview of streaming within deep learning. I will begin by motivating the use of streaming learning, focusing upon the utility of streaming learning in practical applications. I will then overview existing approaches for training deep neural networks in the streaming setting, emphasizing the approaches that are most useful in practice. Through this post, I aim to i) illustrate the relationship between streaming learning and deep learning in general and ii) outline useful findings for leveraging the power of streaming learning in actual applications.
Why streaming learning?
How does streaming relate to other training paradigms? Prior to the proposal of streaming learning [4], numerous training setups were studied that explore different strategies of sequentially exposing partial, disjoint subsets of data to the model instead of training over the entire dataset in an offline fashion. Within this post, I will refer to all of such methodologies (including streaming learning) collectively as “online” learning setups. Typically, online learning divides the dataset into several disjoint “batches” of data, and the model is exposed to one batch at a time. Once a batch of data is learned, the model cannot return to it later in the training process (i.e., only “current” data can be directly accessed to train the model).
Many different variants of online learning have been proposed (e.g., lifelong learning, continual learning, batch/class-incremental learning, streaming learning, etc.). Each of these different variants refers to the same concept of online learning described above, usually with a minor change to the experimental setup. For example, lifelong learning tends to learn different tasks in sequence (i.e., each task could be considered a subset/batch of data from the dataset containing all tasks), while class incremental learning learns a strict subset of classes for an overall classification problem within each batch (e.g., the first 20 classes of CIFAR100). Notably, each batch of data during the online learning process can be quite large. For example, class-incremental learning experiments are often performed on ImageNet, where each batch contains a 100-class subset of the full dataset (i.e., ~130K data examples).
Streaming learning can also be interpreted as a variant of the above description for online learning, where each “batch” of data is just a single example from the dataset. Yet, streaming learning seems to deviate from the usual online learning setup more than other, related methodologies. Namely, streaming learning, because it is restricted to learning the dataset one example at a time, cannot perform arbitrary amounts of offline training each time a new batch of data becomes available. This is because only a single data example is made available to the model at a time, and performing several updates over the same data in sequence can quickly deteriorate model performance. Thus, streaming learning methodologies tend to perform brief, real-time model updates as new data becomes available, whereas other online learning methodologies often perform expensive, offline training procedures to learn each new batch.
Why is online learning difficult? When one first encounters the topic of online learning, they may think that solving this problem is quite easy. Why not just fine-tune the model on new data as it becomes available? Such a naive approach works in some cases. In particular, it will work if the incoming data stream is i.i.d., meaning that each piece of incoming data is sampled uniformly from the full space of possible data examples. In this case, learning from the data stream is the same as uniformly sampling an example from the dataset for each training iteration (i.e., this is just stochastic gradient descent!), and naive fine-tuning works quite well. However, we cannot always ensure that incoming data is i.i.d. In fact, many notable practical applications are characterized by non-i.i.d. streams of data (e.g., video streams, personalized behavior tracking, object tracking, etc.).
Although online learning is easy to solve when the incoming data is i.i.d., an interesting phenomenon occurs when the data is made non-i.i.d. — the model learns from the new data, but quickly forgets everything that was learned previously. For example, in the case of class-incremental learning, a model may begin learning how to classify horses (i.e., some class it has not encountered before), but completely forget how to classify dogs, cats, squirrels, and all other animals that is had learned to classify in the past. This problem — commonly referred to as catastrophic forgetting [5, 6] — is the fundamental issue faced by all online learning techniques. Namely, because data streams are often non-i.i.d., models learned in an online fashion typically suffer from catastrophic forgetting, which significantly impacts their performance (especially on data that is not recently-observed).
Streaming Learning is more practical. Now that we have a better understanding of streaming learning, online learning in general, and the issues faced by both, one may ask the question: why focus on streaming in particular? The main reasons are that streaming learning i) occurs in real-time and ii) better reflects common learning paradigms that arise in practice.
Because learning occurs one example at a time within streaming, model updates tend to be brief (i.e., one or a few forward/backward passes per example). As a result, minimal latency exists between the arrival of a new data example and the adaptation of the underlying model to that data example — the model is updated in real time as new data becomes available. In comparison, other commonly-studied online learning setups may suffer latency when i) waiting for a sufficiently-large batch of new data to accumulate or ii) updating the model after the new batch of data becomes available — many forward/backward passes must be performed to update the model over a large batch of data, especially if several loops over the data are performed. Though such alternative experimental setups for online learning are interesting from a research perspective, why would any practitioner wait for data to accumulate when they have the ability to update the model after the arrival of each new sample?
The ability of streaming learning to adapt models to streams of data in real time also has wide applications in industry. Consider, for example, a recommendation system that performs dynamic updates each time a user interacts with a website (e.g., a purchase, a click, or even a movement of the mouse). Alternatively, one could utilize a deep network to perform video interpolation (e.g., predict where a person will be in the next frame given the previous several frames) and leverage streaming learning to update this model over the video stream based on the mistakes it makes in each frame. The possibilities of streaming learning are nearly endless, as it applies to any situation in which a deep network should immediately learn from incoming data. Thus, it is (in my opinion) a topic worthy of focus for deep learning practitioners.
Methodologies for Deep Streaming Learning
Now that streaming learning has been defined and motivated, it’s time to learn how neural networks can actually be trained over data streams without severely degrading their performance. Recently, several algorithms for training deep neural networks in a streaming fashion have been proposed within the deep learning community [4, 7, 8]. For each algorithm, I will overview the main components and details of the methodology, as well as highlight major practical considerations for implementing the algorithm in practice. Throughout this section, I focus on the major details of each approach that allow practitioners to better understand which methodologies are most suitable for a given application.
ExStream [4]
ExStream, proposed in February of 2019, is a replay-based methodology for streaming learning. Here, “replay” (also referred to as rehearsal) is used to describe methods that store previously-encountered data from the incoming data stream within a buffer. Then, when new data becomes available, replay-based methodologies mix the new data with samples of data from the replay buffer and use this mix of new and old data to update the model, thus ensuring that previous knowledge is retained. Put simply, these methodologies train the network with a mix of new and old data to ensure that the network is exposed to a balanced set of examples from the data stream. Replay is a widely-used methodology within online learning that is both simple and effective, but it does require storage of previous data, which creates a non-negligible memory footprint.
Although ExStream demonstrates that “full” replay (i.e., storing all incoming data in the replay buffer and looping through all examples in the replay buffer each time new data is encountered) eliminates catastrophic forgetting, more memory-efficient replay mechanisms that don’t require storage of the entire data stream are also explored. Using a ResNet50 [9] to perform image classification, ExStream pre-trains and fixes the convolutional layers of the underlying model (i.e., none of the convolutional layer parameters are updated during the streaming process) and focuses upon learning the final, fully-connected layer in a streaming fashion. Thus, all examples stored within the replay buffer are simply vectors (i.e., the output of the ResNet’s feature extractor/backbone). Using these feature vectors as input, ExStream maintains a separate, fixed-size replay buffer for each class of data (i.e., only c
vectors can be stored per class) and aims to discover an algorithm that i) minimizes the number of vectors that must be stored for replay (i.e., this limits memory overhead) and ii) maintains state-of-the-art classification performance (i.e., comparable to full replay).
To achieve these goals, ExStream leverages the following rule set for maintaining its replay buffer during streaming:
Maintain
c
cluster centroids (i.e., just vectors!) per class, each of which has a “count” associated with it.Until
c
examples for a class have been encountered within the data stream, simply add each vector to the replay buffer with a count of one.Once the buffer for a given class is full and a new example for that class arrives, find the two closest clusters centroids (based on Euclidean distance) and merge them together by taking a weighted average of the vectors based on their respective counts. Set the count of the resulting centroid to the sum of the counts for the two previous centroids.
Once the two centroids have been merged (thus making room for a new one), add the new data example into the buffer with a count of one.
Using the replay buffer that’s maintained as described above, ExStream then performs replay by sampling and mixing cluster centroids from different classes with incoming data to perform updates on model parameters. In comparison to several other algorithms for maintaining cluster centroids within the replay buffer (e.g., online k-means, CluStream [10], HPStream [11], etc.), ExStream is shown to yield the best performance. Further, the memory footprint of the algorithm can be easily tuned by adjusting the number of centroids permitted for each class within the replay buffer (although a replay buffer that is too small could yield poor performance). ExStream is shown to perform well on Core50 and iCUB datasets, but is not applied to large-scale classification problems (e.g., ImageNet) until later work was published.
Deep SLDA [7]
Deep Streaming Linear Discriminant Analysis (SLDA), proposed in April of 2020, is another methodology for deep streaming learning that moves away from replay-based methodologies (i.e., it does not maintain any replay buffer). As a result, it is quite memory efficient in comparison to methods like ExStream that require a replay buffer, thus (potentially) making it appropriate for memory constrained learning scenarios (e.g., on-device learning). SLDA is an already-established algorithm [12] that has been used for classification of data streams within the data mining community. Within Deep SLDA, the SLDA algorithm is combined with deep neural networks by i) using a fixed ResNet18 [9] backbone to obtain a feature vector for each data example and ii) employing SLDA to incrementally classify such feature vectors in a streaming fashion. Again, all network layers are fixed during the streaming process for Deep SLDA aside from the final classification module (i.e., the SLDA component).
The specifics behind classifying feature vectors with SLDA in an incremental fashion is beyond the scope of this post — the algorithm is complex and probably deserves and entire blog post of its own to truly develop an understanding. At a high level, however, SLDA operates by i) maintaining a single mean vector per class with an associated count and ii) constructing a shared covariance matrix that characterizes the relationships between class representations. The mean vector is updated for each class as new data becomes available, while the covariance matrix can either be kept fixed (after some base initialization over a subset of training data) or updated incrementally during the streaming process. At test time, the class output for a new data example can be inferred using a closed-form matrix multiplication of mean class vectors with the inverse of the covariance matrix.
SLDA has benefits over replay-based methodologies like ExStream because memory requirements are reduced significantly — only a single vector per-class and a shared covariance matrix must be stored. Furthermore, SLDA is shown to yield impressive classification performance even on large-scale datasets like ImageNet, outperforming popular methods like ExStream [4], iCarl [13], and End-to-End Incremental Learning [14]; see Table 1 of [8]. Additionally, in comparison to normal, offline neural network training over large-scale datasets, the wall-clock training time of Deep SLDA is nearly negligible. Overall, the method is surprisingly effective at scale given its minimal computation and memory requirements.
REMIND [8]
REMIND, published in July of 2020, is a recently-proposed, replay-based methodology for deep streaming learning. Instead of maintaining cluster centroids within the replay buffer like ExStream, REMIND stores separate buffer entries for each data example that is encountered within the incoming stream. In previous work, this is done by storing raw images within the replay buffer [13, 14]. Within REMIND, however, the authors propose that intermediate activations within the neural network (i.e., these activations are not just a vector — they may have spatial dimensions) should be stored within the replay buffer instead of raw images, which i) drastically reduces per-sample memory requirements and ii) mimics the replay of compressed memories within the brain as outlined by hippocampal indexing theory.
To make the above strategy possible, REMIND adopts a ResNet18 [9] architecture and freezes the initial layers of the network so that the parameters of these layers do not change during the streaming process. Similar to Deep SLDA, the values of these frozen parameters are set using some base initialization phase over a subset of training data. Then, for each example encountered during streaming, REMIND i) passes the example through the network’s frozen layers to extract an intermediate activation, ii) compresses the activation tensor using a product quantization (PQ) strategy [15], and iii) stores the quantized vector within the replay buffer. Then, online updates are performed on the final network layers (i.e., those that are not frozen) with a combination of new data (after it has been quantized) and sampled activations from the replay buffer as input. In practice, REMIND performs random crops and mixup on activations sampled from the replay buffer, which provides regularization benefits and yields moderate performance improvements.
Due to REMIND’s memory efficient approach to replay, it can maintain incredibly large replay buffers with limited overhead. For example, on the ImageNet dataset, REMIND can maintain a buffer of ~1M examples with the same memory footprint as a replay buffer of 10K raw images. As a result, REMIND outperforms both ExStream and Deep SLDA significantly on numerous common benchmarks for online learning; see Table 1 in [8]. The performance benefits of REMIND are especially noticeable on large-scale datasets (e.g., Imagenet), where the ability to store many replay examples with limited memory allows REMIND to truly differentiate itself. Currently, REMIND is the best-performing approach for training deep networks in the streaming domain.
Connections to general online learning
Now that the existing approaches to deep streaming learning have been outlined, one may begin to ask how these approaches relate to methodologies that have been proposed for other online learning setups (e.g., batch-incremental learning or lifelong learning). For a more comprehensive description of all methodologies that have been proposed for online learning, I recommend my previous post on this topic. However, I try to overview these methodologies at a high level below.
Replay: Widely-used in both streaming learning and online learning in general.
Knowledge Distillation: Widely utilized within online learning techniques, but not yet explored for streaming learning. Although knowledge distillation may provide some benefit to streaming learning techniques, several recent papers argue that knowledge distillation provides minimal benefit when combined with replay [16, 17].
Bias Correction: Not tested within the streaming setting, but very beneficial for incremental learning.
Several other methodologies for online learning exist that have not been explored within the streaming setting (e.g., architectural modifications or regularization-based approaches). However, many of these methodologies are less popular within current online learning research, as they do not perform well when applied to larger-scale problems. As such, it is unlikely that these approaches would outperform proven methodologies for large-scale streaming learning, such as REMIND.
What’s best to use in practice?
Though all of the methodologies for deep streaming learning overviewed within this post are useful, a practitioner may wonder which method is most appropriate to implement for their application. In terms of performance, REMIND is the best-performing methodology that has been proposed for deep streaming learning to date. This can be seen within Table 1 of the paper for REMIND [8], where REMIND is shown to significantly outperform both ExStream and Deep SLDA. Additionally, REMIND has even been extended to problem domains such as object detection [16], thus showcasing its utility in applications beyond image classification.
Both ExStream and REMIND require the storage of a replay buffer and have similar computational efficiency, thus making REMIND the obvious choice between the two. However, Deep SLDA does not require such a replay buffer to be maintained and can be trained very quickly. As such, even though REMIND achieves better performance, Deep SLDA may be favorable in scenarios with limited memory or computational resources. Otherwise, REMIND is the best option for deep streaming learning in practice, as it achieves impressive performance and minimizes the memory footprint of the replay buffer through quantization.
For those who are interested, implementations of REMIND [17] and Deep SLDA [18] are both publicly available via github.
Conclusions
In this post, I overviewed the training of deep neural networks over data streams, including a discussion of why such a training paradigm is practically relevant and a description of existing methodologies for training deep networks in this fashion. Of the relevant methodologies for deep streaming learning, REMIND achieves the best performance, while approaches such as Deep SLDA may be useful in cases with limited memory or computational resources.
Thank you so much for reading this post. If you have any feedback or generally enjoyed the post and want to keep up with my future work, feel free to follow me on twitter or visit my website. This work was done as part of my job as a Research Scientist at Alegion and a PhD student at Rice University. If you enjoyed the material in this post, I encourage you to check out open positions at Alegion, or get in contact with my research lab!
Citations
[1] https://arxiv.org/abs/1412.2485
[2] https://ieeexplore.ieee.org/abstract/document/8622392
[3] https://epubs.siam.org/doi/abs/10.1137/1.9781611974317.7
[4] https://arxiv.org/abs/1809.05922
[5] https://www.sciencedirect.com/science/article/abs/pii/S0079742108605368
[6] https://arxiv.org/abs/1708.02072
[7] https://arxiv.org/abs/1909.01520
[8] https://arxiv.org/abs/1910.02509
[9] https://arxiv.org/abs/1512.03385
[10] https://www.vldb.org/conf/2003/papers/S04P02.pdf
[12] https://ieeexplore.ieee.org/document/1510767
[13] https://arxiv.org/abs/1611.07725
[14] https://arxiv.org/abs/1807.09536
[15] https://lear.inrialpes.fr/pubs/2011/JDS11/jegou_searching_with_quantization.pdf
[16] https://arxiv.org/abs/2008.06439