package main

import (
	"bytes"
	"crypto/hmac"
	"crypto/sha256"
	"encoding/hex"
	"encoding/json"
	"io/ioutil"
	"log"
	"math/rand"
	"net/http"
	"os"
	"strconv"
	"strings"
	"time"

	"github.com/gorilla/mux"
)

const (
	// Questions defines the amount of partial collisions we demand.
	Questions = 16
	// Expiration defines how long the user has time to solve the questions
	Expiration = 10 * time.Second
	// QuestionLength specifies length of the randomized input prefix.
	QuestionLength = 8
	// MaxLenBody specifies the upper limit of POST body size we handle
	MaxLenBody = 16384
	// APIPrefix is concatenated to all our endpoint URIs
	APIPrefix = "/v1/"
	// CollisionPrefix defines the static part of our input collision requirement.
	CollisionPrefix = "NIXUCOIN"
	// DigestPrefix defines the output digest prefix we demand.
	DigestPrefix = "aaaaaa"
)

// ResponseSerializer defines the interface for our API responses.
type ResponseSerializer interface {
	Serialize(*ServerCtx) []byte
	Name() string
}

// ErrorResponse is the shared error-indication prefix for our responses.
type ErrorResponse struct {
	Ok      bool   `json:"ok"`
	Message string `json:"message"`
}

// QuestionResponse is used to generate a new set of question. Signature
// is the HMAC over the questions with a server secret.
type QuestionResponse struct {
	ErrorResponse
	Questions [Questions]string `json:"questions"`
	Expires   string            `json:"expires"`
	Signature []byte            `json:"signature"`
}

// AnswerRequest specifies how to submit an answer.
type AnswerRequest struct {
	Questions [Questions]string `json:"questions"`
	Signature []byte            `json:"signature"`
	Expires   string            `json:"expires"`
	Proposals [Questions][]byte `json:"proposals"`
}

// AnswerResponse reveals whether the answer was correct.
type AnswerResponse struct {
	ErrorResponse
}

// doHMAC calculates a signature over questions and the expiration time.
// Server secret is used as the key.
func (s *ServerCtx) doHMAC(qs [Questions]string, expiration string) []byte {
	over, err := json.MarshalIndent(qs, "", "")
	if err != nil {
		panic("unable to marshal questions")
	}
	mac := hmac.New(sha256.New, s.secret)
	mac.Write(over)
	mac.Write([]byte(expiration))
	return mac.Sum(nil)
}

// Serialize for QuestionResponse takes care of the signature handling.
func (r *QuestionResponse) Serialize(s *ServerCtx) []byte {
	r.Signature = s.doHMAC(r.Questions, r.Expires)

	return GenericSerialize(r)
}

// Name means the name.
func (r *QuestionResponse) Name() string {
	return "question"
}

// IsSignatureValid performs HMAC validation.
func (r *AnswerRequest) IsSignatureValid(s *ServerCtx) bool {
	sig := s.doHMAC(r.Questions, r.Expires)
	return hmac.Equal(sig, r.Signature)
}

// GenericSerialize is the centralized handler for response JSON serialization.
func GenericSerialize(r ResponseSerializer) []byte {
	res, err := json.Marshal(r)
	if err != nil {
		panic("unable to marshal " + r.Name() + " response")
	}
	return res
}

// Serialize serializes AnswerResponses.
func (r *AnswerResponse) Serialize(*ServerCtx) []byte {
	return GenericSerialize(r)
}

// Name gives a name.
func (r *AnswerResponse) Name() string {
	return "answer"
}

// Serialize serializes ErrorResponses.
func (r *ErrorResponse) Serialize(*ServerCtx) []byte {
	return GenericSerialize(r)
}

// Name gives more names!
func (r *ErrorResponse) Name() string {
	return "error"
}

// ServerCtx defines the server context (!)
type ServerCtx struct {
	flag   string
	secret []byte
}

// Proof contains all the individual questions we want answered.
type Proof struct {
	questions [Questions]Question
}

// Question defines the sha256 prefix collision we look for and the proposed
// answer. A Question is considered valid if hexdigesting the proposal produces
// the seeked prefix.
type Question struct {
	wantedPrefix string
	proposal     []byte
}

// NewQuestion is Question ctor.
func NewQuestion(wantedPrefix string) Question {
	return Question{wantedPrefix: wantedPrefix}
}

// NewProof forms Questions out of prefixes and proposals.
func NewProof(ps [Questions]string, prs [Questions][]byte) Proof {
	qs := [Questions]Question{}
	for i := range qs {
		qs[i] = Question{
			wantedPrefix: ps[i],
		}
		qs[i].Propose(prs[i])
	}
	return Proof{questions: qs}
}

// Propose stores a proposed answer to a Question.
func (q *Question) Propose(p []byte) {
	q.proposal = p
}

