Logo Search packages:      
Sourcecode: ubuntuone-client version File versions

client.py

# ubuntuone.u1sync.client
#
# Client/protocol end of u1sync
#
# Author: Lucio Torre <lucio.torre@canonical.com>
# Author: Tim Cole <tim.cole@canonical.com>
#
# Copyright 2009 Canonical Ltd.
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 3, as published
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranties of
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.
"""Pretty API for protocol client."""

from __future__ import with_statement

import os
import sys
import shutil
from Queue import Queue, Empty
from threading import Thread, Lock
import zlib
import urlparse
import ConfigParser
from cStringIO import StringIO
from OpenSSL import SSL

from twisted.internet import reactor, defer, ssl
from twisted.internet.defer import inlineCallbacks, returnValue
from ubuntuone.storageprotocol.hash import crc32
from ubuntuone.storageprotocol.context import get_ssl_context
from ubuntuone.oauthdesktop.config import get_config \
                                                 as get_oauth_config
from ubuntuone.oauthdesktop.auth import AuthorisationClient
from ubuntuone.u1sync.genericmerge import MergeNode
from ubuntuone.u1sync.utils import should_sync

CONSUMER_KEY = "ubuntuone"

from ubuntuone.storageprotocol.oauth import OAuthConsumer
from ubuntuone.storageprotocol.client import (
    StorageClientFactory, StorageClient)
from ubuntuone.storageprotocol import request
from ubuntuone.storageprotocol.dircontent_pb2 import \
    DirectoryContent, DIRECTORY
import uuid

def share_str(share_uuid):
    """Converts a share UUID to a form the protocol likes."""
    return str(share_uuid) if share_uuid is not None else request.ROOT


00061 class SyncStorageClient(StorageClient):
    """Simple client that calls a callback on connection."""

00064     def connectionMade(self):
        """Setup and call callback."""
        StorageClient.connectionMade(self)
        if self.factory.current_protocol not in (None, self):
            self.factory.current_protocol.transport.loseConnection()
        self.factory.current_protocol = self
        self.factory.observer.connected()

00072     def connectionLost(self, reason=None):
        """Callback for established connection lost"""
        if self.factory.current_protocol is self:
            self.factory.current_protocol = None
            self.factory.observer.disconnected(reason)


00079 class SyncClientFactory(StorageClientFactory):
    """A cmd protocol factory."""
    # no init: pylint: disable-msg=W0232

    protocol = SyncStorageClient

00085     def __init__(self, observer):
        """Create the factory"""
        self.observer = observer
        self.current_protocol = None

00090     def clientConnectionFailed(self, connector, reason):
        """We failed at connecting."""
        self.current_protocol = None
        self.observer.connection_failed(reason)


00096 class UnsupportedOperationError(Exception):
    """The operation is unsupported by the protocol version."""


00100 class ConnectionError(Exception):
    """A connection error."""


00104 class AuthenticationError(Exception):
    """An authentication error."""


00108 class NoSuchShareError(Exception):
    """Error when there is no such share available."""


00112 class Client(object):
    """U1 storage client facade."""

00115     def __init__(self, realm):
        """Create the instance."""

        self.thread = Thread(target=self._run)
        self.thread.setDaemon(True)
        self.factory = SyncClientFactory(self)

        self._status_lock = Lock()
        self._status = "disconnected"
        self._status_reason = None
        self._status_waiting = []

        self.realm = realm

        oauth_config = get_oauth_config()
        if oauth_config.has_section(realm):
            config_section = realm
        elif self.realm.startswith("http://localhost") and \
             oauth_config.has_section("http://localhost"):
            config_section = "http://localhost"
        else:
            config_section = "default"

        def get_oauth_option(option):
            """Retrieves an option from oauth config."""
            try:
                return oauth_config.get(config_section, option)
            except ConfigParser.NoOptionError:
                return oauth_config.get("default", option)

        def get_oauth_url(option):
            """Retrieves an absolutized URL from the OAuth config."""
            suffix = get_oauth_option(option)
            return urlparse.urljoin(realm, suffix)

        self.consumer_key = CONSUMER_KEY
        self.consumer_secret = get_oauth_option("consumer_secret")

        self.request_token_url = get_oauth_url("request_token_url")
        self.user_authorisation_url = get_oauth_url("user_authorisation_url")
        self.access_token_url = get_oauth_url("access_token_url")

