11use crate :: storage:: { BinaryOperation , UnaryOperation } ;
2- use crate :: { index:: StridedIndex , DType , Error , Result , Shape } ;
2+ use crate :: { index:: StridedIndex , DType , Shape } ;
3+ use crate :: storage:: StorageError ;
34
45#[ derive( Debug , Clone ) ]
56pub enum CPUStorage {
@@ -21,7 +22,7 @@ impl CPUStorage {
2122 stride : & [ usize ] ,
2223 mul : f64 ,
2324 add : f64 ,
24- ) -> Result < Self > {
25+ ) -> std :: result :: Result < Self , StorageError > {
2526 match self {
2627 Self :: F32 ( storage) => {
2728 let index = StridedIndex :: new ( shape. dims ( ) , stride) ;
@@ -42,7 +43,7 @@ impl CPUStorage {
4243 & self ,
4344 shape : & Shape ,
4445 stride : & [ usize ] ,
45- ) -> Result < Self > {
46+ ) -> std :: result :: Result < Self , StorageError > {
4647 match self {
4748 Self :: F32 ( storage) => {
4849 let index = StridedIndex :: new ( shape. dims ( ) , stride) ;
@@ -63,7 +64,7 @@ impl CPUStorage {
6364 shape : & Shape ,
6465 lhs_stride : & [ usize ] ,
6566 rhs_stride : & [ usize ] ,
66- ) -> Result < Self > {
67+ ) -> std :: result :: Result < Self , StorageError > {
6768 match ( self , rhs) {
6869 ( CPUStorage :: F32 ( lhs) , CPUStorage :: F32 ( rhs) ) => {
6970 let lhs_index = StridedIndex :: new ( shape. dims ( ) , lhs_stride) ;
@@ -85,16 +86,25 @@ impl CPUStorage {
8586
8687 Ok ( Self :: F64 ( data) )
8788 }
88- _ => Err ( Error :: BinaryOperationDTypeMismatch {
89+ _ => Err ( StorageError :: BinaryOperationDTypeMismatch {
8990 lhs : self . dtype ( ) ,
9091 rhs : rhs. dtype ( ) ,
9192 op : T :: NAME ,
9293 } ) ,
9394 }
9495 }
9596
96- pub ( crate ) fn transpose ( & self , shape : & Shape , stride : & [ usize ] ) -> Result < Self > {
97- let ( rows, cols) = shape. rank_two ( ) ?;
97+ pub ( crate ) fn transpose ( & self , shape : & Shape , stride : & [ usize ] ) -> std:: result:: Result < Self , StorageError > {
98+ let ( rows, cols) = match shape. rank_two ( ) {
99+ Ok ( rc) => rc,
100+ Err ( _) => {
101+ return Err ( StorageError :: BinaryOperationShapeMismatch {
102+ lhs : shape. clone ( ) ,
103+ rhs : shape. clone ( ) ,
104+ op : "transpose" ,
105+ } ) ;
106+ }
107+ } ;
98108 match self {
99109 CPUStorage :: F32 ( storage) => {
100110 let mut out = vec ! [ 0f32 ; rows * cols] ;
@@ -126,11 +136,11 @@ impl CPUStorage {
126136 rhs : & Self ,
127137 rhs_shape : ( usize , usize ) ,
128138 rhs_stride : & [ usize ] ,
129- ) -> Result < Self > {
139+ ) -> std :: result :: Result < Self , StorageError > {
130140 let ( m, k) = lhs_shape;
131141 let ( k_rhs, n) = rhs_shape;
132142 if k != k_rhs {
133- return Err ( Error :: BinaryOperationShapeMismatch {
143+ return Err ( StorageError :: BinaryOperationShapeMismatch {
134144 lhs : Shape :: from ( ( m, k) ) ,
135145 rhs : Shape :: from ( ( k_rhs, n) ) ,
136146 op : "matmul" ,
@@ -167,7 +177,7 @@ impl CPUStorage {
167177 }
168178 Ok ( Self :: F64 ( out) )
169179 }
170- _ => Err ( Error :: BinaryOperationDTypeMismatch {
180+ _ => Err ( StorageError :: BinaryOperationDTypeMismatch {
171181 lhs : self . dtype ( ) ,
172182 rhs : rhs. dtype ( ) ,
173183 op : "matmul" ,
0 commit comments