Skip to content

Commit 0255443

Browse files
authored
Modularize errors (#4)
1 parent 6c5cd9b commit 0255443

File tree

15 files changed

+285
-176
lines changed

15 files changed

+285
-176
lines changed

Cargo.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,3 @@ members = [
55
]
66

77
[workspace.dependencies]
8-
anyhow = "1.0.75"
9-
thiserror = "1"

crates/phantom-core/Cargo.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,3 @@ license = "MIT"
88
readme = "../../README.md"
99

1010
[dependencies]
11-
anyhow = { workspace = true }
12-
thiserror = { workspace = true }

crates/phantom-core/src/backend/cpu_backend.rs

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::storage::{BinaryOperation, UnaryOperation};
2-
use crate::{index::StridedIndex, DType, Error, Result, Shape};
2+
use crate::{index::StridedIndex, DType, Shape};
3+
use crate::storage::StorageError;
34

45
#[derive(Debug, Clone)]
56
pub enum CPUStorage {
@@ -21,7 +22,7 @@ impl CPUStorage {
2122
stride: &[usize],
2223
mul: f64,
2324
add: f64,
24-
) -> Result<Self> {
25+
) -> std::result::Result<Self, StorageError> {
2526
match self {
2627
Self::F32(storage) => {
2728
let index = StridedIndex::new(shape.dims(), stride);
@@ -42,7 +43,7 @@ impl CPUStorage {
4243
&self,
4344
shape: &Shape,
4445
stride: &[usize],
45-
) -> Result<Self> {
46+
) -> std::result::Result<Self, StorageError> {
4647
match self {
4748
Self::F32(storage) => {
4849
let index = StridedIndex::new(shape.dims(), stride);
@@ -63,7 +64,7 @@ impl CPUStorage {
6364
shape: &Shape,
6465
lhs_stride: &[usize],
6566
rhs_stride: &[usize],
66-
) -> Result<Self> {
67+
) -> std::result::Result<Self, StorageError> {
6768
match (self, rhs) {
6869
(CPUStorage::F32(lhs), CPUStorage::F32(rhs)) => {
6970
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
@@ -85,16 +86,25 @@ impl CPUStorage {
8586

8687
Ok(Self::F64(data))
8788
}
88-
_ => Err(Error::BinaryOperationDTypeMismatch {
89+
_ => Err(StorageError::BinaryOperationDTypeMismatch {
8990
lhs: self.dtype(),
9091
rhs: rhs.dtype(),
9192
op: T::NAME,
9293
}),
9394
}
9495
}
9596

96-
pub(crate) fn transpose(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
97-
let (rows, cols) = shape.rank_two()?;
97+
pub(crate) fn transpose(&self, shape: &Shape, stride: &[usize]) -> std::result::Result<Self, StorageError> {
98+
let (rows, cols) = match shape.rank_two() {
99+
Ok(rc) => rc,
100+
Err(_) => {
101+
return Err(StorageError::BinaryOperationShapeMismatch {
102+
lhs: shape.clone(),
103+
rhs: shape.clone(),
104+
op: "transpose",
105+
});
106+
}
107+
};
98108
match self {
99109
CPUStorage::F32(storage) => {
100110
let mut out = vec![0f32; rows * cols];
@@ -126,11 +136,11 @@ impl CPUStorage {
126136
rhs: &Self,
127137
rhs_shape: (usize, usize),
128138
rhs_stride: &[usize],
129-
) -> Result<Self> {
139+
) -> std::result::Result<Self, StorageError> {
130140
let (m, k) = lhs_shape;
131141
let (k_rhs, n) = rhs_shape;
132142
if k != k_rhs {
133-
return Err(Error::BinaryOperationShapeMismatch {
143+
return Err(StorageError::BinaryOperationShapeMismatch {
134144
lhs: Shape::from((m, k)),
135145
rhs: Shape::from((k_rhs, n)),
136146
op: "matmul",
@@ -167,7 +177,7 @@ impl CPUStorage {
167177
}
168178
Ok(Self::F64(out))
169179
}
170-
_ => Err(Error::BinaryOperationDTypeMismatch {
180+
_ => Err(StorageError::BinaryOperationDTypeMismatch {
171181
lhs: self.dtype(),
172182
rhs: rhs.dtype(),
173183
op: "matmul",

crates/phantom-core/src/backprop.rs

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,31 @@
11
use crate::tensor::{Tensor, TensorID};
2-
use crate::Operation;
3-
use crate::Result;
2+
use crate::{Operation};
3+
use crate::tensor::TensorError;
4+
5+
#[derive(Debug)]
6+
pub enum BackpropError {
7+
MissingGradient { tensor: TensorID },
8+
Tensor(TensorError),
9+
}
10+
11+
impl From<TensorError> for BackpropError {
12+
fn from(e: TensorError) -> Self {
13+
BackpropError::Tensor(e)
14+
}
15+
}
16+
17+
impl std::fmt::Display for BackpropError {
18+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19+
match self {
20+
BackpropError::MissingGradient { tensor } => {
21+
write!(f, "missing gradient for tensor {:?}", tensor)
22+
}
23+
BackpropError::Tensor(e) => write!(f, "{e}"),
24+
}
25+
}
26+
}
27+
28+
impl std::error::Error for BackpropError {}
429
use std::collections::HashMap;
530

631
impl Tensor {
@@ -84,9 +109,9 @@ impl Tensor {
84109
/// let gradients = z.backward()?;
85110
/// assert_eq!(gradients.len(), 1);
86111
///
87-
/// # Ok::<(), phantom_core::Error>(())
112+
/// # Ok::<(), Box<dyn std::error::Error>>(())
88113
/// ```
89-
pub fn backward(&self) -> Result<HashMap<TensorID, Tensor>> {
114+
pub fn backward(&self) -> std::result::Result<HashMap<TensorID, Tensor>, BackpropError> {
90115
let sorted_nodes = self.sorted_nodes();
91116
let mut gradients = HashMap::new();
92117

@@ -97,7 +122,9 @@ impl Tensor {
97122
continue;
98123
}
99124

100-
let gradient = gradients.remove(&node.id()).unwrap();
125+
let gradient = gradients
126+
.remove(&node.id())
127+
.ok_or(BackpropError::MissingGradient { tensor: node.id() })?;
101128

102129
if let Some(op) = node.op() {
103130
match op {

crates/phantom-core/src/device.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::backend::cpu_backend::CPUStorage;
2-
use crate::{storage::Storage, DType, Result, Shape};
2+
use crate::{storage::Storage, DType, Shape};
3+
use crate::shape::ShapeError;
34

45
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
56
pub enum Device {
@@ -41,12 +42,12 @@ impl Device {
4142
}
4243

4344
pub trait NDArray {
44-
fn shape(&self) -> Result<Shape>;
45+
fn shape(&self) -> std::result::Result<Shape, ShapeError>;
4546
fn to_cpu(&self) -> CPUStorage;
4647
}
4748

4849
impl<S: crate::WithDType> NDArray for S {
49-
fn shape(&self) -> Result<Shape> {
50+
fn shape(&self) -> std::result::Result<Shape, ShapeError> {
5051
Ok(Shape::from(()))
5152
}
5253

@@ -56,7 +57,7 @@ impl<S: crate::WithDType> NDArray for S {
5657
}
5758

5859
impl<S: crate::WithDType> NDArray for &[S] {
59-
fn shape(&self) -> Result<Shape> {
60+
fn shape(&self) -> std::result::Result<Shape, ShapeError> {
6061
Ok(Shape::from(self.len()))
6162
}
6263

@@ -66,7 +67,7 @@ impl<S: crate::WithDType> NDArray for &[S] {
6667
}
6768

6869
impl<S: crate::WithDType, const N: usize> NDArray for &[S; N] {
69-
fn shape(&self) -> Result<Shape> {
70+
fn shape(&self) -> std::result::Result<Shape, ShapeError> {
7071
Ok(Shape::from(self.len()))
7172
}
7273

@@ -76,7 +77,7 @@ impl<S: crate::WithDType, const N: usize> NDArray for &[S; N] {
7677
}
7778

7879
impl<S: crate::WithDType, const N: usize, const M: usize> NDArray for &[[S; N]; M] {
79-
fn shape(&self) -> Result<Shape> {
80+
fn shape(&self) -> std::result::Result<Shape, ShapeError> {
8081
Ok(Shape::from((M, N)))
8182
}
8283

crates/phantom-core/src/dtype.rs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,25 @@
1-
use crate::{CPUStorage, Error, Result};
1+
use crate::CPUStorage;
2+
3+
#[derive(Debug)]
4+
pub enum DTypeError {
5+
UnexpectedDType { expected: DType, actual: DType },
6+
}
7+
8+
impl std::fmt::Display for DTypeError {
9+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
10+
match self {
11+
DTypeError::UnexpectedDType { expected, actual } => {
12+
write!(
13+
f,
14+
"unexpected dtype, expected: {:?}, actual: {:?}",
15+
expected, actual
16+
)
17+
}
18+
}
19+
}
20+
}
21+
22+
impl std::error::Error for DTypeError {}
223

324
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
425
pub enum DType {
@@ -22,7 +43,7 @@ pub trait WithDType: Sized + Copy {
2243
fn to_cpu(data: &[Self]) -> CPUStorage {
2344
Self::to_cpu_owned(data.to_vec())
2445
}
25-
fn storage_slice(storage: &CPUStorage) -> Result<&[Self]>;
46+
fn storage_slice(storage: &CPUStorage) -> std::result::Result<&[Self], DTypeError>;
2647
}
2748

2849
macro_rules! with_dtype {
@@ -34,10 +55,10 @@ macro_rules! with_dtype {
3455
CPUStorage::$dtype(data)
3556
}
3657

37-
fn storage_slice(storage: &CPUStorage) -> Result<&[Self]> {
58+
fn storage_slice(storage: &CPUStorage) -> std::result::Result<&[Self], DTypeError> {
3859
match storage {
3960
CPUStorage::$dtype(data) => Ok(data),
40-
_ => Err(Error::UnexpectedDType {
61+
_ => Err(DTypeError::UnexpectedDType {
4162
expected: DType::$dtype,
4263
actual: storage.dtype(),
4364
}),

crates/phantom-core/src/error.rs

Lines changed: 0 additions & 37 deletions
This file was deleted.

crates/phantom-core/src/lib.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ mod backend;
22
mod backprop;
33
mod device;
44
mod dtype;
5-
mod error;
65
mod index;
76
mod operation;
87
mod shape;
@@ -12,9 +11,12 @@ mod tensor;
1211
pub use backend::cpu_backend::CPUStorage;
1312
pub use device::Device;
1413
pub use dtype::{DType, WithDType};
15-
pub use error::{Error, Result};
14+
pub use dtype::DTypeError;
1615
pub use index::StridedIndex;
1716
pub use operation::Operation;
1817
pub use shape::Shape;
18+
pub use shape::ShapeError;
1919
pub use storage::Storage;
20-
pub use tensor::{Tensor, TensorID};
20+
pub use storage::StorageError;
21+
pub use tensor::{Tensor, TensorError, TensorID};
22+
pub use backprop::BackpropError;

crates/phantom-core/src/shape.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1-
use crate::{Error, Result};
1+
2+
#[derive(Debug)]
3+
pub enum ShapeError {
4+
UnexpectedRank { expected: usize, actual: usize, shape: Shape },
5+
}
6+
7+
impl std::fmt::Display for ShapeError {
8+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
9+
match self {
10+
ShapeError::UnexpectedRank { expected, actual, .. } => {
11+
write!(f, "unexpected rank, expected: {}, actual: {}", expected, actual)
12+
}
13+
}
14+
}
15+
}
16+
17+
impl std::error::Error for ShapeError {}
218

319
#[derive(Clone, PartialEq, Eq)]
420
pub struct Shape(pub(crate) Vec<usize>);
@@ -47,9 +63,9 @@ impl From<&Shape> for Shape {
4763

4864
macro_rules! get_rank {
4965
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
50-
pub fn $fn_name(&self) -> Result<$out_type> {
66+
pub fn $fn_name(&self) -> std::result::Result<$out_type, ShapeError> {
5167
if self.0.len() != $cnt {
52-
Err(Error::UnexpectedRank {
68+
Err(ShapeError::UnexpectedRank {
5369
expected: $cnt,
5470
actual: self.0.len(),
5571
shape: self.clone(),

0 commit comments

Comments
 (0)