00157     def obtain_oauth_token(self, create_token):
        """Obtains an oauth token, optionally creating one if requried."""
        token_result = Queue()

        def have_token(token):
            """When a token is available."""
            token_result.put(token)

        def no_token():
            """When no token is available."""
            token_result.put(None)

        oauth_client = AuthorisationClient(realm=self.realm,
                                           request_token_url=
                                           self.request_token_url,
                                           user_authorisation_url=
                                           self.user_authorisation_url,
                                           access_token_url=
                                           self.access_token_url,
                                           consumer_key=self.consumer_key,
                                           consumer_secret=
                                           self.consumer_secret,
                                           callback_parent=have_token,
                                           callback_denied=no_token,
                                           do_login=create_token)

        def _obtain_token():
            """Obtains or creates a token."""
            if create_token:
                oauth_client.clear_token()
            oauth_client.ensure_access_token()

        reactor.callFromThread(_obtain_token)
        token = token_result.get()
        if token is None:
            raise AuthenticationError("Unable to obtain OAuth token.")
        return token

00195     def _change_status(self, status, reason=None):
        """Changes the client status.  Usually called from the reactor
        thread.

        """
        with self._status_lock:
            self._status = status
            self._status_reason = reason
            waiting = self._status_waiting
            if len(waiting) > 0:
                self._status_waiting = []
                for waiter in waiting:
                    waiter.put((status, reason))

00209     def _await_status_not(self, *ignore_statuses):
        """Blocks until the client status changes, returning the new status.
        Should never be called from the reactor thread.

        """
        with self._status_lock:
            status = self._status
            reason = self._status_reason
            while status in ignore_statuses:
                waiter = Queue()
                self._status_waiting.append(waiter)
                self._status_lock.release()
                try:
                    status, reason = waiter.get()
                finally:
                    self._status_lock.acquire()
            return (status, reason)

00227     def connection_failed(self, reason):
        """Notification that connection failed."""
        self._change_status("disconnected", reason)

00231     def connected(self):
        """Notification that connection succeeded."""
        self._change_status("connected")

00235     def disconnected(self, reason):
        """Notification that we were disconnected."""
        self._change_status("disconnected", reason)

00239     def _run(self):
        """Run the reactor in bg."""
        reactor.run(installSignalHandlers=False)

00243     def start(self):
        """Start the reactor thread."""
        self.thread.start()

00247     def stop(self):
        """Shut down the reactor."""
        reactor.callWhenRunning(reactor.stop)
        self.thread.join(1.0)

00252     def defer_from_thread(self, function, *args, **kwargs):
        """Do twisted defer magic to get results and show exceptions."""

        queue = Queue()
        def runner():
            """inner."""
            # we do want to catch all
            # no init: pylint: disable-msg=W0703
            try:
                d = function(*args, **kwargs)
                if isinstance(d, defer.Deferred):
                    d.addCallbacks(lambda r: queue.put((r, None, None)),
                                   lambda f: queue.put((None, None, f)))
                else:
                    queue.put((d, None, None))
            except Exception, e:
                queue.put((None, sys.exc_info(), None))

        reactor.callFromThread(runner)
        while True:
            try:
                # poll with a timeout so that interrupts are still serviced
                result, exc_info, failure = queue.get(True, 1)
                break
            except Empty: # pylint: disable-msg=W0704
                pass
        if exc_info:
            raise exc_info[1], None, exc_info[2]
        elif failure:
            failure.raiseException()
        else:
            return result

00285     def connect(self, host, port):
        """Connect to host/port."""
        def _connect():
            """Deferred part."""
            reactor.connectTCP(host, port, self.factory)
        self._connect_inner(_connect)

00292     def connect_ssl(self, host, port, no_verify):
        """Connect to host/port using ssl."""
        def _connect():
            """deferred part."""
            ctx = get_ssl_context(no_verify)

            reactor.connectSSL(host, port, self.factory,
                                ctx)

        self._connect_inner(_connect)

