/ac/codec.go
// Adapted from Eric Bodden's "Arithmetic Coding Revealed",
// http://www.sable.mcgill.ca/publications/techreports/2007-5/bodden-07-arithmetic-TR.pdf
package ac

import "io"

type Codec interface {
	SetFileOutput(f io.Writer)
	Encode(low_count uint32, high_count uint32, total uint32)
	EncodeFinish()

	SetFileInput(f io.Reader)
	DecodeStart()
	DecodeTarget(total uint32) uint32
	Decode(low_count uint32, high_count uint32)
}

const (
	g_FirstQuarter = 0x20000000
	g_ThirdQuarter = 0x60000000
	g_Half         = 0x40000000
)

type codecState struct {
	// Two different file variables, depending on whether we're reading or
	// writing
	readFile io.Reader
	writeFile io.Writer
	// Bit I/O buffer
	bitCount uint8
	bitBuffer byte
	// encode/decode state
	low, high, step, scale uint32
	// decoder buffer
	buffer uint32
}

func NewCodec() *codecState {
	return &codecState{
		nil,
		nil,
		0,
		0,
		0, 0x7FFFFFFF, 0, 0,  // mHigh uses 31-bits to prevent overflow
		0,
	}
}

func (self *codecState) SetFileInput(f io.Reader) {
	self.readFile = f.(io.Reader)
}

func (self *codecState) SetFileOutput(f io.Writer) {
	self.writeFile = f.(io.Writer)
}

func (self *codecState) setBit(bit uint8) {
	// Shift one bit onto the buffer (bit therefore should only be 0 or 1)
	self.bitBuffer = (self.bitBuffer << 1) | bit
	self.bitCount++

	if self.bitCount == 8 {
		// Buffer full -- write to output
		_, err := self.writeFile.Write([]byte{self.bitBuffer})
		if err != nil {
			panic("Failed to write during setBit()")
		}
		self.bitCount = 0
	}
}

func (self *codecState) setBitFlush() {
	// Fill remainder of the bit buffer with zeroes
	for self.bitCount != 0 {
		self.setBit(0)
	}
}

func (self *codecState) getBit() uint8 {
	if (self.bitCount == 0) {
		// Buffer is empty -- read from input
		var b [1]byte
		n, err := self.readFile.Read(b[:])
		if n == 0 && err == io.EOF {
			self.bitBuffer = 0
		} else if n == 1 {
			self.bitBuffer = b[0]
		}
		self.bitCount = 8
	}

	// Shift one bit out of the buffer
	bit := uint8(self.bitBuffer >> 7)
	self.bitBuffer <<= 1
	self.bitCount--

	return bit
}

func (self *codecState) Encode(low_count uint32, high_count uint32, total uint32) {
	// Partition number space into single steps
	self.step = (self.high - self.low + 1) / total
	// Update upper bound
	self.high = self.low + self.step * high_count - 1
	// Update lower bound
	self.low = self.low + self.step * low_count

	// Apply e1/e2 scaling
	for self.high < g_Half || self.low >= g_Half {
		if self.high < g_Half {
			self.setBit(0)
			self.low = self.low * 2
			self.high = self.high * 2 + 1

			// Unwind e3 scaling
			for ; self.scale > 0; self.scale-- {
				self.setBit(1)
			}
		} else if self.low >= g_Half {
			self.setBit(1)
			self.low = 2 * (self.low - g_Half)
			self.high = 2 * (self.high - g_Half)

			// Unwind e3 scaling
			for ; self.scale > 0; self.scale-- {
				self.setBit(0)
			}
		}
	}

	// Apply e3 scaling and store scale for later e1/e2 scaling
	for g_FirstQuarter <= self.low && self.high < g_ThirdQuarter {
		self.scale++
		self.low = 2 * (self.low - g_FirstQuarter)
		self.high = 2 * (self.high - g_FirstQuarter) + 1
	}
}

func (self *codecState) EncodeFinish() {
	// There are two possibilities of how low and high can be distributed,
	// which means that two bits are enough to distinguish them.
	if self.low < g_FirstQuarter {
		self.setBit(0)
		for i := uint32(0); i < self.scale + 1; i++ {
			self.setBit(1)
		}
	} else {
		self.setBit(1)
	}
	self.setBitFlush()
}

func (self *codecState) DecodeStart() {
	// Fill buffer with input
	for i := 0; i < 31; i++ {
		self.buffer = (self.buffer << 1) | uint32(self.getBit())
	}
}

func (self *codecState) DecodeTarget(total uint32) uint32 {
	self.step = (self.high - self.low + 1) / total

	return (self.buffer - self.low) / self.step
}

func (self *codecState) Decode(low_count uint32, high_count uint32) {
	// Update upper bound
	self.high = self.low + self.step * high_count - 1

	// Update lower bound
	self.low = self.low + self.step * low_count

	// Apply e1/e2 scaling
	for self.high < g_Half || self.low >= g_Half {
		if self.high < g_Half {
			self.low = self.low * 2
			self.high = self.high * 2 + 1
			self.buffer = 2 * self.buffer + uint32(self.getBit())
		} else if self.low >= g_Half {
			self.low = 2 * (self.low - g_Half)
			self.high = 2 * (self.high - g_Half)
			self.buffer = 2 * (self.buffer - g_Half) + uint32(self.getBit())
		}
		self.scale = 0
	}
	
	// Apply e3 scaling
	for g_FirstQuarter <= self.low && self.high < g_ThirdQuarter {
		self.scale++
		self.low = 2 * (self.low - g_FirstQuarter)
		self.high = 2 * (self.high - g_FirstQuarter) + 1
		self.buffer = 2 * (self.buffer - g_FirstQuarter) + uint32(self.getBit())
	}
}