@@ -6,36 +6,40 @@ struct Matrix(Copyable, Movable, ImplicitlyCopyable, Sized):
66 var height : Int
77 var width : Int
88 var size : Int
9- var data : UnsafePointer[Float32]
9+ var data : UnsafePointer[Float32, MutAnyOrigin ]
1010 var order : String
1111
1212 # initialize from UnsafePointer
1313 @always_inline
14- fn __init__ (out self , data : UnsafePointer[Float32 ], height : Int, width : Int, order : String = ' c' ):
14+ fn __init__ [ src : DType = DType.float32] (out self , data : UnsafePointer[Scalar[src], MutAnyOrigin ], height : Int, width : Int, order : String = ' c' ):
1515 self .height = height
1616 self .width = width
1717 self .size = height * width
18- self .data = data
18+ if src == DType.float32:
19+ self .data = data.bitcast[Float32]()
20+ else :
21+ self .data = cast[src=src, des=DType.float32, width=self .simd_width](data, self .size)
22+ data.free()
1923 self .order = order.lower()
2024
2125 # initialize by copying from UnsafePointer
2226 @always_inline
23- fn __init__ (out self , height : Int, width : Int, data : UnsafePointer[Float32] = UnsafePointer[Float32](), order : String = ' c' ):
27+ fn __init__ (out self , height : Int, width : Int, data : UnsafePointer[Float32, MutAnyOrigin ] = UnsafePointer[Float32, MutAnyOrigin ](), order : String = ' c' ):
2428 self .height = height
2529 self .width = width
2630 self .size = height * width
27- self .data = UnsafePointer [Float32].alloc (self .size)
31+ self .data = alloc [Float32](self .size)
2832 self .order = order.lower()
2933 if data:
30- memcpy(self .data, data, self .size)
34+ memcpy(dest = self .data, src = data, count = self .size)
3135
3236 fn __copyinit__ (out self , other : Self):
3337 self .height = other.height
3438 self .width = other.width
3539 self .size = other.size
36- self .data = UnsafePointer [Float32].alloc (self .size)
40+ self .data = alloc [Float32](self .size)
3741 self .order = other.order
38- memcpy(self .data, other.data, self .size)
42+ memcpy(dest = self .data, src = other.data, count = self .size)
3943
4044 fn __moveinit__ (out self , deinit existing : Self):
4145 self .height = existing.height
@@ -45,7 +49,7 @@ struct Matrix(Copyable, Movable, ImplicitlyCopyable, Sized):
4549 self .order = existing.order
4650 # existing.height = existing.width = existing.size = 0
4751 # existing.order = ''
48- # existing.data = UnsafePointer[Float32]()
52+ # existing.data = UnsafePointer[Float32, MutAnyOrigin ]()
4953
5054 # access an element
5155 @always_inline
@@ -56,7 +60,7 @@ struct Matrix(Copyable, Movable, ImplicitlyCopyable, Sized):
5660 else :
5761 loc = (column * self .height) + row
5862 if loc > self .size - 1 or loc < 0 :
59- raise Error(" Error: Location is out of range!" )
63+ raise Error(" Location is out of range!" )
6064 return self .data[loc]
6165
6266 @always_inline
@@ -72,6 +76,24 @@ struct Matrix(Copyable, Movable, ImplicitlyCopyable, Sized):
7276 fn __mul__ (self , rhs : Self) raises -> Self:
7377 if self .width != rhs.height:
7478 raise Error(' Error: Cannot multiply matrices with shapes (' + String(self .height) + ' , ' + String(self .width) + ' ) and (' + String(rhs.height) + ' , ' + String(rhs.width) + ' )' )
79+
80+ if self .height == 1 and rhs.width == 1 :
81+ # Dot product
82+ var mat = Self(1 , 1 )
83+ mat.data[0 ] = self .ele_mul(rhs.T()).sum()
84+ return mat^
85+
86+ if self .height * self .width * rhs.width <= 4096 :
87+ # matmul naive
88+ var mat = Self(self .height, rhs.width)
89+ for i in range (self .size):
90+ var rhsr = i % self .width
91+ for j in range (rhsr * rhs.width, rhsr * rhs.width + rhs.width):
92+ if rhsr != 0 :
93+ mat.data[(Int(i / self .width) * mat.width) + (j % rhs.width)] += self .data[i] * rhs.data[j]
94+ else :
95+ mat.data[(Int(i / self .width) * mat.width) + (j % rhs.width)] = self .data[i] * rhs.data[j]
96+ return mat^
7597 var A = matmul.Matrix[DType.float32](self .data, (self .height, self .width))
7698 var B = matmul.Matrix[DType.float32](rhs.data, (rhs.height, rhs.width))
7799 var C = matmul.Matrix[DType.float32]((self .height, rhs.width))
@@ -91,7 +113,6 @@ struct Matrix(Copyable, Movable, ImplicitlyCopyable, Sized):
91113 return mat^
92114
93115 @ staticmethod
94- @always_inline
95116 fn random (height : Int, width : Int, order : String = ' c' ) -> Matrix:
96117 random.seed()
97118 var mat = Matrix(height, width, order = order)
0 commit comments