diff --git a/canvasapi/util.py b/canvasapi/util.py index 21c2a02b..a4cc08ed 100644 --- a/canvasapi/util.py +++ b/canvasapi/util.py @@ -133,37 +133,34 @@ def obj_or_id(parameter, param_name, object_types): raise TypeError(message) -def obj_or_str(obj, attr, object_types): +def obj_or_str(parameter, param_name, object_types): """ - Accepts an object. If the object has the attribute, return the + Accepts either an object or a string. If it is a string, return it directly. + If it is an object and the object is of correct type, return the object's corresponding string. Otherwise, throw an exception. - :param obj: object from which to retrieve attribute - :type obj: object - :param attr: name of the attribute to retrieve - :type attr: str + :param parameter: object from which to retrieve attribute + :type parameter: str or object + :param param_name: name of the attribute to retrieve + :type param_name: str :param object_types: tuple containing the types of the object being passed in :type object_types: tuple :rtype: str """ - try: - return str(getattr(obj, attr)) - except (AttributeError, TypeError): - if not isinstance(attr, str): - raise TypeError( - "Atttibute parameter {} must be of type string".format(attr) - ) - for obj_type in object_types: - if isinstance(obj, obj_type): - try: - return str(getattr(obj, attr)) - except AttributeError: - raise AttributeError("{} object does not have {} attribute").format( - obj, attr - ) + if isinstance(parameter, str): + return parameter - obj_type_list = ",".join([obj_type.__name__ for obj_type in object_types]) - raise TypeError("Parameter {} must be of type {}.".format(obj, obj_type_list)) + for obj_type in object_types: + if isinstance(parameter, obj_type): + try: + return str(getattr(parameter, param_name)) + except AttributeError: + raise AttributeError("{} object does not have {} attribute").format( + parameter, param_name + ) + + obj_type_list = ",".join([obj_type.__name__ for obj_type in object_types]) + raise TypeError("Parameter {} must be of type {}.".format(parameter, obj_type_list)) def get_institution_url(base_url): diff --git a/tests/test_util.py b/tests/test_util.py index cd85648b..211ed597 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -431,6 +431,12 @@ def test_obj_or_id_nonuser_self(self, m): obj_or_id("self", "user_id", (CourseNickname,)) # obj_or_str() + def test_obj_or_str_str(self, m): + name = obj_or_str("test", "name", (User,)) + + self.assertIsInstance(name, str) + self.assertEqual(name, "test") + def test_obj_or_str_obj_attr(self, m): register_uris({"user": ["get_by_id"]}, m) @@ -467,8 +473,12 @@ def test_obj_or_str_invalid_attr_parameter(self, m): obj_or_str(user, user, (User,)) def test_obj_or_str_invalid_obj_type(self, m): + register_uris({"course": ["get_by_id"]}, m) + + course = self.canvas.get_course(1) + with self.assertRaises(TypeError): - obj_or_str("user", "name", (User,)) + obj_or_str(course, "name", (User,)) # get_institution_url() def test_get_institution_url(self, m):