#!/usr/bin/python
"""
Usage: ale <action> [options...]

Actions:
  init      Initialize an empty ale namespace, rooted at this dir
  freeze    Update the cask metadata DB from the file system state
            After a file is checksummed, it is made read-only.
            TODO: accept arg for efficiency?
            --progress shows sha1sum progress
            --preview shows the files that haven't yet been frozen
  backup    backup any new files, using the configuration file in
            .ale/config.txt.  Potentially invokes 'send' multiple times.
  send      Write a list of blobs on stdout.  Piped to 'cask receieve'.
  scp       Alternative to ale send / cask receive, using plain scp (may incur
            connection latency)
  manifest  Print a manifest that can be used as input to "cask migrate".
            (The checksums are read from the metadata)

File Management
  mv        'ale mv' is a more efficient way to do Unix 'mv' and then
            'ale freeze'
  rm        'ale rm' is a more efficient way to do Unix 'rm' and then 'ale
             freeze'
  ls        List files in the repo (frozen files only; Unix 'ls' is for
            unfrozen files)

  du        show total size of files, but faster.  Consults the cask
            metadata rather than the file system.
  space     Show total disk space used
            (algorithm: scan DB)

  log       Dump the entry log
  check     Compare files on disk with recorded checksums.

  config    Show/edit the config that lists backing casks
  casks     List casks?  Run "cask status" remotely, in parallel, with
            shell script.
            --verbose will pass xargs --verbose to see the commands.
            NOTE: This could be another entry point:

            ale casks        # status by default - rowid/counter, space, if
                             # it's online/valid
            ale casks status # explicit
            ale casks stats  # checksum stats, i.e. last verified
            ale casks init   # could even recursively init according to the
                             # backup.cfg file!

  clean-casks print out a list of commands to delete old files?
              probably log these to a text file?

  debug     Dump state for debugging purposes.

"""

import errno
import os
import subprocess
import sys
import time

import util


class RepoCorruptError(Exception):
  """Raised when invariants on the ale repo are violated."""


# Don't use autoincrement, just use the row ID
# https://www.sqlite.org/autoinc.html

# NOTE:
# - maximum row ID can serve as a key for a cache of the current directory
# structure


CREATE_SQL = """\
-- NOTE: implicit END TRANSACTION here by executescript, which is retarded!

-- Immutable log of files in the ale repo
CREATE TABLE entry_log(
  action      TEXT,    -- '+', or '-' for add/remove
  timestamp   INTEGER, -- timestamp of action
  rel_path    TEXT,    -- e.g. path within the ale namespace, e.g. 'foo/bar'
  num_bytes   INTEGER, -- file size
  mtime       INTEGER, -- file system mtime of the file, if add.  This is used
                       -- as a sanity check.
  sha1        BLOB     -- checksum of the file, if add
);

-- This whole table is for auditing purposes I guess?  Could get rid of it.
CREATE TABLE DEBUG_freeze_log(
  timestamp       INTEGER, -- timestamp of freeze, for UI only
  num_files       INTEGER, -- number of files frozen, for UI only
  total_bytes     INTEGER, -- total bytes, for UI only
  wall_time_secs  REAL,    -- how long it took to freeze
  ale_counter     INTEGER  -- latest rowid written to the metadata DB
);

-- Immutable log of successful backup jobs
CREATE TABLE backup_log(
  cask_host    TEXT,    -- id of cask
  cask_dir     TEXT,
  ale_counter  INTEGER, -- latest rowid backed up
  timestamp    INTEGER  -- time backed up, for UI only
);

BEGIN TRANSACTION;
"""

def FindAleRootDir(current_dir):
  while True:
    ale_dir = os.path.join(current_dir, '.ale')
    if os.path.exists(ale_dir):
      return current_dir
    parent_dir = os.path.dirname(current_dir)
    if parent_dir == current_dir:  # reached /
      return None
    current_dir = parent_dir


def FindAleRepoOrDie():
  """
  Find the ale repo the current directory lives in.  An ale repo is rooted at a
  dir with an '.ale' subdirectory.
  """
  cwd = os.getcwd()
  ale_root = FindAleRootDir(cwd)
  if ale_root is None:
    raise RuntimeError("Couldn't find .ale directory (%s)" % cwd)
  db_name = os.path.join(ale_root, '.ale/metadata.sqlite3')
  return ale_root, db_name


def ParseConfig(f):
  """Parse the backup.cfg file."""
  config = []
  for line in f:
    line = line.strip()
    if not line:
      continue
    if line.startswith('#'):
      continue
    cask_host, cask_dir = line.split(None, 2)
    config.append((cask_host, cask_dir))
  return config