00303     def _connect_inner(self, _connect):
        """Helper function for connecting."""
        self._change_status("connecting")
        reactor.callFromThread(_connect)
        status, reason = self._await_status_not("connecting")
        if status != "connected":
            raise ConnectionError(reason.value)

00311     def disconnect(self):
        """Disconnect."""
        if self.factory.current_protocol is not None:
            reactor.callFromThread(
                self.factory.current_protocol.transport.loseConnection)
        self._await_status_not("connecting", "connected", "authenticated")

00318     def oauth_from_token(self, token):
        """Perform OAuth authorisation using an existing token."""

        consumer = OAuthConsumer(self.consumer_key, self.consumer_secret)

        def _auth_successful(value):
            """Callback for successful auth.  Changes status to
            authenticated."""
            self._change_status("authenticated")
            return value

        def _auth_failed(value):
            """Callback for failed auth.  Disconnects."""
            self.factory.current_protocol.transport.loseConnection()
            return value

        def _wrapped_authenticate():
            """Wrapped authenticate."""
            d = self.factory.current_protocol.oauth_authenticate(consumer,
                                                                 token)
            d.addCallbacks(_auth_successful, _auth_failed)
            return d

        try:
            self.defer_from_thread(_wrapped_authenticate)
        except request.StorageProtocolError, e:
            raise AuthenticationError(e)
        status, reason = self._await_status_not("connected")
        if status != "authenticated":
            raise AuthenticationError(reason.value)

00349     def get_root_info(self, share_uuid):
        """Returns the UUID of the applicable share root."""
        if share_uuid is None:
            _get_root = self.factory.current_protocol.get_root
            root = self.defer_from_thread(_get_root)
            return (uuid.UUID(root), True)
        else:
            str_share_uuid = str(share_uuid)
            share = self._match_share(lambda s: str(s.id) == str_share_uuid)
            return (uuid.UUID(str(share.subtree)),
                    share.access_level == "Modify")

00361     def resolve_path(self, share_uuid, root_uuid, path):
        """Resolve path relative to the given root node."""

        @inlineCallbacks
        def _resolve_worker():
            """Path resolution worker."""
            str_share_uuid = share_str(share_uuid)
            node_uuid = root_uuid
            local_path = path.strip('/')

            while local_path != '':
                local_path, name = os.path.split(local_path)
                hashes = yield self._get_node_hashes(share_uuid, [root_uuid])
                content_hash = hashes.get(root_uuid, None)
                if content_hash is None:
                    raise KeyError, "Content hash not available"
                entries = yield self._get_raw_dir_entries(share_uuid,
                                                          root_uuid,
                                                          content_hash)
                match_name = name.decode('utf-8')
                match = None
                for entry in entries:
                    if match_name == entry.name:
                        match = entry
                        break

                if match is None:
                    raise KeyError, "Path not found"

                node_uuid = uuid.UUID(match.node)

            returnValue(node_uuid)

        return self.defer_from_thread(_resolve_worker)

00396     def find_share(self, share_spec):
        """Finds a share matching the given UUID.  Looks at both share UUIDs
        and root node UUIDs."""
        share = self._match_share(lambda s: str(s.id) == share_spec or \
                                            str(s.subtree) == share_spec)
        return uuid.UUID(str(share.id))

00403     def _match_share(self, predicate):
        """Finds a share matching the given predicate."""
        _list_shares = self.factory.current_protocol.list_shares
        r = self.defer_from_thread(_list_shares)
        for share in r.shares:
            if predicate(share) and share.direction == "to_me":
                return share
        raise NoSuchShareError()

