Navigated to BITESIZE | Why is Bayesian Deep Learning so Powerful? - Transcript

BITESIZE | Why is Bayesian Deep Learning so Powerful?

Episode Transcript

1 00:00:03,523 --> 00:00:16,243 So I'm curious if you can define what a Gaussian process is, because I think my audience has a good idea of what a Bayesian neural network is. 2 00:00:16,643 --> 00:00:22,703 I've had, especially recently, Vincent Fortwin talk about that on the show. 3 00:00:22,703 --> 00:00:25,183 I'll put that also in the show notes. 4 00:00:25,583 --> 00:00:30,270 So these Bayesian deep learning, I think people are 5 00:00:30,270 --> 00:00:31,831 I'm familiar with. 6 00:00:31,831 --> 00:00:35,514 Can you tell us what a deep Gaussian process is? 7 00:00:35,514 --> 00:00:40,898 I think people see what a Gaussian process is, but what makes it a deep one? 8 00:00:40,898 --> 00:00:42,248 Great episode. 9 00:00:42,248 --> 00:00:43,840 The one with Vincent, by the way. 10 00:00:43,840 --> 00:00:44,820 I checked it out. 11 00:00:44,820 --> 00:00:45,311 Thank you. 12 00:00:45,311 --> 00:00:49,574 Because I guess they would say a lot of things that I would probably say also in my episode. 13 00:00:49,574 --> 00:00:51,845 So it was great to see it. 14 00:00:52,045 --> 00:00:56,549 So yeah, so a Gaussian process, uh there's many ways which you can see it. 15 00:00:56,549 --> 00:00:58,942 The easiest way is probably to start from a linear model. 16 00:00:58,942 --> 00:01:02,843 I think I really like the construction from a linear model. 17 00:01:02,843 --> 00:01:12,676 So if we start from a linear model and we make it Bayesian, so we put a prior on the parameters, then we have analytical forms for the posterior, the predictions, everything 18 00:01:12,676 --> 00:01:14,206 is nice and Gaussian. 19 00:01:14,206 --> 00:01:21,458 And so now one nice thing we can do is to start thinking about linear regression, but now with basis functions. 20 00:01:21,458 --> 00:01:28,980 So we start introducing linear combinations, not of just the covariates or features, if you want to call them that. 21 00:01:29,071 --> 00:01:36,271 But you have a transformation that say sine and cosine could be trigonometric functions of any kind, could be polynomials. 22 00:01:36,271 --> 00:01:46,131 And it turns out that you can use kernel tricks to be able to say what the predictive distribution is going to be for this. 23 00:01:46,131 --> 00:01:52,271 The model is still linear in the parameters, but now what we can do is to take the number of basis functions to infinity. 24 00:01:52,271 --> 00:01:57,431 So we can make an large polynomial. 25 00:01:57,473 --> 00:01:59,945 And now the number of parameters will be infinite. 26 00:01:59,945 --> 00:02:14,077 But what we can do is to use this kernel, so-called kernel trick to actually express everything in terms of scalar products among this mapping of inputs to this polynomial. 27 00:02:14,077 --> 00:02:23,304 And so if you do that, then what you can do is to, instead of working with polynomials or these basis functions, now you can define a so-called kernel function, which is the one 28 00:02:23,304 --> 00:02:26,146 that takes inputs features. 29 00:02:26,146 --> 00:02:33,190 And it spits out a scalar product of these induced polynomials in this very large dimension, infinite dimensional space. 30 00:02:33,190 --> 00:02:44,216 So this kernel trick allows you to just then work with something which is infinitely uh powerful in a way, because it's infinitely flexible in a way that you have an infinite 31 00:02:44,216 --> 00:02:45,756 number of parameters now. 32 00:02:45,917 --> 00:02:53,331 But the great thing is that if you have only n observations, all you need to do is to care about what happens for this n uh observations. 33 00:02:53,331 --> 00:02:56,045 And so you can construct this covariance matrix and 34 00:02:56,045 --> 00:02:58,557 you know, it can do and everything is Gaussian again. 35 00:02:58,557 --> 00:02:59,227 It's very nice. 36 00:02:59,227 --> 00:03:09,884 So the first time you generate a function from Gaussian process, it's beautiful because you get these nice functions that look beautiful and it's just a multivariate normal 37 00:03:09,884 --> 00:03:10,365 really. 38 00:03:10,365 --> 00:03:12,546 And it's just, that's all it is, you know? 39 00:03:12,546 --> 00:03:21,231 So I still remember the first time I generated the function from a GP because it was a eureka moment, you know, where you realize how simple and beautiful this is. 40 00:03:21,392 --> 00:03:24,940 And, and now, so then you can think that now. 41 00:03:24,940 --> 00:03:27,811 This is, this represents a distribution over functions. 42 00:03:27,811 --> 00:03:32,833 So if you draw from this uh GP, you obtain samples that are functions. 43 00:03:33,074 --> 00:03:43,618 And now what you can do is to say, well, what if I take this function now and instead of just observing this function alone, I just put it inside as an input to another Gaussian 44 00:03:43,618 --> 00:03:44,438 process. 45 00:03:44,438 --> 00:03:50,441 So in a GP, you have inputs, which are your input data where you have observations. 46 00:03:50,521 --> 00:03:53,582 So now you're mapping into functions. 47 00:03:53,846 --> 00:04:04,049 And then this function can become now the input to another uh GP, for example, you know, and then you can even say, okay, let's take these inputs and map them not just to a 48 00:04:04,049 --> 00:04:09,210 univariate Gaussian process where we have just one function, but maybe we can map it into 10 functions. 49 00:04:09,210 --> 00:04:13,131 And then these 10 functions become the input to a new Gaussian process. 50 00:04:13,131 --> 00:04:18,403 And so this would be a, a one layer deep Gaussian process, right? 51 00:04:18,403 --> 00:04:23,394 So you have now one layer, which is first hidden functions that then enter 52 00:04:23,394 --> 00:04:27,618 the, as input to another Gaussian process. 53 00:04:27,618 --> 00:04:28,859 What's the advantage of this? 54 00:04:28,859 --> 00:04:30,080 Why do we do this? 55 00:04:30,080 --> 00:04:41,629 Well, you know, with Gaussian process, yeah, so with Gaussian process is the, the, the characteristics that you observe for the functions that you will generate are determined 56 00:04:41,629 --> 00:04:43,761 by the choice of the covariance function. 57 00:04:43,761 --> 00:04:52,598 So if you take a covariance function, which is a RBF, you're going to have infinitely smooth functions that you generate. 58 00:04:52,928 --> 00:05:02,694 And the way these functions are going to be, the length scale of these functions and the amplitude, they're going to be determined by the parameters that you put in the covariance 59 00:05:02,694 --> 00:05:03,644 function. 60 00:05:03,825 --> 00:05:09,378 And of course, you know, there might be problems where, you you have no stationarity. 61 00:05:09,378 --> 00:05:13,241 So in a part of the space, functions should be nice and smooth. 62 00:05:13,241 --> 00:05:17,173 In other parts of the space, maybe you want more flexibility. 63 00:05:17,333 --> 00:05:19,254 And then, you know, 64 00:05:19,270 --> 00:05:24,634 A Gaussian process with a standard covariance function cannot achieve that. 65 00:05:24,995 --> 00:05:35,590 And so in order to increase flexibility, you either spend time designing kernels that actually can do crazy things, which is possible, but relatively hard because now you have 66 00:05:35,590 --> 00:05:36,965 a lot of choices. 67 00:05:36,965 --> 00:05:39,347 You can combine kernels in multiple ways. 68 00:05:39,347 --> 00:05:46,052 And if you have a space of possible kernels you want to choose from, combining them, you know, becomes a combinatorial problem. 69 00:05:46,052 --> 00:05:48,418 So you may say instead, let's just... 70 00:05:48,418 --> 00:05:51,591 compose functions and composition is very powerful. 71 00:05:51,591 --> 00:05:53,383 And this is why deep learning works. 72 00:05:53,383 --> 00:05:57,767 Because in deep learning, you essentially have function compositions. 73 00:05:57,767 --> 00:06:04,894 And so even if you compose simple things, the result is something very complicated and you can try it yourself. 74 00:06:04,894 --> 00:06:08,757 You know, take a sine function and put it into another sine function. 75 00:06:08,757 --> 00:06:12,821 If you play around with the parameters, you can get things that oscillates in a crazy way. 76 00:06:13,122 --> 00:06:14,112 And this is 77 00:06:14,274 --> 00:06:16,575 Very simple, but very powerful. 78 00:06:16,575 --> 00:06:29,823 And so the idea of deep Gaussian process is exactly this, to try to enrich the kind of class of functions you can obtain by composing functions, composing Gaussian processes. 79 00:06:29,823 --> 00:06:34,485 And of course, now the marginals, you know, in a Gaussian process, all the marginals are nice and Gaussian. 80 00:06:34,485 --> 00:06:37,487 If you compose, these marginals become non-Gaussian. 81 00:06:37,487 --> 00:06:43,430 And this is really, you know, getting to the point where you start thinking, well, why should we then... 82 00:06:43,468 --> 00:06:49,522 restrict ourselves to composing processes that are Gaussian, maybe we can do something else. 83 00:06:49,562 --> 00:06:58,589 And then maybe thinking about other ways in which you can be flexible in the way you parametrize these complicated conditional distributions. 84 00:06:58,589 --> 00:06:59,029 Okay. 85 00:06:59,029 --> 00:06:59,629 Yeah. 86 00:06:59,629 --> 00:07:00,900 Damn, this is super fun. 87 00:07:00,900 --> 00:07:05,974 So it sounds to me like Fourier decomposition on steroids, basically. 88 00:07:05,974 --> 00:07:12,748 it's like decomposing everything through these basis functions and plugging everything into... 89 00:07:12,842 --> 00:07:14,783 into, into each other. 90 00:07:14,783 --> 00:07:20,574 like, um, you know, like these mamushkas of Gaussian processes, basically. 91 00:07:20,574 --> 00:07:21,535 So, yeah. 92 00:07:21,535 --> 00:07:23,985 And I can definitely see the power of that. 93 00:07:23,985 --> 00:07:28,836 like, yeah, it's, it's like having very deep neural networks, basically. 94 00:07:28,836 --> 00:07:33,008 So I see, I definitely see the connection and why that would be super helpful. 95 00:07:33,008 --> 00:07:40,030 Um, and that helps, I'm guessing that helps uncover. 96 00:07:40,546 --> 00:07:47,490 very complex non-linear patterns that are very hard to express in a functional form. 97 00:07:47,570 --> 00:07:51,013 That functional form would be, well, you have to choose the kernels. 98 00:07:51,013 --> 00:07:58,197 And sometimes, as you were saying, the out of the box kernels can't express the complexity you have in the data. 99 00:07:58,197 --> 00:08:05,521 then having basically the machine discover the kernels by itself is much easier. 100 00:08:06,326 --> 00:08:09,268 And it's really also about the marginals. 101 00:08:09,268 --> 00:08:14,911 If you believe that your marginals can be Gaussian and you're happy with that, then it's all fine. 102 00:08:14,911 --> 00:08:16,332 You can do kernel design. 103 00:08:16,332 --> 00:08:22,956 You can spend a bit of time trying to find a good kernel that gives you good fit to the data, good modeling, good uncertainties. 104 00:08:22,956 --> 00:08:28,599 But then there's still going to be this constraint in a way that you're working with the Gaussian process. 105 00:08:28,599 --> 00:08:30,600 In the end, marginally, everything is Gaussian. 106 00:08:30,600 --> 00:08:34,432 You may not want that in certain applications where it may be the... 107 00:08:34,446 --> 00:08:43,906 Distributions are very skewed and other things, you know, and then maybe the skewness also is position depend input dependent, you know, so this non-stationarity also, again, you 108 00:08:43,906 --> 00:08:48,966 can encode it in certain kernels, you know, but it's just so much easier to compose. 109 00:08:48,966 --> 00:08:58,226 mean, from the principle of just a mathematical composition, then of course, computationally how to handle this, this is another pair of hands. 110 00:08:58,226 --> 00:08:58,686 Yeah, yeah, yeah. 111 00:08:58,686 --> 00:08:59,166 No, exactly. 112 00:08:59,166 --> 00:09:04,386 mean, you're trading, you're trading. 113 00:09:04,386 --> 00:09:11,052 basically something that's more comfortable for the user for something that's much harder to compute for the computer. 114 00:09:11,052 --> 00:09:22,752 But yeah, like in the end, that also can be something that is more transferable because if you have, unless you're a deep expert in Gaussian processes, coming up with your own 115 00:09:22,752 --> 00:09:26,785 kernels each time you need to work on a project is very time consuming. 116 00:09:26,785 --> 00:09:32,970 So it can be actually worth your time to turn into the deep Gaussian processes framework. 117 00:09:32,974 --> 00:09:42,219 throw computing power at it and, you know, go your merry way working on something in the meantime while the computer samples. 118 00:09:42,219 --> 00:09:43,880 definitely makes sense. 119 00:09:43,880 --> 00:09:49,824 but again, the deep aspect carries other design choices. 120 00:09:49,824 --> 00:09:54,886 Now you have to choose how many layers, what's the dimensionality of each layer. 121 00:09:54,886 --> 00:10:01,580 So, and then there is this other uh problem of now what kind of inference you choose. 122 00:10:01,922 --> 00:10:03,802 which definitely has an effect. 123 00:10:03,943 --> 00:10:08,982 So we've done some studies on this, you know, trying to compare a little bit, various approaches. 124 00:10:08,982 --> 00:10:16,267 I mean, we did this a few years ago now because the deep, I think we started working on this right after TensorFlow came out. 125 00:10:16,267 --> 00:10:18,207 So this was 2016. 126 00:10:18,207 --> 00:10:23,219 So we started doing, we did our deep GP with a certain kind of approximation that is not very popular. 127 00:10:23,219 --> 00:10:27,150 I the community seems to have agreed that 128 00:10:27,550 --> 00:10:31,723 know, inducing points methods are very powerful to do approximations. 129 00:10:31,723 --> 00:10:43,442 you know, I've also done some work on that with some great people, particularly James Hansman, who has developed the GP flow with some other great guys. 130 00:10:43,783 --> 00:10:50,849 but random features is what you said before, you mentioned the Fourier transform uh on steroids. 131 00:10:50,849 --> 00:10:52,590 mean, the idea is really to... 132 00:10:52,694 --> 00:10:57,785 You know, for certain classes of kernels, you can do some sort of expansions and sort of linearize the Gaussian process. 133 00:10:57,785 --> 00:11:04,737 So before I was talking about going from a linear model to something which is uh infinite number of basis functions. 134 00:11:04,737 --> 00:11:08,388 And now the idea is just truncate this number of basis functions. 135 00:11:08,388 --> 00:11:10,499 You know, you can do it in various ways. 136 00:11:10,499 --> 00:11:15,140 You know, there is a randomized version that we do when we do these random features. 137 00:11:15,140 --> 00:11:18,461 uh And then you sort of truncate. 138 00:11:18,461 --> 00:11:21,222 And so now instead of working with this uh 139 00:11:21,346 --> 00:11:26,688 You turn a Gaussian process into a linear model with a uh large number of basis functions. 140 00:11:26,688 --> 00:11:28,688 And then linear models are nice to work with. 141 00:11:28,688 --> 00:11:32,709 And then if you compose them, then that's when you get the deep Gaussian process. 142 00:11:32,709 --> 00:11:38,491 Essentially you get the deep neural network with some stochasticity in the layers. 143 00:11:38,491 --> 00:11:40,871 And that's all there is to it. 144 00:11:40,871 --> 00:11:48,793 And so when we did this, we implemented it in TensorFlow because it was the new thing and ah it was very scalable. 145 00:11:48,793 --> 00:11:51,074 know, we took some competitors. 146 00:11:51,386 --> 00:11:58,048 And we really, you know, we're really fast at converging to good solutions and getting good results, you know. 147 00:11:58,048 --> 00:12:02,199 So, and we have an implementation out there in TensorFlow, unfortunately. 148 00:12:02,199 --> 00:12:08,491 I we should now maybe port it to PyTorch, which has become what we work on more. 149 00:12:10,071 --> 00:12:10,851 No, for sure. 150 00:12:10,851 --> 00:12:20,280 I mean, yeah, that's definitely, that's definitely linked to that, to that TensorFlow implementation that you have because yeah, I'm very big on 151 00:12:20,280 --> 00:12:33,439 pointing people towards how they can apply that in practice and basically making the bridge between frontier research as you're doing and then helping people implement that in 152 00:12:33,439 --> 00:12:36,021 their own modeling workflows and problems. 153 00:12:36,021 --> 00:12:37,582 So let's definitely do that. 154 00:12:37,582 --> 00:12:42,818 um And yeah, I was actually going to ask you, okay, so that's... 155 00:12:42,818 --> 00:12:47,101 That's a great explanation and thank you so much for laying that out so, so clearly. 156 00:12:47,101 --> 00:13:01,934 think it's awesome to start from the linear representation, as you were saying, and basically, um yeah, going to the very big deep GPs, which are in a way easier for me to 157 00:13:01,934 --> 00:13:07,698 represent to myself because they, you know, it's like in the infinity, in the limit. 158 00:13:07,698 --> 00:13:12,522 It's easier I find to work with than deep neural networks, for instance. 159 00:13:12,522 --> 00:13:21,288 But yes, like, can you give us a lay of the land of how, what's the field about right now? 160 00:13:21,288 --> 00:13:25,781 Let's start with the practicality of it. 161 00:13:25,781 --> 00:13:30,855 What would you recommend for people? 162 00:13:30,855 --> 00:13:35,438 In which cases would these DeepGPs be useful? 163 00:13:35,438 --> 00:13:41,410 First and second question, why wouldn't they use just 164 00:13:41,410 --> 00:13:44,571 deep tool networks instead of deep GP's. 165 00:13:44,571 --> 00:13:45,491 Let's start with that. 166 00:13:45,491 --> 00:13:47,752 I have a lot of other questions, but let's start with that. 167 00:13:47,752 --> 00:13:49,432 think it's the most general. 168 00:13:49,632 --> 00:13:50,233 Yeah. 169 00:13:50,233 --> 00:13:50,463 Yeah. 170 00:13:50,463 --> 00:13:52,453 I think, I mean, it's a, it's a great question. 171 00:13:52,453 --> 00:13:55,004 It's a, it's the mother of all questions really. 172 00:13:55,004 --> 00:13:57,325 I mean, what kind of model should you choose for your data? 173 00:13:57,325 --> 00:14:07,277 And I think, I think that is going to be a lot of great work that is going to happen soon where we, we're going to maybe be able to give more definite answers to this. 174 00:14:07,277 --> 00:14:08,938 You know, I think. 175 00:14:09,356 --> 00:14:17,329 We're starting to realize that this overparameterization that we see in deep learning is not so bad after all. 176 00:14:17,329 --> 00:14:26,405 know, so for someone working in business statistics, I think we have this image in mind where, you know, we should find the right complexity for the data that we have. 177 00:14:26,405 --> 00:14:33,760 So there's going to be a sweet spot of a model that is sort of parsimonious in looking at the data and, know, not too parameterized. 178 00:14:33,760 --> 00:14:38,642 But actually deep learning is telling us now a different story, which is not different from 179 00:14:38,713 --> 00:14:42,665 the story that we know for non-parametric modeling, for Gaussian processes. 180 00:14:42,665 --> 00:14:45,768 In Gaussian processes, we push the number of parameters to infinity. 181 00:14:46,109 --> 00:14:52,093 And in deep learning now we're sort of doing the same, but in a slightly mathematical different form. 182 00:14:54,315 --> 00:15:07,205 So where we're getting at is a point where actually this enormous complexity is in a way facilitating certain behaviors for these models to be able to represent our data in a very 183 00:15:07,205 --> 00:15:08,216 simple way. 184 00:15:08,216 --> 00:15:13,208 So the emergence of simplicity seems to be connected to this explosion in parameters. 185 00:15:13,208 --> 00:15:24,373 And I think Andrew Wilson has done some amazing work on this and it's recently published and I can link you to that paper, which says, deep learning is not so mysterious. 186 00:15:24,373 --> 00:15:28,014 uh And it's something I was reading recently. 187 00:15:28,014 --> 00:15:29,475 It's beautiful read. 188 00:15:30,715 --> 00:15:33,616 I think, you know, to go back to your question, so today, what should we do? 189 00:15:33,616 --> 00:15:35,897 Should we stick to a GP? 190 00:15:35,897 --> 00:15:37,718 Should we go for a deep neural network? 191 00:15:37,718 --> 00:15:43,000 I think for certain problems, we may have some understanding of the kind of functions we want. 192 00:15:43,080 --> 00:15:52,924 so for those, if it's possible and easy to encode them with the GPs, I think it's definitely a good idea to go for that. 193 00:15:52,985 --> 00:16:04,209 But there might be other problems where we have no idea or maybe there is too many complications in the way we can think about the uncertainties and other things. 194 00:16:04,209 --> 00:16:06,220 And so maybe just throwing a... 195 00:16:06,432 --> 00:16:12,323 A data driven, I mean, if we have a lot of data, maybe we can say, okay, maybe we can go for an approach that is data hungry. 196 00:16:12,324 --> 00:16:17,925 And then, you know, we can leverage that and deep learning seems to be like maybe a right choice there. 197 00:16:17,925 --> 00:16:27,278 But of course now, there is also a lot of stuff happening in other uh spaces, let's say in terms of foundation models. 198 00:16:27,278 --> 00:16:34,329 So now there is this class, this breed of new things, new models that have been trained on a lot of data. 199 00:16:36,419 --> 00:16:40,380 with some fine tuning on your small data, you can actually adopt them. 200 00:16:40,380 --> 00:16:49,990 You know, this transfer learning actually works and we've done it for, so there's this paper again by Andrew Wilson on predicting time series with language models. 201 00:16:49,990 --> 00:17:01,109 So you take chat GPT and you make it predict, you discretize your time series, you tokenize and you give it to GPT and you look at the predictions, you invert the 202 00:17:01,109 --> 00:17:02,040 transformation. 203 00:17:02,040 --> 00:17:04,982 you get back a scalar values. 204 00:17:04,982 --> 00:17:07,724 And actually this seems to be working quite well. 205 00:17:07,724 --> 00:17:13,708 So we tried now for, with the multivariate versions of this probabilistic multivariate and so on. 206 00:17:13,708 --> 00:17:15,769 So we've done some work on that also. 207 00:17:15,790 --> 00:17:25,827 But just to say that, I mean, now this is something also kind of new that is happening, you know, because before maybe it was really hard to train these models at such a large 208 00:17:25,827 --> 00:17:26,257 scale. 209 00:17:26,257 --> 00:17:33,772 But now if you train a model on the entire web with all the language, language is Markovian in a way. 210 00:17:33,772 --> 00:17:37,354 So, you know, these Markovian structures are sort of learned by these models. 211 00:17:37,354 --> 00:17:44,769 And now if you feed these models with the stuff that is Markovian, it will try to make a prediction that is actually going to be reasonable. 212 00:17:44,769 --> 00:17:48,571 And this is what we've seen in the literature. 213 00:17:48,571 --> 00:18:00,138 And all these things are, I think are going to change a lot of the way we think about designing a model for the data we have and how we we do inference and all these things. 214 00:18:00,138 --> 00:18:00,992 So. 215 00:18:00,992 --> 00:18:10,219 So as of today, think maybe still is relevant to think about, okay, if I have a particular type of data, I know that, you know, it makes sense to use a Gaussian process because I 216 00:18:10,219 --> 00:18:11,790 want certain properties in the functions. 217 00:18:11,790 --> 00:18:18,234 want certain, you know, maternal, for example, gives us some sort of smoothness up to a certain degree. 218 00:18:18,234 --> 00:18:23,598 And it's easy to encode length scales of these functions for the prior of the functions. 219 00:18:23,598 --> 00:18:26,938 And this is great, you know, for neural networks, this is very hard to do. 220 00:18:26,938 --> 00:18:29,658 So we've done some work trying to map the two, right? 221 00:18:29,658 --> 00:18:40,178 So we try to say, okay, what can we make a neural network imitate what Gaussian processes do so that we gain sort of the interpretability and the nice properties of a Gaussian 222 00:18:40,178 --> 00:18:41,158 process. 223 00:18:41,158 --> 00:18:52,958 But then we also inherit the flexibility and the power of this deep learning model so that they can really perform well and also give us sound uncertainty quantification. 224 00:18:52,958 --> 00:18:53,262 Yeah. 225 00:18:53,262 --> 00:18:54,642 Okay, yeah, yeah. 226 00:18:56,463 --> 00:18:59,458 Be sure you had to be a good peasy.

Never lose your place, on any device

Create a free account to sync, back up, and get personal recommendations.