is-a-dev.com / autoapi

Automatic api generation from an SQL database, complete with http API endpoint scaffolding code and preconditions checking.

checksumGenerator.go2.0KB

package lib

import (
	"bytes"
	"fmt"
	"go/format"
	"io"
	"os"
	"text/template"

	"golang.org/x/tools/imports"
)

type checksumGenerator struct {
}

func (g *checksumGenerator) Generate(tables map[string]tableInfo) error {
	err := os.Mkdir("db", 0755)
	if err != nil && !os.IsExist(err) {
		return err
	}
	tmpl := template.Must(template.New("dbchecksum").Parse(`//WARNING.
//THIS HAS BEEN GENERATED AUTOMATICALLY BY AUTOAPI.
//IF THERE WAS A WARRANTY, MODIFYING THIS WOULD VOID IT.

package db

//Checksum is an autoapi-generated checksum of the state of the database, at time of generation.
func Checksum() string{
    return "{{.}}"
}

//ValidateChecksum compares the checksum generated by Autoapi to the current state of the db,
//returning an error if they don't match.
func ValidateChecksum(db *sql.DB, dbName string) error {
     b, err := lib.DatabaseChecksum(db, dbName)	
     if err != nil {
         return err
     }

	if fmt.Sprintf("%x", b) != Checksum() {
		fmt.Println(fmt.Sprintf("%x", b))
		fmt.Println(Checksum())
		return ErrBadDatabaseChecksum
	}
     return nil
}

//MustValidateChecksum compares the checksum against the database, and panics if they don't match.
//Useful when you absolutely don't want to run the software against a non-matching version of the db.
func MustValidateChecksum(db *sql.DB, dbName string) {
    if err := ValidateChecksum(db, dbName); err != nil{
       panic(err)
    }
}

var ErrBadDatabaseChecksum = errors.New("The code doesn't match the database's structure.")

`))

	f, err := os.Create("db/checksum.go")
	if err != nil && !os.IsExist(err) {
		return err
	}
	var b bytes.Buffer
	err = tmpl.Execute(&b, fmt.Sprintf("%x", codeChecksum(tables)))
	if err != nil {
		return err
	}
	bf, err := format.Source(b.Bytes())
	if err != nil {
		fmt.Println(b.String())
		fmt.Println(err)
		return err
	}
	bf, err = imports.Process(f.Name(), bf, nil)
	if err != nil {
		return err
	}
	_, err = io.Copy(f, bytes.NewBuffer(bf))
	if err != nil {
		return err
	}
	return nil
}