00412     def build_tree(self, share_uuid, root_uuid):
        """Builds and returns a tree representing the metadata for the given
        subtree in the given share.

        @param share_uuid: the share UUID or None for the user's volume
        @param root_uuid: the root UUID of the subtree (must be a directory)
        @return: a MergeNode tree

        """
        root = MergeNode(node_type=DIRECTORY, uuid=root_uuid)

        @inlineCallbacks
        def _get_root_content_hash():
            """Obtain the content hash for the root node."""
            result = yield self._get_node_hashes(share_uuid, [root_uuid])
            returnValue(result.get(root_uuid, None))

        root.content_hash = self.defer_from_thread(_get_root_content_hash)
        if root.content_hash is None:
            raise ValueError("No content available for node %s" % root_uuid)

        @inlineCallbacks
        def _get_children(parent_uuid, parent_content_hash):
            """Obtain a sequence of MergeNodes corresponding to a node's
            immediate children.

            """
            entries = yield self._get_raw_dir_entries(share_uuid,
                                                      parent_uuid,
                                                      parent_content_hash)
            children = {}
            for entry in entries:
                if should_sync(entry.name):
                    child = MergeNode(node_type=entry.node_type,
                                      uuid=uuid.UUID(entry.node))
                    children[entry.name] = child

            child_uuids = [child.uuid for child in children.itervalues()]
            content_hashes = yield self._get_node_hashes(share_uuid,
                                                         child_uuids)
            for child in children.itervalues():
                child.content_hash = content_hashes.get(child.uuid, None)

            returnValue(children)

        need_children = [root]
        while need_children:
            node = need_children.pop()
            if node.content_hash is not None:
                children = self.defer_from_thread(_get_children, node.uuid,
                                                  node.content_hash)
                node.children = children
                for child in children.itervalues():
                    if child.node_type == DIRECTORY:
                        need_children.append(child)

        return root

00470     def _get_raw_dir_entries(self, share_uuid, node_uuid, content_hash):
        """Gets raw dir entries for the given directory."""
        d = self.factory.current_protocol.get_content(share_str(share_uuid),
                                                      str(node_uuid),
                                                      content_hash)
        d.addCallback(lambda c: zlib.decompress(c.data))

        def _parse_content(raw_content):
            """Parses directory content into a list of entry objects."""
            unserialized_content = DirectoryContent()
            unserialized_content.ParseFromString(raw_content)
            return list(unserialized_content.entries)

        d.addCallback(_parse_content)
        return d

00486     def download_string(self, share_uuid, node_uuid, content_hash):
        """Reads a file from the server into a string."""
        output = StringIO()
        self._download_inner(share_uuid=share_uuid, node_uuid=node_uuid,
                             content_hash=content_hash, output=output)
        return output.getValue()

00493     def download_file(self, share_uuid, node_uuid, content_hash, filename):
        """Downloads a file from the server."""
        partial_filename = "%s.u1partial" % filename
        output = open(partial_filename, "w")

        def rename_file():
            """Renames the temporary file to the final name."""
            output.close()
            os.rename(partial_filename, filename)

        def delete_file():
            """Deletes the temporary file."""
            output.close()
            os.unlink(partial_filename)

        self._download_inner(share_uuid=share_uuid, node_uuid=node_uuid,
                             content_hash=content_hash, output=output,
                             on_success=rename_file, on_failure=delete_file)

00512     def _download_inner(self, share_uuid, node_uuid, content_hash, output,
                        on_success=lambda: None, on_failure=lambda: None):
        """Helper function for content downloads."""
        dec = zlib.decompressobj()

        def write_data(data):
            """Helper which writes data to the output file."""
            uncompressed_data = dec.decompress(data)
            output.write(uncompressed_data)

        def finish_download(value):
            """Helper which finishes the download."""
            uncompressed_data = dec.flush()
            output.write(uncompressed_data)
            on_success()
            return value

        def abort_download(value):
            """Helper which aborts the download."""
            on_failure()
            return value

        def _download():
            """Async helper."""
            _get_content = self.factory.current_protocol.get_content
            d = _get_content(share_str(share_uuid), str(node_uuid),
                             content_hash, callback=write_data)
            d.addCallbacks(finish_download, abort_download)
            return d

        self.defer_from_thread(_download)

00544     def create_directory(self, share_uuid, parent_uuid, name):
        """Creates a directory on the server."""
        r = self.defer_from_thread(self.factory.current_protocol.make_dir,
                                   share_str(share_uuid), str(parent_uuid),
                                   name)
        return uuid.UUID(r.new_id)