// Validate validates all the Questions of a Proof in parallel.
func (p *Proof) Validate() bool {
	qc := len(p.questions)
	// The token validation parallelizes trivially.
	rets := make(chan bool, qc)
	for _, q := range p.questions {
		go func() {
			// Throttle a bit for DoS protection
			time.Sleep(10 * time.Millisecond)
			// We very careful with empty values!
			if len(q.wantedPrefix) == 0 || len(q.proposal) == 0 {
				rets <- false
			}
			// The proposal must start with prefix we insisted.
			if !bytes.HasPrefix(q.proposal, []byte(q.wantedPrefix)) {
				rets <- false
			}
			d := sha256.Sum256(q.proposal)
			dh := hex.EncodeToString(d[:])
			rets <- strings.HasPrefix(dh, DigestPrefix)
		}()
	}

	for i := 0; i < qc; i++ {
		ret := <-rets
		if !ret {
			return false
		}
	}
	return true
}

// replyOk is a generic way to return a success response.
func (s *ServerCtx) replyOk(w http.ResponseWriter, responder ResponseSerializer) {
	reply := responder.Serialize(s)
	doResponseHeaders(w, http.StatusOK)
	w.Write(reply)
	log.Printf("|%03d|%s|%-.60s", http.StatusOK, responder.Name(), reply)
}

// replyFail is the generic way to return a failure response.
func (s *ServerCtx) replyFail(w http.ResponseWriter, msg string, responder ResponseSerializer) {
	r := ErrorResponse{Ok: false, Message: msg}
	reply := r.Serialize(nil)
	doResponseHeaders(w, http.StatusBadRequest)
	w.Write(reply)
	log.Printf("|%03d|%s|%-.60s", http.StatusBadRequest, responder.Name(), reply)
}

// doResponseHeaders adds the shared details for all our responses.
func doResponseHeaders(w http.ResponseWriter, code int) {
	w.Header().Set("Content-type", "application/json")
	w.WriteHeader(code)
}

// generateQuestions generates random questions for answer requests.
func (s *ServerCtx) generateQuestions() [Questions]string {
	// For added security, create a new random source with fresh timestamp.
	r := rand.New(rand.NewSource(time.Now().Unix() / 100))
	alphabet := []byte("0123456789abcdf")
	qs := [Questions]string{}
	for i := 0; i < Questions; i++ {
		suffix := make([]byte, QuestionLength)
		for i := range suffix {
			suffix[i] = alphabet[r.Intn(len(alphabet))]
		}
		qs[i] = CollisionPrefix + string(suffix)
	}
	return qs
}

// question is the `question' endpoint.
func (s *ServerCtx) question(w http.ResponseWriter, r *http.Request) {
	s.replyOk(w, &QuestionResponse{
		Questions: s.generateQuestions(),
		Expires:   strconv.FormatInt(time.Now().Add(Expiration).Unix(), 10),
		ErrorResponse: ErrorResponse{
			Ok:      true,
			Message: "answer these and GET TO THE ROCKET!!!",
		}})
}

// answer is the `answer' endpoint.
func (s *ServerCtx) answer(w http.ResponseWriter, r *http.Request) {
	fail := func(msg string) {
		s.replyFail(w, msg, &AnswerResponse{})
	}
	mbr := http.MaxBytesReader(w, r.Body, MaxLenBody)
	req := AnswerRequest{}
	d := json.NewDecoder(mbr)
	if err := d.Decode(&req); err != nil {
		s.replyFail(w, "invalid answer body", &AnswerResponse{})
		return
	}

	if len(req.Questions) != Questions || len(req.Proposals) != Questions {
		fail("invalid question/proposal count")
		return
	}

	if !req.IsSignatureValid(s) {
		fail("invalid signature")
		return
	}

	// Validate the answer set.
	p := NewProof(req.Questions, req.Proposals)
	if !p.Validate() {
		s.replyOk(w, &AnswerResponse{
			ErrorResponse: ErrorResponse{
				Ok:      false,
				Message: "no ICO for you!",
			}})
		return
	}

	s.replyOk(w, &AnswerResponse{
		ErrorResponse: ErrorResponse{
			Ok:      true,
			Message: s.flag}})
}

// NewServerCtx is the ServerCtx constructor.
func NewServerCtx(fn string) *ServerCtx {
	flag, err := ioutil.ReadFile(fn)
	if err != nil {
		log.Fatal("Unable to get reward: ", err)
	}
	return &ServerCtx{flag: strings.TrimSpace(string(flag))}
}

// getRouter creates a new http mux for our endpoints.
func (s *ServerCtx) getRouter() *mux.Router {
	type endpoint struct {
		uri     string
		handler http.HandlerFunc
		method  string
	}

	endpoints := []endpoint{
		{"question", s.question, "GET"},
		{"answer", s.answer, "POST"},
	}
	r := mux.NewRouter()
	for _, e := range endpoints {
		uri := APIPrefix + e.uri
		log.Printf("Adding endpoint `%s' (%s)", uri, e.method)
		z := r.HandleFunc(uri, e.handler).
			Methods(e.method)
		if e.method == "POST" {
			z.HeadersRegexp("Content-Type", "application/json")
		}

	}
	return r
}

func main() {
	secret := os.Getenv("VERY_SECRET")
	if len(secret) < 10 {
		log.Fatal("Give a decent secret.")
	}
	s := NewServerCtx("reward.txt")
	r := s.getRouter()

	addr := ":5000"
	if oa := os.Getenv("ADDR"); len(oa) > 0 {
		addr = oa
	}

	log.Print("Starting listener at ", addr)
	log.Fatal(http.ListenAndServe(addr, r))
}
