|
| 1 | +const std = @import("std"); |
| 2 | +const print = std.debug.print; |
| 3 | + |
| 4 | +const Mode = enum { c, d }; |
| 5 | + |
| 6 | +pub fn main() !void { |
| 7 | + var args = std.process.args(); |
| 8 | + _ = args.skip(); const arg = args.next(); |
| 9 | + if (arg == null) { print("Error: No args passed. Pass -c for compression , -d for decompression\n", .{}); std.os.exit(1); } |
| 10 | + |
| 11 | + const mode = if (std.mem.eql(u8, arg.?, "-d")) Mode.d |
| 12 | + else if (std.mem.eql(u8, arg.?, "-c")) Mode.c |
| 13 | + else null; |
| 14 | + if (mode == null) { print("Error: Invalid arg. Pass -c for compression , -d for decompression\n", .{}); std.os.exit(2); } |
| 15 | + |
| 16 | + var bufw = std.io.bufferedWriter(std.io.getStdOut().writer()); |
| 17 | + var writer = std.io.bitWriter(.Big, bufw.writer()); |
| 18 | + var model = Model.init(); |
| 19 | + |
| 20 | + if (mode.? == .c) { |
| 21 | + const fileName = args.next(); |
| 22 | + if (fileName == null) { print("Error: To compress use -c fileName\n", .{}); std.os.exit(3); } |
| 23 | + var path_buffer: [std.fs.MAX_PATH_BYTES]u8 = undefined; |
| 24 | + const path = try std.fs.realpathZ(fileName.?, &path_buffer); |
| 25 | + const file = try std.fs.openFileAbsolute(path, .{}); |
| 26 | + defer file.close(); |
| 27 | + const size = (try file.stat()).size; |
| 28 | + var bufr = std.io.bufferedReader(file.reader()); |
| 29 | + var reader = std.io.bitReader(.Big, bufr.reader()); |
| 30 | + |
| 31 | + try writer.writeBits(size, 64); |
| 32 | + var ac = initAC(writer, Mode.c); |
| 33 | + while (true) { |
| 34 | + const bit = reader.readBitsNoEof(u1, 1) catch { break; }; |
| 35 | + try ac.encode(bit, model.p()); |
| 36 | + model.update(bit); |
| 37 | + } |
| 38 | + try ac.flush(); |
| 39 | + try bufw.flush(); |
| 40 | + } else { |
| 41 | + var bufr = std.io.bufferedReader(std.io.getStdIn().reader()); |
| 42 | + var reader = std.io.bitReader(.Big, bufr.reader()); |
| 43 | + const size = try reader.readBitsNoEof(u64, 64); |
| 44 | + var ac = initAC(reader, Mode.d); |
| 45 | + |
| 46 | + var i: u64 = 0; |
| 47 | + while (i / 8 < size) : (i += 1) { |
| 48 | + const bit = ac.decode(model.p()); |
| 49 | + try writer.writeBits(bit, 1); |
| 50 | + model.update(bit); |
| 51 | + } |
| 52 | + try bufw.flush(); |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +const Model = struct { |
| 57 | + ctx: u12, |
| 58 | + data: [1 << 12]Counter, |
| 59 | + |
| 60 | + const Self = @This(); |
| 61 | + |
| 62 | + pub fn init() Self { return Self { .ctx = 0, .data = .{Counter.init()}**(1<<12) }; } |
| 63 | + pub fn p(self: Self) u16 { return self.data[self.ctx].p(); } |
| 64 | + pub fn update(self: *Self, bit: u1) void { |
| 65 | + self.data[self.ctx].update(bit); |
| 66 | + self.ctx <<= 1; self.ctx |= bit; self.ctx &= (1 << 12) - 1; |
| 67 | + } |
| 68 | +}; |
| 69 | + |
| 70 | +const Counter = struct { |
| 71 | + // c0: u16, c1: u16, |
| 72 | + c0: u12, c1: u12, |
| 73 | + |
| 74 | + const Self = @This(); |
| 75 | + pub fn init() Self { return Self { .c0 = 0, .c1 = 0 }; } |
| 76 | + pub fn p(self: Self) u16 { |
| 77 | + const n0 = @as(u64, self.c0); |
| 78 | + const n1 = @as(u64, self.c1); |
| 79 | + return @intCast(u16, (1 << 16) * (n1 + 1) / (n1 + n0 + 2)); |
| 80 | + } |
| 81 | + pub fn update(self: *Self, bit: u1) void { |
| 82 | + // const maxCount = (1 << 16) - 1; |
| 83 | + const maxCount = (1 << 12) - 1; |
| 84 | + if (self.c0 == maxCount or self.c1 == maxCount) { |
| 85 | + self.c0 >>= 1; |
| 86 | + self.c1 >>= 1; |
| 87 | + } |
| 88 | + if (bit == 1) self.c1 += 1 else self.c0 += 1; |
| 89 | + } |
| 90 | +}; |
| 91 | + |
| 92 | +fn initAC(writer: anytype, comptime mode: Mode) ArithmeticCoder(@TypeOf(writer), mode) { return ArithmeticCoder(@TypeOf(writer), mode).init(writer); } |
| 93 | +fn ArithmeticCoder(comptime T: type, comptime mode: Mode) type { return struct { |
| 94 | + io:T, x: if (mode == Mode.d) u32 else void, |
| 95 | + revBits: if (mode == Mode.c) u64 else void, |
| 96 | + x1: u32 = 0, x2: u32 = (1 << 32) - 1, |
| 97 | + |
| 98 | + const Self = @This(); |
| 99 | + pub fn init(io: T) Self { |
| 100 | + var self = if (mode == .c) Self { .io = io, .revBits = 0, .x = {} } |
| 101 | + else if (mode == .d) Self { .io = io, .x = 0, .revBits = {} }; |
| 102 | + if (mode == .d) self.readState(); |
| 103 | + return self; |
| 104 | + } |
| 105 | + pub fn encode(self: *Self, bit: u1, p: u16) !void { return self.code(bit, p); } |
| 106 | + pub fn decode(self: *Self, p: u16) u1 { return self.code({}, p); } |
| 107 | + pub fn flush(self: *Self) !void { |
| 108 | + try self.writeBit(self.x2 >> 31); |
| 109 | + while (self.io.bit_count != 0) { |
| 110 | + self.x2 <<= 1; try self.writeBit(self.x2 >> 31); |
| 111 | + } |
| 112 | + } |
| 113 | + |
| 114 | + fn readBit(self: *Self) u1 { return self.io.readBitsNoEof(u1, 1) catch 0; } |
| 115 | + fn incParity(self: *Self) void { self.revBits += 1; } |
| 116 | + fn writeBit(self: *Self, bit: u32) !void { |
| 117 | + try self.io.writeBits(bit, 1); |
| 118 | + while (self.revBits > 0) { |
| 119 | + try self.io.writeBits(bit ^ 1, 1); |
| 120 | + self.revBits -= 1; |
| 121 | + } |
| 122 | + } |
| 123 | + fn readState(self: *Self) void { |
| 124 | + var bitsRead: usize = 0; |
| 125 | + var state = self.io.readBits(u32, 32, &bitsRead) catch 0; |
| 126 | + self.x = state << @intCast(u5, 32 - bitsRead); |
| 127 | + } |
| 128 | + |
| 129 | + fn code(self: *Self, bit_: if (mode == .d) void else u1, prob: u16) if (mode == .d) u1 else anyerror!void { |
| 130 | + const p = if (prob == 0) 1 else @as(u64, prob) << 16; |
| 131 | + const xmid = @intCast(u32, self.x1 + ((@as(u64, self.x2 - self.x1) * p) >> 32)); |
| 132 | + |
| 133 | + const bit = if (mode == .c) bit_ else @boolToInt(self.x <= xmid); |
| 134 | + if (bit == 1) self.x2 = xmid else self.x1 = xmid + 1; |
| 135 | + |
| 136 | + while ((self.x1 ^ self.x2) >> 31 == 0) { |
| 137 | + if (mode == .c) try self.writeBit(self.x1 >> 31) |
| 138 | + else self.x = (self.x << 1) | self.readBit(); |
| 139 | + self.x1 <<= 1; |
| 140 | + self.x2 = (self.x2 << 1) | 1; |
| 141 | + } |
| 142 | + |
| 143 | + while (self.x1 >= (1 << 30) and self.x2 < (3 << 30)) { |
| 144 | + if (mode == .c) self.incParity() |
| 145 | + else self.x = ((self.x << 1) ^ (2 << 30)) | self.readBit(); |
| 146 | + self.x1 = (self.x1 << 1) & ((1 << 31) - 1); |
| 147 | + self.x2 = (self.x2 << 1) | ((1 << 31) + 1); |
| 148 | + } |
| 149 | + |
| 150 | + if (mode == .d) return bit; |
| 151 | + } |
| 152 | +};} |
0 commit comments