00551     def create_file(self, share_uuid, parent_uuid, name):
        """Creates a file on the server."""
        r = self.defer_from_thread(self.factory.current_protocol.make_file,
                                   share_str(share_uuid), str(parent_uuid),
                                   name)
        return uuid.UUID(r.new_id)

00558     def create_symlink(self, share_uuid, parent_uuid, name, target):
        """Creates a symlink on the server."""
        raise UnsupportedOperationError("Protocol does not support symlinks")

00562     def upload_string(self, share_uuid, node_uuid, old_content_hash,
                      content_hash, content):
        """Uploads a string to the server as file content."""
        crc32 = crc32(content, 0)
        compressed_content = zlib.compress(content, 9)
        compressed = StringIO(compressed_content)
        self.defer_from_thread(self.factory.current_protocol.put_content,
                               share_str(share_uuid), str(node_uuid),
                               old_content_hash, content_hash,
                               crc32, len(content), len(compressed_content),
                               compressed)

00574     def upload_file(self, share_uuid, node_uuid, old_content_hash,
                    content_hash, filename):
        """Uploads a file to the server."""
        parent_dir = os.path.split(filename)[0]
        unique_filename = os.path.join(parent_dir, "." + str(uuid.uuid4()))


        class StagingFile(object):
            """An object which tracks data being compressed for staging."""
            def __init__(self, stream):
                """Initialize a compression object."""
                self.crc32 = 0
                self.enc = zlib.compressobj(9)
                self.size = 0
                self.compressed_size = 0
                self.stream = stream

            def write(self, bytes):
                """Compress bytes, keeping track of length and crc32."""
                self.size += len(bytes)
                self.crc32 = crc32(bytes, self.crc32)
                compressed_bytes = self.enc.compress(bytes)
                self.compressed_size += len(compressed_bytes)
                self.stream.write(compressed_bytes)

            def finish(self):
                """Finish staging compressed data."""
                compressed_bytes = self.enc.flush()
                self.compressed_size += len(compressed_bytes)
                self.stream.write(compressed_bytes)

        with open(unique_filename, "w+") as compressed:
            os.unlink(unique_filename)
            with open(filename, "r") as original:
                staging = StagingFile(compressed)
                shutil.copyfileobj(original, staging)
            staging.finish()
            compressed.seek(0)
            self.defer_from_thread(self.factory.current_protocol.put_content,
                                   share_str(share_uuid), str(node_uuid),
                                   old_content_hash, content_hash,
                                   staging.crc32,
                                   staging.size, staging.compressed_size,
                                   compressed)

00619     def move(self, share_uuid, parent_uuid, name, node_uuid):
        """Moves a file on the server."""
        self.defer_from_thread(self.factory.current_protocol.move,
                               share_str(share_uuid), str(node_uuid),
                               str(parent_uuid), name)

00625     def unlink(self, share_uuid, node_uuid):
        """Unlinks a file on the server."""
        self.defer_from_thread(self.factory.current_protocol.unlink,
                               share_str(share_uuid), str(node_uuid))

00630     def _get_node_hashes(self, share_uuid, node_uuids):
        """Fetches hashes for the given nodes."""
        share = share_str(share_uuid)
        queries = [(share, str(node_uuid), request.UNKNOWN_HASH) \
                   for node_uuid in node_uuids]
        d = self.factory.current_protocol.query(queries)

        def _collect_hashes(multi_result):
            """Accumulate hashes from query replies."""
            hashes = {}
            for (success, value) in multi_result:
                if success:
                    for node_state in value.response:
                        node_uuid = uuid.UUID(node_state.node)
                        hashes[node_uuid] = node_state.hash
            return hashes

        d.addCallback(_collect_hashes)
        return d

00650     def get_incoming_shares(self):
        """Returns a list of incoming shares as (name, uuid, accepted)
        tuples.

        """
        _list_shares = self.factory.current_protocol.list_shares
        r = self.defer_from_thread(_list_shares)
        return [(s.name, s.id, s.other_visible_name,
                 s.accepted, s.access_level) \
                for s in r.shares if s.direction == "to_me"]

Generated by  Doxygen 1.6.0   Back to index