diff --git a/Project.toml b/Project.toml index 494a4b0..90666cf 100644 --- a/Project.toml +++ b/Project.toml @@ -21,4 +21,4 @@ LinearAlgebra = "1.9" NNlib = "0.9" Statistics = "1.9" Zygote = "0.6" -julia = "1.9" +julia = "1.10" diff --git a/src/losses.jl b/src/losses.jl index 54f63a4..6bcad30 100644 --- a/src/losses.jl +++ b/src/losses.jl @@ -9,7 +9,7 @@ function cross_entropy(probs, label; dims = 1, agg = mean) return agg(.-sum(ce_summands; dims = dims)) end -function class_error(probs, label; dims = 1, agg = mean) +function class_error(probs::Array, label::Array; dims = 1, agg = mean) class_predicted = argmax(probs; dims = dims) class_actual = argmax(label; dims = dims) return agg(1 .- (class_predicted .== class_actual))