def ParseConfigOrDie(ale_root):
  config_path = os.path.join(ale_root, '.ale/backup.cfg')
  try:
    with open(config_path) as f:
      config = ParseConfig(f)
  except IOError as e:
    if e.errno == errno.ENOENT:
      raise RuntimeError('No casks configured (%s not found).' %
          config_path)
    else:
      raise
  return config


# mv algorithm:
# - search for rel_path that starts with the path.  Delete them all, and then
# add them back?


def AleInit(cursor):
  cursor.executescript(CREATE_SQL)


def List(cursor, args):
  """List files that are known to the metadata DB."""
  db_state = ReconstructState(cursor)

  debug_unicode = 0
  if debug_unicode:
    fmt = '%s %r'  # see raw bytes to see if it is utf-8.  YUP!
  else:
    fmt = '%s %s'

  # TODO: maybe ls -l shows mtime and size
  for name in sorted(db_state):
    num_bytes, mtime, sha1_bytes = db_state[name]
    print fmt % (sha1_bytes.encode('hex'), name)


class UpdateHandler(object):
  """
  Receives notice about files on the file system, compares it to existing
  state, and figures out which rows to write back into the database.
  """

  def __init__(self, db_state, printer):
    """
    Args:
      db_state: existing state
    """
    self.db_state = db_state
    self.printer = printer

    self.files_read = 0
    self.bytes_read = 0
    self.rows_to_insert = []

  def OnFile(self, full_path, rel_path, lstat):
    """Called with each file in the ale repo."""
    if rel_path in self.db_state:
      return  # do nothing

    num_bytes = lstat.st_size
    self.printer.OnChecksumBegin(rel_path, num_bytes)

    with open(full_path) as f:
      sha1_bytes = util.ChecksumFile(f, self.printer)

    # After the progress bar
    sys.stdout.write(' ' * 15)
    print sha1_bytes.encode('hex')

    # Whole second resolution is enough, and I think more portable.
    timestamp = int(time.time())  # no time zone for now
    mtime_int = int(lstat.st_mtime)

    row = ('+', timestamp, rel_path, num_bytes, mtime_int, buffer(sha1_bytes))
    self.rows_to_insert.append(row)

    self.files_read += 1
    self.bytes_read += num_bytes

  def GetResults(self):
    return self.rows_to_insert, self.files_read, self.bytes_read


def ReconstructState(cursor, rowid_range=None):
  """Play back the entry log and reconstruct the current files from it.

  Args:
    cursor: sqlite cursor
    rowid_range: An inclusive range as a tuple (min, max); or None if all rows
       should be returned

  Returns:
    A dict of { rel_path -> (num_bytes, mtime, sha1) }
  """
  # NOTE:
  # - rel_path should be normalized so it doesn't end with / ever

  base_query = """
    SELECT action, rel_path, num_bytes, mtime, sha1
    FROM entry_log
    """
  if rowid_range:
    query_str = base_query + 'WHERE ? <= rowid AND rowid <= ?'
    args = rowid_range
  else: 
    # Range is inculsive on both sides, like [3, 5].
    query_str = base_query
    args = ()

  state = {}
  for action, rel_path, num_bytes, mtime, sha1 in cursor.execute(
      query_str, args):
    if action == '+':
      state[rel_path] = (num_bytes, mtime, str(sha1))
    elif action == '-':
      del state[rel_path]  # TODO: exceptions
    else:
      raise RepoCorruptError('Invalid action %r' % action)
  return state


