rkaf provides Kolmogorov-Arnold Fourier Networks for R
users through the torch backend.
The package supports:
This vignette gives a quick tour of the main workflow.
We first fit a KAF model to a synthetic one-dimensional function with both low-frequency and high-frequency structure.
x <- as.matrix(seq(-1, 1, length.out = 128))
y <- sin(8 * pi * x) +
0.35 * cos(3 * pi * x) +
0.15 * x^2fit <- kaf_fit(
x = x,
y = y,
hidden = c(256, 256),
num_grids = 32,
use_layernorm = FALSE,
epochs = 1000,
lr = 1e-3,
standardize_x = FALSE,
standardize_y = TRUE,
fourier_init_scale = 5e-2,
restore_best = TRUE,
verbose = FALSE,
seed = 123
)
fit
#> <kaf_fit>
#> Task: regression
#> Architecture: 1 -> 256 -> 256 -> 1
#> Fourier grids: 32
#> Epochs: 1000
#> Batch size: 128
#> Validation: no
#> Standardize x: no
#> Standardize y: yes
#> Final train loss: 0.0227977
#> Best train loss: 0.0198582 at epoch 945pred <- predict(fit, x)
head(data.frame(
observed = round(as.numeric(y), 3),
predicted = round(pred, 3)
))
#> observed predicted
#> 1 -0.200 0.195
#> 2 0.185 0.356
#> 3 0.517 0.474
#> 4 0.748 0.540
#> 5 0.842 0.545
#> 6 0.787 0.486plot(
x,
y,
type = "l",
lwd = 2,
xlab = "x",
ylab = "f(x)",
main = "KAF regression fit"
)
lines(x, pred, lwd = 2, lty = 2)
legend(
"topright",
legend = c("Observed", "Predicted"),
lty = c(1, 2),
lwd = 2,
bty = "n"
)For tabular data, rkaf also supports a formula
interface.
fit_mtcars <- kaf_fit_formula(
mpg ~ wt + hp + cyl,
data = mtcars,
hidden = c(32, 32),
num_grids = 16,
epochs = 200,
verbose = FALSE,
seed = 123
)
fit_mtcars
#> <kaf_fit>
#> Task: regression
#> Formula: mpg ~ wt + hp + cyl
#> Architecture: 3 -> 32 -> 32 -> 1
#> Fourier grids: 16
#> Epochs: 200
#> Batch size: 32
#> Validation: no
#> Standardize x: yes
#> Standardize y: yes
#> Final train loss: 0.0542788
#> Best train loss: 0.0542788 at epoch 200If the response is a factor with two classes, rkaf
automatically treats the problem as binary classification.
df <- mtcars
df$high_mpg <- factor(
ifelse(df$mpg > median(df$mpg), "yes", "no"),
levels = c("no", "yes")
)fit_binary <- kaf_fit_formula(
high_mpg ~ wt + hp + cyl,
data = df,
hidden = c(32, 32),
num_grids = 16,
epochs = 200,
verbose = FALSE,
seed = 123
)
fit_binary
#> <kaf_fit>
#> Task: binary
#> Classes: no, yes
#> Formula: high_mpg ~ wt + hp + cyl
#> Architecture: 3 -> 32 -> 32 -> 1
#> Fourier grids: 16
#> Epochs: 200
#> Batch size: 32
#> Validation: no
#> Standardize x: yes
#> Standardize y: no
#> Final train loss: 0.00163561
#> Best train loss: 0.00163561 at epoch 200Predicted probabilities and classes:
prob_binary <- predict(fit_binary, df, type = "prob")
class_binary <- predict(fit_binary, df, type = "class")
head(data.frame(
observed = df$high_mpg,
prob_yes = round(prob_binary, 3),
predicted = class_binary
))
#> observed prob_yes predicted
#> 1 yes 0.999 yes
#> 2 yes 0.999 yes
#> 3 yes 1.000 yes
#> 4 yes 0.983 yes
#> 5 no 0.000 no
#> 6 no 0.008 noConfusion matrix
table(
observed = df$high_mpg,
predicted = class_binary
)
#> predicted
#> observed no yes
#> no 17 0
#> yes 0 15Raw logits:
If the response is a factor with more than two classes,
rkaf fits a multiclass classifier.
fit_iris <- kaf_fit_formula(
Species ~ .,
data = iris,
hidden = c(32, 32),
num_grids = 16,
epochs = 300,
verbose = FALSE,
seed = 123
)
fit_iris
#> <kaf_fit>
#> Task: multiclass
#> Classes: setosa, versicolor, virginica
#> Formula: Species ~ .
#> Architecture: 4 -> 32 -> 32 -> 3
#> Fourier grids: 16
#> Epochs: 300
#> Batch size: 150
#> Validation: no
#> Standardize x: yes
#> Standardize y: no
#> Final train loss: 0.00819381
#> Best train loss: 0.00819381 at epoch 300Confusion matrix
class_iris <- predict(fit_iris, iris, type = "class")
table(
observed = iris$Species,
predicted = class_iris
)
#> predicted
#> observed setosa versicolor virginica
#> setosa 50 0 0
#> versicolor 0 50 0
#> virginica 0 0 50Class probabilities:
kaf_fit() supports validation splits, mini-batches, and
early stopping.
fit_val <- kaf_fit(
x = x,
y = y,
hidden = c(64, 64),
num_grids = 16,
use_layernorm = FALSE,
epochs = 300,
lr = 5e-4,
batch_size = 64,
validation_split = 0.2,
patience = 100,
restore_best = TRUE,
verbose = FALSE,
seed = 123
)
plot(fit_val)The fitted object stores both train_loss_history and
validation_loss_history, so users can inspect training and
validation behavior directly.
The KAF architecture contains a base/GELU branch and a Fourier branch. The package exposes helper functions to inspect the learned branch scales and Fourier parameters.
scales <- extract_kaf_scales(fit)
head(scales)
#> layer feature base_scale fourier_scale fourier_to_base_ratio
#> 1 1 1 1.0369691 0.11137812 0.10740737
#> 2 2 1 0.9652719 0.15750510 0.16317174
#> 3 2 2 0.9738784 0.39944243 0.41015638
#> 4 2 3 0.9945174 0.12360517 0.12428658
#> 5 2 4 0.9936960 0.08714252 0.08769535
#> 6 2 5 1.0083405 0.08873361 0.08799965fourier_params <- extract_fourier_params(fit, layer = 1)
head(fourier_params)
#> layer input_feature grid weight bias
#> 1 1 1 1 0.1379520 4.5462580
#> 2 1 1 2 -0.1811931 4.3979001
#> 3 1 1 3 -0.1297648 5.7381864
#> 4 1 1 4 -0.4164377 2.7132401
#> 5 1 1 5 0.3679259 0.4899397
#> 6 1 1 6 0.3606938 2.3420620Advanced users can use the low-level torch modules directly.