#
#   Spam Server -- Cofunction version
#

from cofunctions import *
from socket import *
from scheduler import *

port = 4200

class BadRequest(Exception):
  pass

@cofunction
def sock_accept(cocall, sock):
  yield cocall(block_for_reading, sock)
  coreturn(sock.accept())

@cofunction
def sock_readline(cocall, sock):
  buf = b""
  while buf[-1:] != b"\n":
    yield cocall(block_for_reading, sock)
    data = sock.recv(1024)
    if not data:
      break
    buf += data
  if not buf:
    close_fd(sock)
  coreturn(buf)

@cofunction
def sock_write(cocall, sock, data):
  while data:
    yield cocall(block_for_writing, sock)
    n = sock.send(data)
    data = data[n:]

@cofunction
def listener(cocall):
  lsock = socket(AF_INET, SOCK_STREAM)
  lsock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
  lsock.bind(("", port))
  print("Server listening for connections on localhost port", port)
  lsock.listen(5)
  while 1:
    csock, addr = yield cocall(sock_accept, lsock)
    print("Listener: Accepted connection from", addr)
    schedule(costart(handler, csock))

@cofunction
def handler(cocall, sock):
  while 1:
    line = yield cocall(sock_readline, sock)
    if not line:
      break
    try:
      n = parse_request(line)
      yield cocall(sock_write, sock, b"100 SPAM FOLLOWS\n")
      for i in range(n):
        yield cocall(sock_write, sock, b"spam glorious spam\n")
    except BadRequest:
      yield cocall(sock_write, sock, b"400 WE ONLY SERVE SPAM\n")

def parse_request(line):
  tokens = line.split()
  if len(tokens) != 2 or tokens[0] != b"SPAM":
    raise BadRequest
  try:
    n = int(tokens[1])
  except ValueError:
    raise BadRequest
  if n < 1:
    raise BadRequest
  return n

coschedule(listener)
run2()