def Freeze(cursor, ale_root, args):
  """Make the metadata.sqlite3 DB reflect the file system state.

  Most work is in checksumming the files.

  Right now we assume
  """
  # TODO:
  # - don't checksum if the mtime matches that of the file in db_state.  if
  # the name matches, but the mtime doesn't, raise an error?  Files should not
  # be modified.
  # - Make the files read-only
  # - Also need handle deletions!  UpdateHandler should another dict and
  # compare it against db_state.  Set difference.

  start_time = time.time()

  db_state = ReconstructState(cursor)

  # TODO: Parameterize this on whether sys.stdout.isatty()
  printer = util.FancyPrinter()

  handler = UpdateHandler(db_state, printer)
  file_count, total_bytes = util.WalkTree(ale_root, '', handler)

  wall_time_secs = time.time() - start_time

  new_rows, files_read, bytes_read = handler.GetResults()

  mb_read = float(bytes_read) / 1e6
  util.stderr('Found %d files.', file_count)
  util.stderr(
      'Checksummed %d files of %.1f MB in %.1f seconds (%.1f MB/s).',
      files_read, mb_read, wall_time_secs, mb_read/wall_time_secs)

  if not new_rows:
    util.stderr('No new files in repo.')
    return

  cursor.executemany(
      'INSERT INTO entry_log VALUES(?, ?, ?, ?, ?, ?);', new_rows)

  util.stderr('Added %d files to repo.', len(new_rows))

  ale_counter = None
  for (latest_rowid,) in cursor.execute('SELECT MAX(rowid) FROM entry_log'):
    ale_counter = latest_rowid
  assert ale_counter is not None  # we would have done an early return

  wall_time_secs = time.time() - start_time
  freeze_row = (
      int(start_time), file_count, total_bytes, wall_time_secs,
      latest_rowid)

  cursor.execute(
      'INSERT INTO DEBUG_freeze_log VALUES(?, ?, ?, ?, ?);', freeze_row)


def GetFilesToBackup(cursor, latest_rowid, last_backup_rowid):
  # Range is inclusive.  Don't include the last backed up row.
  rowid_range = (last_backup_rowid + 1, latest_rowid)
  db_state = ReconstructState(cursor, rowid_range=rowid_range)
  for rel_path, (_, _, sha1_bytes) in db_state.iteritems():
    yield rel_path, sha1_bytes


def Backup(cursor, ale_root, config):
  """
  TODO: 
  - 'cask status' for everything in the config file
  - figure out what to copy where
  - call 'ale scp' to copy them


  Model:
  - A sequence of casks backs an ale.
  - A cask repo belongs to a single ale repo

  Should I accept explicit cask args?  Will that change the state?
  Policy:
  - always choose the last one, except when it's full.  Then choose a new one
    based on the config file?

  - Keep state in the ale about the last rowid backed up
  - And then when you to "ale backup", print a message like "ale was last
    backed up to cask homer.local:foo.cask at 2/01/16 at 9:34pm".  And then
    make sure the latest_rowid in cask and ale matches.  Otherwise you have to
    print a "missing cask" message.
  """
  # Later algorithm:
  # - Parse the config file to get a list of cask (host, dir)
  # - Run 'cask space' in parallel to get (host, dir, space)
  #
  # - Yeah it would be convenient to get (space, rowid) for each cask.  And
  # then choose the biggest one?  And then back up since the biggest.
  #   - Example: Current ale row ID is 500
  #   - One cask has 100-400, and the other one has 400-450
  #   - Then you know you need to sent rows 450-500
  #     - Do you care about deletions, or is that a separate "cask
  #     reclaim-space" step/
  #
  # What about offline casks?  I think you can comment them out of the
  # config file.

  # Get the maximum
  latest_rowid = -1
  for (rowid,) in cursor.execute('SELECT MAX(rowid) FROM entry_log'):
    latest_rowid = rowid
  if latest_rowid == -1:
    raise RepoCorruptError('Missing rowid in entry_log table')
  print 'LATEST rowid', latest_rowid

  # TODO: Find the last TWO freeze counters?
  # difference this vs. the last backup coutner!

  # Now find the cask with the maximum ID.  That is the last one we wrote to.
  #
  # Write to that one, unless it's full!
  last_backup_rowid = None
  last_cask_host = None
  last_cask_dir = None
  for (last_cask_host, last_cask_dir, last_backup_rowid,) in cursor.execute(
      """
      SELECT   cask_host, cask_dir, MAX(ale_counter)
      FROM     backup_log
      GROUP BY ale_counter
      """):
    pass

  if last_backup_rowid is None:
    last_backup_rowid = 0
    # Just use the first one for now
    cask_host, cask_dir = config[0]
    util.stderr('No backups yet!')
  else:
    cask_host = last_cask_host
    cask_dir = last_cask_dir

  print cask_host, cask_dir, last_backup_rowid

  path_sha1_pairs = list(
      GetFilesToBackup(cursor, latest_rowid, last_backup_rowid))

  Scp(ale_root, cask_host, cask_dir, path_sha1_pairs)

  new_ale_counter = latest_rowid
  backup_timestamp = int(time.time())
  backup_row = (cask_host, cask_dir, new_ale_counter, backup_timestamp)
  cursor.execute(
      'INSERT INTO backup_log VALUES(?, ?, ?, ?);', backup_row)


