ahead.Base

1from .Base import Base
2
3__all__ = ["Base"]
class Base:
 15class Base(object):
 16
 17    def __init__(self, h=5, level=95, date_formatting="ms", seed=123):
 18
 19        self.h = h
 20        self.level = level
 21        self.date_formatting = date_formatting
 22        self.seed = seed
 23        self.frequency = None
 24        self.series_names = None
 25        self.n_series = None
 26        self.type_input = "univariate"  # (or "multivariate")
 27        self.B = None
 28        self.input_df = None
 29        self.input_dates = None
 30        self.method = None
 31        self.weights = None
 32        self.type_pi = None 
 33        self.type_conformalize = None
 34        self.type_sim_conformalize = None
 35        self.type_aggregation = None
 36        self.type_clustering = None
 37        self.lags = None
 38        self.lags_ = None  # used for VAR
 39        self.seed = None 
 40
 41        self.input_ts_ = None  # input time series
 42        self.mean_ = None
 43        self.lower_ = None
 44        self.upper_ = None
 45        self.sims_ = None
 46        self.output_dates_ = None
 47        self.result_dfs_ = None
 48
 49        R_IS_INSTALLED = False
 50
 51        try:
 52            proc = Popen(["which", "R"], stdout=PIPE, stderr=PIPE)
 53            R_IS_INSTALLED = proc.wait() == 0
 54        except Exception as e:
 55            pass
 56
 57        if not R_IS_INSTALLED:
 58            raise ImportError("R is not installed! \n" + config.USAGE_MESSAGE)
 59
 60    def format_input(self):
 61        if self.input_df.shape[1] > 0:
 62            self.input_ts_ = compute_y_mts(self.input_df, self.frequency)
 63        else:
 64            self.input_ts_ = compute_y_ts(self.input_df, self.frequency)
 65
 66    def init_forecasting_params(self, df):
 67        self.input_df = df
 68        self.series_names = df.columns
 69        self.n_series = len(self.series_names)
 70        self.input_dates = compute_input_dates(df)
 71        self.type_input = "multivariate" if len(df.shape) > 0 else "univariate"
 72        self.output_dates_, self.frequency = compute_output_dates(df, self.h)
 73
 74    def getsims(self, input_tuple, ix):
 75        n_sims = len(input_tuple)
 76        res = [input_tuple[i].iloc[:, ix].values for i in range(n_sims)]
 77        return np.asarray(res).T
 78
 79    def get_forecast(self, method=None, xreg=None):
 80
 81        if method != None:
 82            self.method = method
 83
 84        if self.method == "armagarch":
 85            self.fcast_ = config.AHEAD_PACKAGE.armagarchf(
 86                y=self.input_ts_,
 87                h=self.h,
 88                level=self.level,
 89                B=self.B,
 90                cl=self.cl,
 91                dist=self.dist,
 92                seed=self.seed,
 93            )
 94
 95        if self.method in ("mean", "median", "rw"):
 96            self.fcast_ = config.AHEAD_PACKAGE.basicf(
 97                self.input_ts_,
 98                h=self.h,
 99                level=self.level,
100                method=self.method,
101                type_pi=self.type_pi,
102                block_length=self.block_length,
103                B=self.B,
104                seed=self.seed,
105            )
106
107        if self.method == "dynrm":
108            self.fcast_ = config.AHEAD_PACKAGE.dynrmf(
109                y=self.input_ts_,
110                h=self.h,
111                level=self.level,
112                type_pi=self.type_pi,
113            )
114
115        if self.method == "eat":
116            self.fcast_ = config.AHEAD_PACKAGE.eatf(
117                y=self.input_ts_,
118                h=self.h,
119                level=self.level,
120                type_pi=self.type_pi,
121                weights=config.FLOATVECTOR(self.weights),
122            )
123
124        if self.method == "ridge2":
125            if xreg is None:
126
127                self.fcast_ = config.AHEAD_PACKAGE.ridge2f(
128                    self.input_ts_,
129                    h=self.h,
130                    level=self.level,
131                    lags=self.lags,
132                    nb_hidden=self.nb_hidden,
133                    nodes_sim=self.nodes_sim,
134                    activ=self.activation,
135                    a=self.a,
136                    lambda_1=self.lambda_1,
137                    lambda_2=self.lambda_2,
138                    dropout=self.dropout,
139                    type_pi=self.type_pi,
140                    margins=self.margins,
141                    # can be NULL, but in R (use 0 in R instead of NULL for v0.7.0)
142                    block_length=self.block_length,
143                    B=self.B,
144                    type_aggregation=self.type_aggregation,
145                    # can be NULL, but in R (use 0 in R instead of NULL for v0.7.0)
146                    centers=self.centers,
147                    type_clustering=self.type_clustering,
148                    cl=self.cl,
149                    seed=self.seed,
150                )
151
152            else:  # xreg is not None:
153
154                try:
155                    self.xreg_ = xreg.values
156                except:
157                    self.xreg_ = config.DEEP_COPY(xreg)
158
159                is_matrix_xreg = len(self.xreg_.shape) > 1
160
161                numpy2ri.activate()
162
163                xreg_ = (
164                    r.matrix(
165                        FloatVector(self.xreg_.flatten()),
166                        byrow=True,
167                        nrow=self.xreg_.shape[0],
168                        ncol=self.xreg_.shape[1],
169                    )
170                    if is_matrix_xreg
171                    else r.matrix(
172                        FloatVector(self.xreg_.flatten()),
173                        byrow=True,
174                        nrow=self.xreg_.shape[0],
175                        ncol=1,
176                    )
177                )
178
179                self.fcast_ = config.AHEAD_PACKAGE.ridge2f(
180                    self.input_ts_,
181                    xreg=xreg_,
182                    h=self.h,
183                    level=self.level,
184                    lags=self.lags,
185                    nb_hidden=self.nb_hidden,
186                    nodes_sim=self.nodes_sim,
187                    activ=self.activation,
188                    a=self.a,
189                    lambda_1=self.lambda_1,
190                    lambda_2=self.lambda_2,
191                    dropout=self.dropout,
192                    type_pi=self.type_pi,
193                    margins=self.margins,
194                    # can be NULL, but in R (use 0 in R instead of NULL for v0.7.0)
195                    block_length=self.block_length,
196                    B=self.B,
197                    type_aggregation=self.type_aggregation,
198                    # can be NULL, but in R (use 0 in R instead of NULL for v0.7.0)
199                    centers=self.centers,
200                    type_clustering=self.type_clustering,
201                    cl=self.cl,
202                    seed=self.seed,
203                )
204
205        if self.method == "var":
206            self.fcast_ = config.AHEAD_PACKAGE.varf(
207                self.input_ts_,
208                h=self.h,
209                level=self.level,
210                lags=self.lags,
211                type_VAR=self.type_VAR,
212            )
213        
214        if self.method.lower() == "mlarch":
215            valid_type_pi = ("surrogate", "bootstrap", "kde")
216            type_pi = self.type_pi if self.type_pi in valid_type_pi else "surrogate"
217            valid_type_sim = ("surrogate", "block-bootstrap", "bootstrap", "kde", "fitdistr")
218            type_sim_conformalize = (
219                self.type_sim_conformalize if self.type_sim_conformalize in valid_type_sim else "surrogate"
220            )
221
222            mlarch_args = dict(
223                y=self.input_ts_,
224                h=self.h,
225                mean_model=getattr(self, "mean_model", None),
226                model_residuals=getattr(self, "model_residuals", None),
227                fit_func=getattr(self, "fit_func", None),
228                predict_func=getattr(self, "predict_func", None),
229                type_pi=type_pi,
230                type_sim_conformalize=type_sim_conformalize,
231                ml_method=getattr(self, "ml_method", None),
232                level=self.level,
233                B=self.B,
234                ml=True,
235                stat_model=getattr(self, "stat_model", None),
236                seed=self.seed,
237            )
238            # Remove keys with value None
239            mlarch_args = {k: v for k, v in mlarch_args.items() if v is not None}
240
241            self.fcast_ = config.AHEAD_PACKAGE.mlarchf(**mlarch_args)
242
243
244    def plot(self, series, type_axis="dates", type_plot="pi"):
245        """Plot time series forecast
246
247        Parameters:
248
249        series: {integer} or {string}
250            series index or name
251        """
252        assert all(
253            [
254                self.mean_ is not None,
255                self.lower_ is not None,
256                self.upper_ is not None,
257                self.output_dates_ is not None,
258            ]
259        ), "model forecasting must be obtained first (with `forecast` method)"
260
261        if isinstance(series, str):
262            assert (
263                series in self.series_names
264            ), f"series {series} doesn't exist in the input dataset"
265            series_idx = self.input_df.columns.get_loc(series)
266        else:
267            assert isinstance(series, int) and (
268                0 <= series < self.n_series
269            ), f"check series index (< {self.n_series})"
270            series_idx = series
271
272        y_all = list(self.input_df.iloc[:, series_idx]) + list(
273            self.result_dfs_[series_idx]["mean"].values
274        )
275
276        y_test = list(self.result_dfs_[series_idx]["mean"].values)
277        n_points_all = len(y_all)
278        n_points_train = self.input_df.shape[0]
279
280        if type_axis == "numeric":
281            x_all = [i for i in range(n_points_all)]
282            x_test = [i for i in range(n_points_train, n_points_all)]
283
284        if type_axis == "dates":  # use dates
285            x_train = [date.strftime("%Y-%m-%d") for date in self.input_dates]
286            x_test = [date.strftime("%Y-%m-%d") for date in self.output_dates_]
287            x_all = np.concatenate((x_train, x_test), axis=None)
288
289        if type_plot == "pi":
290            fig, ax = plt.subplots()
291            ax.plot(x_all, y_all, "-")
292            ax.plot(x_test, y_test, "-", color="orange")
293            ax.fill_between(
294                x_test,
295                self.result_dfs_[series_idx]["lower"].values,
296                self.result_dfs_[series_idx]["upper"].values,
297                alpha=0.2,
298                color="orange",
299            )
300            plt.title(
301                f"prediction intervals for {series}",
302                loc="left",
303                fontsize=12,
304                fontweight=0,
305                color="black",
306            )
307            plt.show()
308
309        if type_plot == "spaghetti":
310            palette = plt.get_cmap("Set1")
311            sims_ix = self.getsims(self.sims_, series_idx)
312            plt.plot(x_all, y_all, "-")
313            for col_ix in range(
314                sims_ix.shape[1]
315            ):  # avoid this when there are thousands of simulations
316                plt.plot(
317                    x_test,
318                    sims_ix[:, col_ix],
319                    "-",
320                    color=palette(col_ix),
321                    linewidth=1,
322                    alpha=0.9,
323                )
324            plt.plot(x_all, y_all, "-", color="black")
325            plt.plot(x_test, y_test, "-", color="blue")
326            # Add titles
327            plt.title(
328                f"{self.B} simulations of {series}",
329                loc="left",
330                fontsize=12,
331                fontweight=0,
332                color="black",
333            )
334            plt.xlabel("Time")
335            plt.ylabel("Values")
336            # Show the graph
337            plt.show()