diff --git a/cdb_test.go b/cdb_test.go index fe19b1b..df37bfb 100644 --- a/cdb_test.go +++ b/cdb_test.go @@ -1,12 +1,11 @@ package cdb import ( - "bufio" "bytes" + "errors" "fmt" "io" "io/ioutil" - "os" "testing" ) @@ -23,27 +22,55 @@ var records = []rec{ var data []byte // set by init() -func TestCdb(t *testing.T) { - tmp, err := ioutil.TempFile("", "") - if err != nil { - t.Fatalf("Failed to create temp file: %s", err) +func init() { + b := bytes.NewBuffer(nil) + for _, rec := range records { + key := rec.key + for _, value := range rec.values { + b.WriteString(fmt.Sprintf("+%d,%d:%s->%s\n", len(key), len(value), key, value)) + } } + b.WriteByte('\n') + data = b.Bytes() +} - defer os.Remove(tmp.Name()) +func TestCdbMake(t *testing.T) { + make_from_data(data, t) +} - // Test Make - err = Make(tmp, bytes.NewBuffer(data)) +func TestCdbDump(t *testing.T) { + f, buf := make_from_data(data, t), bytes.NewBuffer(nil) + err := Dump(buf, f.bytesReader()) if err != nil { - t.Fatalf("Make failed: %s", err) + t.Fatalf("Dump failed: %s", err) } + if !bytes.Equal(buf.Bytes(), data) { + t.Fatalf("Dump round-trip failed") + } +} - // Test reading records - c, err := Open(tmp.Name()) - if err != nil { - t.Fatalf("Error opening %s: %s", tmp.Name(), err) +func TestCdbGet(t *testing.T) { + f := make_from_data(data, t) + c, buf := New(f.bytesReader()), bytes.NewBuffer(nil) + for _, rec := range records { + for skip, val := range rec.values { + buf.Reset() + if _, err := Get(buf, c, []byte(rec.key), skip); err != nil { + t.Fatalf("cdb.Get failed: %s", err) + } + if !bytes.Equal(buf.Bytes(), []byte(val)) { + t.Fatalf("cdb.Get failed: expected %q, got %q", val, buf.Bytes()) + } + t.Logf("%q => %d: %q", rec.key, skip, val) + } } +} + +func TestCdbDataAndFind(t *testing.T) { + f := make_from_data(data, t) + c := New(f.bytesReader()) - _, err = c.Data([]byte("does not exist")) + _, err := c.Data([]byte("does not exist")) if err != io.EOF { t.Fatalf("non-existent key should return io.EOF") } @@ -60,6 +87,7 @@ func TestCdb(t *testing.T) { if !bytes.Equal(v, []byte(values[0])) { t.Fatal("Incorrect value returned") } + t.Logf("%q => %q", key, v) c.FindStart() for _, value := range values { @@ -68,8 +96,7 @@ func TestCdb(t *testing.T) { t.Fatalf("Record read failed: %s", err) } - data := make([]byte, sr.Size()) - _, err = sr.Read(data) + data, err := ioutil.ReadAll(sr) if err != nil { t.Fatalf("Record read failed: %s", err) } @@ -77,6 +104,8 @@ func TestCdb(t *testing.T) { if !bytes.Equal(data, []byte(value)) { t.Fatal("value mismatch") } + + t.Logf(" %q => %q", key, data) } // Read all values, so should get EOF _, err = c.FindNext(key) @@ -84,43 +113,11 @@ func TestCdb(t *testing.T) { t.Fatalf("Expected EOF, got %s", err) } } - - // Test Dump - if _, err = tmp.Seek(0, 0); err != nil { - t.Fatal(err) - } - - buf := bytes.NewBuffer(nil) - err = Dump(buf, tmp) - if err != nil { - t.Fatalf("Dump failed: %s", err) - } - - if !bytes.Equal(buf.Bytes(), data) { - t.Fatalf("Dump round-trip failed") - } } func TestEmptyFile(t *testing.T) { - tmp, err := ioutil.TempFile("", "") - if err != nil { - t.Fatalf("Failed to create temp file: %s", err) - } - - defer os.Remove(tmp.Name()) - - // Test Make - err = Make(tmp, bytes.NewBuffer([]byte("\n\n"))) - if err != nil { - t.Fatalf("Make failed: %s", err) - } - - // Check that all tables are length 0 - if _, err = tmp.Seek(0, 0); err != nil { - t.Fatal(err) - } - rb := bufio.NewReader(tmp) - readNum := makeNumReader(rb) + f := make_from_data([]byte("\n\n"), t) + readNum := makeNumReader(f.bytesReader()) for i := 0; i < 256; i++ { _ = readNum() // table pointer tableLen := readNum() @@ -129,26 +126,62 @@ func TestEmptyFile(t *testing.T) { } } - // Test reading records - c, err := Open(tmp.Name()) - if err != nil { - t.Fatalf("Error opening %s: %s", tmp.Name(), err) - } - - _, err = c.Data([]byte("does not exist")) + c := New(f.bytesReader()) + _, err := c.Data([]byte("does not exist")) if err != io.EOF { t.Fatalf("non-existent key should return io.EOF") } } -func init() { - b := bytes.NewBuffer(nil) - for _, rec := range records { - key := rec.key - for _, value := range rec.values { - b.WriteString(fmt.Sprintf("+%d,%d:%s->%s\n", len(key), len(value), key, value)) - } +func make_from_data(d []byte, t *testing.T) *memFile { + writer := &memFile{} + if err := Make(writer, bytes.NewBuffer(d)); err != nil { + t.Fatalf("Make failed: %s", err) } - b.WriteByte('\n') - data = b.Bytes() + return writer +} + +// 'memFile' is a naive implementation of a io.WriteSeeker (backed by +// a []byte-buffer) to be used in the tests without creating any real files +// +// NOTE: it might be usefull elsewhere, but the .Seek() method might +// move memFile.i (the write-position) behind the len(buf). +// memFile.growIfNeeded takes care of growing the buffer +type memFile struct { + buf []byte + i int64 +} + +func (f *memFile) Write(data []byte) (int, error) { + f.growIfNeeded(int64(len(data))) + n := copy(f.buf[f.i:], data) + f.i += int64(n) + return n, nil } + +func (f *memFile) Seek(offset int64, whence int) (abs int64, _ error) { + switch whence { + default: + return 0, errors.New("bufWriteSeeker.Seek: invalid whence") + case 0: + abs = offset + case 1: + abs = f.i + offset + case 2: + abs = int64(len(f.buf)) + offset + } + if abs < 0 { + return 0, errors.New("bufWriteSeeker.Seek: negative position") + } + f.i = abs + return +} + +// grows the buffer to hold (mw.i + n) bytes +func (f *memFile) growIfNeeded(n int64) { + if needed := ((f.i + n) - int64(len(f.buf))); needed > 0 { + f.buf = append(f.buf, make([]byte, needed)...) + } +} + +func (f *memFile) bytesReader() *bytes.Reader { return bytes.NewReader(f.buf) } diff --git a/cdbget/cdbget.go b/cdbget/cdbget.go new file mode 100644 index 0000000..6d37d1a --- /dev/null +++ b/cdbget/cdbget.go @@ -0,0 +1,50 @@ +package main + +import ( + "io" + "os" + "strconv" + + "github.com/jbarham/go-cdb" +) + +const usage = "usage: cdbget key [skip]" + +func main() { + + var ( + key []byte + skip int + err error + ) + + if len(os.Args) < 2 { + exitWithMsg(1, usage) + } + if len(os.Args) > 1 { + key = []byte(os.Args[1]) + } + if len(os.Args) > 2 { + if skip, err = strconv.Atoi(os.Args[2]); err != nil { + exitWithMsg(2, "error:", "parsing 'skip'-error", err.Error()) + } + if skip < 0 { + exitWithMsg(2, "error:", "skip parameter is invalid (negativ)") + } + } + + c := cdb.New(os.Stdin) + if _, err := cdb.Get(os.Stdout, c, key, skip); err != nil { + exitWithMsg(3, err.Error()) + } +} + +func exitWithMsg(c int, msg ...string) { + if len(msg) > 0 { + for _, m := range msg { + io.WriteString(os.Stderr, m) + } + io.WriteString(os.Stderr, "\n") + } + os.Exit(c) +} diff --git a/dump.go b/dump.go index 6772cd2..c8e9b95 100644 --- a/dump.go +++ b/dump.go @@ -44,12 +44,11 @@ func Dump(w io.Writer, r io.Reader) (err error) { } func makeNumReader(r io.Reader) func() uint32 { - buf := make([]byte, 4) - return func() uint32 { - if _, err := r.Read(buf); err != nil { + return func() (n uint32) { + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { panic(err) } - return binary.LittleEndian.Uint32(buf) + return } } diff --git a/get.go b/get.go new file mode 100644 index 0000000..6c4be23 --- /dev/null +++ b/get.go @@ -0,0 +1,19 @@ +package cdb + +import "io" + +// Get seeks 'key' in the cdb, similar to Find(), but it resembles the +// interface describe by the cdbget programm +func Get(w io.Writer, c *Cdb, key []byte, skip int) (n int64, err error) { + var record io.Reader + c.FindStart() + for ; skip >= 0; skip-- { + record, err = c.FindNext(key) + if err == io.EOF { + return 0, nil + } else if err != nil { + return 0, err + } + } + return io.Copy(w, record) +}