def Scp(ale_root, cask_host, cask_dir, path_sha1_pairs):
  # Query rel_path by sha1
  for rel_path, sha1_bytes in path_sha1_pairs:
    sha1_hex = sha1_bytes.encode('hex')
    dir_part = sha1_hex[:3]
    name_part = sha1_hex[3:]

    # Make the dir via SSH first.
    dest_dir = os.path.join(cask_dir, dir_part)
    ssh_argv = ['ssh', cask_host, 'mkdir', '-p', dest_dir]
    exit_code = subprocess.call(ssh_argv)
    if exit_code != 0:
      raise RuntimeError('%s failed with code %d' % (ssh_argv, exit_code))

    # Copy via SCP.
    src = os.path.join(ale_root, rel_path)
    dest = '%s:%s/%s/%s' % (cask_host, cask_dir, dir_part, name_part)
    scp_argv = ['scp', src, dest]

    exit_code = subprocess.call(scp_argv)
    if exit_code != 0:
      raise RuntimeError('%s failed with code %d' % (scp_argv, exit_code))


def AllCaskStatus(config):
  # Hm this should be a flag
  cask_cmd = '/home/andy/bin/cask'

  # TODO:
  # - Verify that the are attached to this 'ale' repo.
  # - Print the counter -- important state used for backup
  # - And then, number of files, space used, and last backup timestamp for
  # information
  #   - is the backup timestamp local or remote?  Maybe reconcile them
  #
  # Should you have a --json flag?  Probably
  for cask_host, cask_dir in config:
    argv = ['ssh', cask_host, cask_cmd, 'status', cask_dir]
    subprocess.check_call(argv)
    #if exit_code != 0:
    #  raise RuntimeError('%s failed with code %d' % (scp_argv, exit_code))


def main(argv):
  try:
    action = argv[1]
  except IndexError:
    action = 'help'

  if action == 'help':
    print __doc__

  elif action == 'init':
    try:
      ale_root = argv[2]
    except IndexError:
      ale_root = os.getcwd()
    util.MakeDir(ale_root)

    ale_dir = os.path.join(ale_root, '.ale')
    if os.path.exists(ale_dir):
      util.stderr('%s already exists', ale_dir)
      return 1

    util.MakeDir(ale_dir)

    db_name = os.path.join(ale_dir, 'metadata.sqlite3')

    with util.SqliteCursor(db_name, create=True) as cursor:
      AleInit(cursor)

    util.stderr('Initialized %s', db_name)

  elif action == 'ls':
    # Turn this into a context manager?
    # enter/exit
    # exit closes it
    # with AleMetadata() as conn:
    # 
    ale_root, db_name = FindAleRepoOrDie()
    with util.SqliteCursor(db_name) as cursor:
      List(cursor, argv[2:])

  elif action == 'freeze':
    ale_root, db_name = FindAleRepoOrDie()

    with util.SqliteCursor(db_name) as cursor:
      Freeze(cursor, ale_root, argv[2:])

  elif action == 'backup':
    ale_root, db_name = FindAleRepoOrDie()

    config = ParseConfigOrDie(ale_root)

    with util.SqliteCursor(db_name) as cursor:
      Backup(cursor, ale_root, config)

  elif action == 'scp':
    # Usage:
    #   ale scp HOST CASK <sha1>...
    ale_root, db_name = FindAleRepoOrDie()

    cask_host = argv[2]
    cask_dir = argv[3]
    rest = argv[4:]
    n = len(rest)
    if n % 2 != 0:
      raise RuntimeError('scp takes an even number of args')
    path_sha1_pairs = []
    for i in xrange(n/2):
      m = i * 2
      name = rest[m]
      sha1_hex = rest[m+1]
      path_sha1_pairs.append((name, sha1_hex.decode('hex')))

    Scp(ale_root, cask_host, cask_dir, path_sha1_pairs)

  elif action == 'casks':
    ale_root, db_name = FindAleRepoOrDie()
    config = ParseConfigOrDie(ale_root)
    AllCaskStatus(config)

  elif action == 'debug':
    # Debugging only

    subaction = argv[2]

    ale_root, db_name = FindAleRepoOrDie()
    with util.SqliteCursor(db_name) as cursor:
      if subaction == 'freeze-log':
        for result in cursor.execute('SELECT * FROM DEBUG_freeze_log'):
          print result
      else:
        raise RuntimeError('Invalid debug action %r' % subaction)

  else:
    raise RuntimeError('Invalid action %r' % action)


if __name__ == '__main__':
  try:
    sys.exit(main(sys.argv))
  except RuntimeError, e:
    print >>sys.stderr, 'FATAL: %s' % e
    sys.exit(1)